You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@pinot.apache.org by ja...@apache.org on 2020/05/12 21:02:49 UTC

[incubator-pinot] branch master updated: Clean up AggregationFunctionContext and use TransformExpressionTree as the key in the blockValSetMap passed to the AggregationFunctions (#5364)

This is an automated email from the ASF dual-hosted git repository.

jackie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-pinot.git


The following commit(s) were added to refs/heads/master by this push:
     new 8b0089f  Clean up AggregationFunctionContext and use TransformExpressionTree as the key in the blockValSetMap passed to the AggregationFunctions (#5364)
8b0089f is described below

commit 8b0089f4e8f8d323abbdfb0d8dfd0c79e49b41c2
Author: Xiaotian (Jackie) Jiang <17...@users.noreply.github.com>
AuthorDate: Tue May 12 14:02:38 2020 -0700

    Clean up AggregationFunctionContext and use TransformExpressionTree as the key in the blockValSetMap passed to the AggregationFunctions (#5364)
    
    - Clean up all the usage of AggregationFunctionContext to directly use AggregationFunction
    - Construct the AggregationFunctions and Group-by Expressions at planning phase and pass them to Operator and Executor to save the extra expression compilation
    - Use TransformExpressionTree as the key in the blockValSetMap passed to the AggregationFunctions
      - The benefit of this is to save the redundant string conversion, and more efficient hashCode() and equals()
      - The keys of the blockValSetMap should be the same as AggregationFunction.getInputExpressions()
      - The only exception is CountAggregationFunction with Star-Tree where there is a single entry in blockValSetMap (column "*")
    - Add base implementation of AggregationFunction: BaseSingleExpressionAggregationFunction for aggregation functions on single expressions
    - For PERCENTILE group aggregation functions, support using the second arguments to pass in percentile (e.g. PERCENTILE(column, 99), PERCENTILETDIGEST(column, 90))
    - Enhance Star-Tree Aggregation/Group-by Executor to handle the column name conversion so that AggregationFunctionColumnPair is transparent to the AggregationFunction
    
    BACKWARD-INCOMPATIBLE CHANGE:
    The following APIs are changed in AggregationFunction (use TransformExpressionTree instead of String as the key of blockValSetMap):
    void aggregate(int length, AggregationResultHolder aggregationResultHolder, Map<TransformExpressionTree, BlockValSet> blockValSetMap);
    void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder, Map<TransformExpressionTree, BlockValSet> blockValSetMap);
    void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder groupByResultHolder, Map<TransformExpressionTree, BlockValSet> blockValSetMap);
---
 .../requesthandler/BaseBrokerRequestHandler.java   |  15 +-
 .../common/function/AggregationFunctionType.java   |  12 +-
 .../request/transform/TransformExpressionTree.java |  23 +-
 .../parsers/PinotQuery2BrokerRequestConverter.java |   2 +-
 .../core/common/datatable/DataTableUtils.java      |  23 +-
 .../apache/pinot/core/data/table/BaseTable.java    |  31 +--
 .../core/data/table/ConcurrentIndexedTable.java    |  10 +-
 .../apache/pinot/core/data/table/IndexedTable.java |   7 +-
 .../pinot/core/data/table/SimpleIndexedTable.java  |  10 +-
 .../apache/pinot/core/data/table/TableResizer.java |  40 ++-
 .../core/operator/CombineGroupByOperator.java      |  13 +-
 .../operator/CombineGroupByOrderByOperator.java    |  10 +-
 .../operator/blocks/IntermediateResultsBlock.java  |  48 ++--
 .../operator/query/AggregationGroupByOperator.java |  31 ++-
 .../query/AggregationGroupByOrderByOperator.java   |  64 +++--
 .../core/operator/query/AggregationOperator.java   |  14 +-
 .../query/DictionaryBasedAggregationOperator.java  |  49 ++--
 .../query/MetadataBasedAggregationOperator.java    |  32 +--
 .../plan/AggregationGroupByOrderByPlanNode.java    |  73 +++--
 .../core/plan/AggregationGroupByPlanNode.java      |  62 +++--
 .../pinot/core/plan/AggregationPlanNode.java       |  49 ++--
 .../plan/DictionaryBasedAggregationPlanNode.java   |  23 +-
 .../plan/MetadataBasedAggregationPlanNode.java     |  33 +--
 .../core/plan/maker/InstancePlanMakerImplV2.java   |   7 +-
 .../aggregation/AggregationFunctionContext.java    |  85 ------
 .../aggregation/DefaultAggregationExecutor.java    |  66 ++---
 .../core/query/aggregation/DistinctTable.java      |   8 +-
 .../aggregation/function/AggregationFunction.java  |   7 +-
 .../function/AggregationFunctionFactory.java       | 102 ++++---
 .../function/AggregationFunctionUtils.java         | 191 ++++++-------
 .../function/AvgAggregationFunction.java           |  44 +--
 .../function/AvgMVAggregationFunction.java         |  18 +-
 .../BaseSingleInputAggregationFunction.java        |  57 ++++
 .../function/CountAggregationFunction.java         |  32 +--
 .../function/CountMVAggregationFunction.java       |  29 +-
 .../function/DistinctAggregationFunction.java      |  39 ++-
 .../function/DistinctCountAggregationFunction.java |  42 +--
 .../DistinctCountHLLAggregationFunction.java       |  42 +--
 .../DistinctCountHLLMVAggregationFunction.java     |  18 +-
 .../DistinctCountMVAggregationFunction.java        |  18 +-
 .../DistinctCountRawHLLAggregationFunction.java    |  41 +--
 .../DistinctCountRawHLLMVAggregationFunction.java  |   6 +-
 ...istinctCountThetaSketchAggregationFunction.java |  53 ++--
 .../function/FastHLLAggregationFunction.java       |  42 +--
 .../function/MaxAggregationFunction.java           |  42 +--
 .../function/MaxMVAggregationFunction.java         |  18 +-
 .../function/MinAggregationFunction.java           |  42 +--
 .../function/MinMVAggregationFunction.java         |  18 +-
 .../function/MinMaxRangeAggregationFunction.java   |  41 +--
 .../function/MinMaxRangeMVAggregationFunction.java |  18 +-
 .../function/PercentileAggregationFunction.java    |  48 +---
 .../function/PercentileEstAggregationFunction.java |  43 +--
 .../PercentileEstMVAggregationFunction.java        |  28 +-
 .../function/PercentileMVAggregationFunction.java  |  28 +-
 .../PercentileTDigestAggregationFunction.java      |  43 +--
 .../PercentileTDigestMVAggregationFunction.java    |  28 +-
 .../function/SumAggregationFunction.java           |  41 +--
 .../function/SumMVAggregationFunction.java         |  18 +-
 .../groupby/DefaultGroupByExecutor.java            | 102 +++----
 .../query/aggregation/groupby/GroupByExecutor.java |   3 +-
 .../query/reduce/AggregationDataTableReducer.java  |  47 ++--
 .../pinot/core/query/reduce/CombineService.java    |  24 +-
 .../core/query/reduce/ComparisonFunction.java      |   5 +-
 .../query/reduce/DistinctDataTableReducer.java     |  13 +-
 .../core/query/reduce/GroupByDataTableReducer.java |  16 +-
 .../core/query/request/ServerQueryRequest.java     |   2 +-
 .../apache/pinot/core/startree/StarTreeUtils.java  |  32 +--
 .../executor/StarTreeAggregationExecutor.java      |  55 +---
 .../startree/executor/StarTreeGroupByExecutor.java |  70 ++---
 .../startree/plan/StarTreeTransformPlanNode.java   |  10 +-
 .../pinot/core/data/table/IndexedTableTest.java    | 128 ++++-----
 .../pinot/core/data/table/TableResizerTest.java    | 305 +++++++++------------
 .../function/AggregationFunctionFactoryTest.java   |   9 +-
 .../pinot/core/startree/v2/BaseStarTreeV2Test.java |  34 +--
 ...terSegmentAggregationMultiValueQueriesTest.java | 122 +++++----
 ...erSegmentAggregationSingleValueQueriesTest.java |  41 +--
 ...terSegmentResultTableMultiValueQueriesTest.java |  79 +++---
 ...erSegmentResultTableSingleValueQueriesTest.java |  78 +++---
 .../DefaultAggregationExecutorTest.java            |  33 +--
 .../apache/pinot/perf/BenchmarkCombineGroupBy.java |  76 ++---
 .../apache/pinot/perf/BenchmarkIndexedTable.java   |  61 ++---
 81 files changed, 1430 insertions(+), 1992 deletions(-)

diff --git a/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/BaseBrokerRequestHandler.java b/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/BaseBrokerRequestHandler.java
index c9f5e29..6a00878 100644
--- a/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/BaseBrokerRequestHandler.java
+++ b/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/BaseBrokerRequestHandler.java
@@ -456,13 +456,9 @@ public abstract class BaseBrokerRequestHandler implements BrokerRequestHandler {
       for (AggregationInfo info : brokerRequest.getAggregationsInfo()) {
         if (!info.getAggregationType().equalsIgnoreCase(AggregationFunctionType.COUNT.getName())) {
           // Always read from backward compatible api in AggregationFunctionUtils.
-          List<String> expressions = AggregationFunctionUtils.getAggregationExpressions(info);
-
-          List<String> newExpressions = new ArrayList<>(expressions.size());
-          for (String expression : expressions) {
-            newExpressions.add(fixColumnNameCase(actualTableName, expression));
-          }
-          info.setExpressions(newExpressions);
+          List<String> arguments = AggregationFunctionUtils.getArguments(info);
+          arguments.replaceAll(e -> fixColumnNameCase(actualTableName, e));
+          info.setExpressions(arguments);
         }
       }
       if (brokerRequest.isSetGroupBy()) {
@@ -720,11 +716,10 @@ public abstract class BaseBrokerRequestHandler implements BrokerRequestHandler {
             throw new UnsupportedOperationException("DISTINCT with GROUP BY is currently not supported");
           }
           if (brokerRequest.isSetOrderBy()) {
-            List<String> columns = AggregationFunctionUtils.getAggregationExpressions(aggregationInfo);
-            Set<String> set = new HashSet<>(columns);
+            Set<String> expressionSet = new HashSet<>(AggregationFunctionUtils.getArguments(aggregationInfo));
             List<SelectionSort> orderByColumns = brokerRequest.getOrderBy();
             for (SelectionSort selectionSort : orderByColumns) {
-              if (!set.contains(selectionSort.getColumn())) {
+              if (!expressionSet.contains(selectionSort.getColumn())) {
                 throw new UnsupportedOperationException(
                     "ORDER By should be only on some/all of the columns passed as arguments to DISTINCT");
               }
diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/AggregationFunctionType.java b/pinot-common/src/main/java/org/apache/pinot/common/function/AggregationFunctionType.java
index d3cef28..af31639 100644
--- a/pinot-common/src/main/java/org/apache/pinot/common/function/AggregationFunctionType.java
+++ b/pinot-common/src/main/java/org/apache/pinot/common/function/AggregationFunctionType.java
@@ -76,17 +76,17 @@ public enum AggregationFunctionType {
     String upperCaseFunctionName = functionName.toUpperCase();
     if (upperCaseFunctionName.startsWith("PERCENTILE")) {
       String remainingFunctionName = upperCaseFunctionName.substring(10);
-      if (remainingFunctionName.matches("\\d+")) {
+      if (remainingFunctionName.isEmpty() || remainingFunctionName.matches("\\d+")) {
         return PERCENTILE;
-      } else if (remainingFunctionName.matches("EST\\d+")) {
+      } else if (remainingFunctionName.equals("EST") || remainingFunctionName.matches("EST\\d+")) {
         return PERCENTILEEST;
-      } else if (remainingFunctionName.matches("TDIGEST\\d+")) {
+      } else if (remainingFunctionName.equals("TDIGEST") || remainingFunctionName.matches("TDIGEST\\d+")) {
         return PERCENTILETDIGEST;
-      } else if (remainingFunctionName.matches("\\d+MV")) {
+      } else if (remainingFunctionName.equals("MV") || remainingFunctionName.matches("\\d+MV")) {
         return PERCENTILEMV;
-      } else if (remainingFunctionName.matches("EST\\d+MV")) {
+      } else if (remainingFunctionName.equals("ESTMV") || remainingFunctionName.matches("EST\\d+MV")) {
         return PERCENTILEESTMV;
-      } else if (remainingFunctionName.matches("TDIGEST\\d+MV")) {
+      } else if (remainingFunctionName.equals("TDIGESTMV") || remainingFunctionName.matches("TDIGEST\\d+MV")) {
         return PERCENTILETDIGESTMV;
       } else {
         throw new IllegalArgumentException("Invalid aggregation function name: " + functionName);
diff --git a/pinot-common/src/main/java/org/apache/pinot/common/request/transform/TransformExpressionTree.java b/pinot-common/src/main/java/org/apache/pinot/common/request/transform/TransformExpressionTree.java
index 5c0a515..527df87 100644
--- a/pinot-common/src/main/java/org/apache/pinot/common/request/transform/TransformExpressionTree.java
+++ b/pinot-common/src/main/java/org/apache/pinot/common/request/transform/TransformExpressionTree.java
@@ -22,14 +22,13 @@ import java.util.ArrayList;
 import java.util.List;
 import java.util.Objects;
 import java.util.Set;
-import javax.annotation.Nonnull;
-import org.apache.pinot.spi.utils.EqualityUtils;
+import javax.annotation.Nullable;
 import org.apache.pinot.pql.parsers.Pql2Compiler;
 import org.apache.pinot.pql.parsers.pql2.ast.AstNode;
 import org.apache.pinot.pql.parsers.pql2.ast.FunctionCallAstNode;
 import org.apache.pinot.pql.parsers.pql2.ast.IdentifierAstNode;
 import org.apache.pinot.pql.parsers.pql2.ast.LiteralAstNode;
-import org.apache.pinot.pql.parsers.pql2.ast.StringLiteralAstNode;
+import org.apache.pinot.spi.utils.EqualityUtils;
 
 
 /**
@@ -37,7 +36,7 @@ import org.apache.pinot.pql.parsers.pql2.ast.StringLiteralAstNode;
  * <ul>
  *   <li>A TransformExpressionTree node has either transform function or a column name, or a literal.</li>
  *   <li>Leaf nodes either have column name or literal, whereas non-leaf nodes have transform function.</li>
- *   <li>Transform function in non-leaf nodes is applied to its children nodes.</li>
+ *   <li>Transform function is applied to its children.</li>
  * </ul>
  */
 public class TransformExpressionTree {
@@ -66,10 +65,9 @@ public class TransformExpressionTree {
     } else if (astNode instanceof FunctionCallAstNode) {
       // UDF expression
       return standardizeExpression(((FunctionCallAstNode) astNode).getExpression());
-    } else if (astNode instanceof StringLiteralAstNode) {
-      // Treat string as column name
-      // NOTE: this is for backward-compatibility
-      return ((StringLiteralAstNode) astNode).getText();
+    } else if (astNode instanceof LiteralAstNode) {
+      // Literal
+      return ((LiteralAstNode) astNode).getValueAsString();
     } else {
       throw new IllegalStateException("Cannot get standard expression from " + astNode.getClass().getSimpleName());
     }
@@ -106,6 +104,13 @@ public class TransformExpressionTree {
     }
   }
 
+  public TransformExpressionTree(ExpressionType expressionType, String value,
+      @Nullable List<TransformExpressionTree> children) {
+    _expressionType = expressionType;
+    _value = value;
+    _children = children;
+  }
+
   /**
    * Returns the expression type of the node, which can be one of the following:
    * <ul>
@@ -168,7 +173,7 @@ public class TransformExpressionTree {
    *
    * @param columns Output columns
    */
-  public void getColumns(@Nonnull Set<String> columns) {
+  public void getColumns(Set<String> columns) {
     if (_expressionType == ExpressionType.IDENTIFIER) {
       columns.add(_value);
     } else if (_children != null) {
diff --git a/pinot-common/src/main/java/org/apache/pinot/pql/parsers/PinotQuery2BrokerRequestConverter.java b/pinot-common/src/main/java/org/apache/pinot/pql/parsers/PinotQuery2BrokerRequestConverter.java
index 6ce2f16..1ce9a06 100644
--- a/pinot-common/src/main/java/org/apache/pinot/pql/parsers/PinotQuery2BrokerRequestConverter.java
+++ b/pinot-common/src/main/java/org/apache/pinot/pql/parsers/PinotQuery2BrokerRequestConverter.java
@@ -226,7 +226,7 @@ public class PinotQuery2BrokerRequestConverter {
   private String getColumnExpression(Expression functionParam) {
     switch (functionParam.getType()) {
       case LITERAL:
-        return functionParam.getLiteral().getStringValue();
+        return functionParam.getLiteral().getFieldValue().toString();
       case IDENTIFIER:
         return functionParam.getIdentifier().getName();
       case FUNCTION:
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/common/datatable/DataTableUtils.java b/pinot-core/src/main/java/org/apache/pinot/core/common/datatable/DataTableUtils.java
index 2122779..1574a24 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/common/datatable/DataTableUtils.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/common/datatable/DataTableUtils.java
@@ -26,7 +26,6 @@ import org.apache.pinot.common.request.BrokerRequest;
 import org.apache.pinot.common.request.Selection;
 import org.apache.pinot.common.utils.DataSchema;
 import org.apache.pinot.common.utils.DataTable;
-import org.apache.pinot.core.query.aggregation.AggregationFunctionContext;
 import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
 import org.apache.pinot.core.query.aggregation.function.AggregationFunctionUtils;
 import org.apache.pinot.core.util.QueryOptions;
@@ -101,9 +100,8 @@ public class DataTableUtils {
     }
 
     // Aggregation query.
-    AggregationFunctionContext[] aggregationFunctionContexts =
-        AggregationFunctionUtils.getAggregationFunctionContexts(brokerRequest);
-    int numAggregations = aggregationFunctionContexts.length;
+    AggregationFunction[] aggregationFunctions = AggregationFunctionUtils.getAggregationFunctions(brokerRequest);
+    int numAggregations = aggregationFunctions.length;
     if (brokerRequest.isSetGroupBy()) {
       // Aggregation group-by query.
 
@@ -121,9 +119,9 @@ public class DataTableUtils {
           columnDataTypes[index] = DataSchema.ColumnDataType.STRING;
           index++;
         }
-        for (AggregationFunctionContext aggregationFunctionContext : aggregationFunctionContexts) {
-          columnNames[index] = aggregationFunctionContext.getResultColumnName();
-          AggregationFunction aggregationFunction = aggregationFunctionContext.getAggregationFunction();
+        for (AggregationFunction aggregationFunction : aggregationFunctions) {
+          // NOTE: Use AggregationFunction.getResultColumnName() for SQL format response
+          columnNames[index] = aggregationFunction.getResultColumnName();
           columnDataTypes[index] = aggregationFunction.getIntermediateResultColumnType();
           index++;
         }
@@ -137,9 +135,10 @@ public class DataTableUtils {
 
         // Build the data table.
         DataTableBuilder dataTableBuilder = new DataTableBuilder(new DataSchema(columnNames, columnDataTypes));
-        for (AggregationFunctionContext aggregationFunctionContext : aggregationFunctionContexts) {
+        for (AggregationFunction aggregationFunction : aggregationFunctions) {
           dataTableBuilder.startRow();
-          dataTableBuilder.setColumn(0, aggregationFunctionContext.getAggregationColumnName());
+          // NOTE: For backward-compatibility, use AggregationFunction.getColumnName() for PQL format response
+          dataTableBuilder.setColumn(0, aggregationFunction.getColumnName());
           dataTableBuilder.setColumn(1, Collections.emptyMap());
           dataTableBuilder.finishRow();
         }
@@ -152,9 +151,9 @@ public class DataTableUtils {
       DataSchema.ColumnDataType[] columnDataTypes = new DataSchema.ColumnDataType[numAggregations];
       Object[] aggregationResults = new Object[numAggregations];
       for (int i = 0; i < numAggregations; i++) {
-        AggregationFunctionContext aggregationFunctionContext = aggregationFunctionContexts[i];
-        aggregationColumnNames[i] = aggregationFunctionContext.getAggregationColumnName();
-        AggregationFunction aggregationFunction = aggregationFunctionContext.getAggregationFunction();
+        AggregationFunction aggregationFunction = aggregationFunctions[i];
+        // NOTE: For backward-compatibility, use AggregationFunction.getColumnName() for aggregation only query
+        aggregationColumnNames[i] = aggregationFunction.getColumnName();
         columnDataTypes[i] = aggregationFunction.getIntermediateResultColumnType();
         aggregationResults[i] =
             aggregationFunction.extractAggregationResult(aggregationFunction.createAggregationResultHolder());
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/data/table/BaseTable.java b/pinot-core/src/main/java/org/apache/pinot/core/data/table/BaseTable.java
index 03d8b3b..3ad90ac 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/data/table/BaseTable.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/data/table/BaseTable.java
@@ -21,22 +21,20 @@ package org.apache.pinot.core.data.table;
 import java.util.Iterator;
 import java.util.List;
 import org.apache.commons.collections.CollectionUtils;
-import org.apache.pinot.common.request.AggregationInfo;
 import org.apache.pinot.common.request.SelectionSort;
 import org.apache.pinot.common.utils.DataSchema;
 import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
-import org.apache.pinot.core.query.aggregation.function.AggregationFunctionUtils;
 
 
 /**
  * Base abstract implementation of Table
  */
 public abstract class BaseTable implements Table {
-
-  final AggregationFunction[] _aggregationFunctions;
-  final int _numAggregations;
+  // TODO: After fixing the DistinctTable logic, make it final
   protected DataSchema _dataSchema;
-  final int _numColumns;
+  protected final int _numColumns;
+  protected final AggregationFunction[] _aggregationFunctions;
+  protected final int _numAggregations;
 
   // the capacity we need to trim to
   protected int _capacity;
@@ -46,34 +44,27 @@ public abstract class BaseTable implements Table {
   protected boolean _isOrderBy;
   protected TableResizer _tableResizer;
 
-  private final List<AggregationInfo> _aggregationInfos;
-
   /**
    * Initializes the variables and comparators needed for the table
    */
-  public BaseTable(DataSchema dataSchema, List<AggregationInfo> aggregationInfos, List<SelectionSort> orderBy, int capacity) {
+  public BaseTable(DataSchema dataSchema, AggregationFunction[] aggregationFunctions, List<SelectionSort> orderBy,
+      int capacity) {
     _dataSchema = dataSchema;
     _numColumns = dataSchema.size();
-
-    _numAggregations = aggregationInfos.size();
-    _aggregationFunctions = new AggregationFunction[_numAggregations];
-    for (int i = 0; i < _numAggregations; i++) {
-      _aggregationFunctions[i] =
-          AggregationFunctionUtils.getAggregationFunctionContext(aggregationInfos.get(i)).getAggregationFunction();
-    }
-
-    _aggregationInfos = aggregationInfos;
+    _aggregationFunctions = aggregationFunctions;
+    _numAggregations = aggregationFunctions.length;
     addCapacityAndOrderByInfo(orderBy, capacity);
   }
 
   protected void addCapacityAndOrderByInfo(List<SelectionSort> orderBy, int capacity) {
     _isOrderBy = CollectionUtils.isNotEmpty(orderBy);
     if (_isOrderBy) {
-      _tableResizer = new TableResizer(_dataSchema, _aggregationInfos, orderBy);
+      _tableResizer = new TableResizer(_dataSchema, _aggregationFunctions, orderBy);
 
       // TODO: tune these numbers and come up with a better formula (github ISSUE-4801)
       // Based on the capacity and maxCapacity, the resizer will smartly choose to evict/retain recors from the PQ
-      if (capacity <= 100_000) { // Capacity is small, make a very large buffer. Make PQ of records to retain, during resize
+      if (capacity
+          <= 100_000) { // Capacity is small, make a very large buffer. Make PQ of records to retain, during resize
         _maxCapacity = 1_000_000;
       } else { // Capacity is large, make buffer only slightly bigger. Make PQ of records to evict, during resize
         _maxCapacity = (int) (capacity * 1.2);
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/data/table/ConcurrentIndexedTable.java b/pinot-core/src/main/java/org/apache/pinot/core/data/table/ConcurrentIndexedTable.java
index aa52218..3c25ef2 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/data/table/ConcurrentIndexedTable.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/data/table/ConcurrentIndexedTable.java
@@ -27,9 +27,9 @@ import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.concurrent.atomic.AtomicLong;
 import java.util.concurrent.locks.ReentrantReadWriteLock;
-import org.apache.pinot.common.request.AggregationInfo;
 import org.apache.pinot.common.request.SelectionSort;
 import org.apache.pinot.common.utils.DataSchema;
+import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -52,13 +52,13 @@ public class ConcurrentIndexedTable extends IndexedTable {
   /**
    * Initializes the data structures needed for this Table
    * @param dataSchema data schema of the record's keys and values
-   * @param aggregationInfos aggregation infos for the aggregations in record's values
+   * @param aggregationFunctions aggregation functions for the record's values
    * @param orderBy list of {@link SelectionSort} defining the order by
    * @param capacity the capacity of the table
    */
-  public ConcurrentIndexedTable(DataSchema dataSchema, List<AggregationInfo> aggregationInfos, List<SelectionSort> orderBy,
-      int capacity) {
-    super(dataSchema, aggregationInfos, orderBy, capacity);
+  public ConcurrentIndexedTable(DataSchema dataSchema, AggregationFunction[] aggregationFunctions,
+      List<SelectionSort> orderBy, int capacity) {
+    super(dataSchema, aggregationFunctions, orderBy, capacity);
 
     _lookupMap = new ConcurrentHashMap<>();
     _readWriteLock = new ReentrantReadWriteLock();
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/data/table/IndexedTable.java b/pinot-core/src/main/java/org/apache/pinot/core/data/table/IndexedTable.java
index 0f29f7a..cb2dc33 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/data/table/IndexedTable.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/data/table/IndexedTable.java
@@ -20,9 +20,9 @@ package org.apache.pinot.core.data.table;
 
 import java.util.Arrays;
 import java.util.List;
-import org.apache.pinot.common.request.AggregationInfo;
 import org.apache.pinot.common.request.SelectionSort;
 import org.apache.pinot.common.utils.DataSchema;
+import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
 
 
 /**
@@ -36,8 +36,9 @@ public abstract class IndexedTable extends BaseTable {
   /**
    * Initializes the variables and comparators needed for the table
    */
-  IndexedTable(DataSchema dataSchema, List<AggregationInfo> aggregationInfos, List<SelectionSort> orderBy, int capacity) {
-    super(dataSchema, aggregationInfos, orderBy, capacity);
+  IndexedTable(DataSchema dataSchema, AggregationFunction[] aggregationFunctions, List<SelectionSort> orderBy,
+      int capacity) {
+    super(dataSchema, aggregationFunctions, orderBy, capacity);
 
     _numKeyColumns = dataSchema.size() - _numAggregations;
     _keyExtractor = new KeyExtractor(_numKeyColumns);
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/data/table/SimpleIndexedTable.java b/pinot-core/src/main/java/org/apache/pinot/core/data/table/SimpleIndexedTable.java
index cb28c08..ffdf103 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/data/table/SimpleIndexedTable.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/data/table/SimpleIndexedTable.java
@@ -24,9 +24,9 @@ import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
 import javax.annotation.concurrent.NotThreadSafe;
-import org.apache.pinot.common.request.AggregationInfo;
 import org.apache.pinot.common.request.SelectionSort;
 import org.apache.pinot.common.utils.DataSchema;
+import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -48,13 +48,13 @@ public class SimpleIndexedTable extends IndexedTable {
   /**
    * Initializes the data structures needed for this Table
    * @param dataSchema data schema of the record's keys and values
-   * @param aggregationInfos aggregation infos for the aggregations in record'd values
+   * @param aggregationFunctions aggregation functions for the record's values
    * @param orderBy list of {@link SelectionSort} defining the order by
    * @param capacity the capacity of the table
    */
-  public SimpleIndexedTable(DataSchema dataSchema, List<AggregationInfo> aggregationInfos, List<SelectionSort> orderBy,
-      int capacity) {
-    super(dataSchema, aggregationInfos, orderBy, capacity);
+  public SimpleIndexedTable(DataSchema dataSchema, AggregationFunction[] aggregationFunctions,
+      List<SelectionSort> orderBy, int capacity) {
+    super(dataSchema, aggregationFunctions, orderBy, capacity);
 
     _lookupMap = new HashMap<>();
   }
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/data/table/TableResizer.java b/pinot-core/src/main/java/org/apache/pinot/core/data/table/TableResizer.java
index 19b8b4c..1fa1deb 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/data/table/TableResizer.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/data/table/TableResizer.java
@@ -31,11 +31,9 @@ import java.util.Map;
 import java.util.PriorityQueue;
 import java.util.Set;
 import java.util.function.Function;
-import org.apache.pinot.common.request.AggregationInfo;
 import org.apache.pinot.common.request.SelectionSort;
 import org.apache.pinot.common.utils.DataSchema;
 import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
-import org.apache.pinot.core.query.aggregation.function.AggregationFunctionUtils;
 
 
 /**
@@ -48,13 +46,13 @@ public class TableResizer {
   private Comparator<Record> _recordComparator;
   protected int _numOrderBy;
 
-  TableResizer(DataSchema dataSchema, List<AggregationInfo> aggregationInfos, List<SelectionSort> orderBy) {
+  TableResizer(DataSchema dataSchema, AggregationFunction[] aggregationFunctions, List<SelectionSort> orderBy) {
 
     // NOTE: the assumption here is that the key columns will appear before the aggregation columns in the data schema
     // This is handled in the only in the AggregationGroupByOrderByOperator for now
 
     int numColumns = dataSchema.size();
-    int numAggregations = aggregationInfos.size();
+    int numAggregations = aggregationFunctions.length;
     int numKeyColumns = numColumns - numAggregations;
 
     Map<String, Integer> columnIndexMap = new HashMap<>();
@@ -63,10 +61,7 @@ public class TableResizer {
       String columnName = dataSchema.getColumnName(i);
       columnIndexMap.put(columnName, i);
       if (i >= numKeyColumns) {
-        AggregationInfo aggregationInfo = aggregationInfos.get(i - numKeyColumns);
-        AggregationFunction aggregationFunction =
-            AggregationFunctionUtils.getAggregationFunctionContext(aggregationInfo).getAggregationFunction();
-        aggregationColumnToFunction.put(columnName, aggregationFunction);
+        aggregationColumnToFunction.put(columnName, aggregationFunctions[i - numKeyColumns]);
       }
     }
 
@@ -109,7 +104,8 @@ public class TableResizer {
       };
     } else {
       // For cases where the entire Record is unique and is treated as a key
-      Preconditions.checkState(numKeyColumns == numColumns, "number of key columns should be equal to total number of columns");
+      Preconditions
+          .checkState(numKeyColumns == numColumns, "number of key columns should be equal to total number of columns");
       int[] orderByIndexes = new int[_numOrderBy];
       boolean[] orderByAsc = new boolean[_numOrderBy];
       for (int i = 0; i < _numOrderBy; i++) {
@@ -194,7 +190,7 @@ public class TableResizer {
     return priorityQueue;
   }
 
-   private List<Record> sortRecordsMap(Map<Key, Record> recordsMap) {
+  private List<Record> sortRecordsMap(Map<Key, Record> recordsMap) {
     int numRecords = recordsMap.size();
     List<Record> sortedRecords = new ArrayList<>(numRecords);
     List<IntermediateRecord> intermediateRecords = new ArrayList<>(numRecords);
@@ -321,7 +317,6 @@ public class TableResizer {
     }
   }
 
-
   /********************************************************
    *                                                      *
    * Resize functions for Set based table implementation  *
@@ -342,9 +337,10 @@ public class TableResizer {
       Object[] values1 = record1.getValues();
       Object[] values2 = record2.getValues();
       for (int i = 0; i < _numOrderBy; i++) {
-        Comparable valueToCompare1 = (Comparable)values1[_orderByColumnIndexes[i]];
-        Comparable valueToCompare2 = (Comparable)values2[_orderByColumnIndexes[i]];
-        int result = _orderByAsc[i] ? valueToCompare1.compareTo(valueToCompare2) : valueToCompare2.compareTo(valueToCompare1);
+        Comparable valueToCompare1 = (Comparable) values1[_orderByColumnIndexes[i]];
+        Comparable valueToCompare2 = (Comparable) values2[_orderByColumnIndexes[i]];
+        int result =
+            _orderByAsc[i] ? valueToCompare1.compareTo(valueToCompare2) : valueToCompare2.compareTo(valueToCompare1);
         if (result != 0) {
           return result;
         }
@@ -359,14 +355,16 @@ public class TableResizer {
       if (numRecordsToEvict < trimToSize) {
         // num records to evict is smaller than num records to retain
         // make PQ of records to evict
-        PriorityQueue<Record> priorityQueue = buildPriorityQueueFromRecordSet(numRecordsToEvict, recordSet, _recordComparator);
+        PriorityQueue<Record> priorityQueue =
+            buildPriorityQueueFromRecordSet(numRecordsToEvict, recordSet, _recordComparator);
         for (Record recordToEvict : priorityQueue) {
           recordSet.remove(recordToEvict);
         }
       } else {
         // num records to retain is smaller than num records to evict
         // make PQ of records to retain
-        PriorityQueue<Record> priorityQueue = buildPriorityQueueFromRecordSet(trimToSize, recordSet, _recordComparator.reversed());
+        PriorityQueue<Record> priorityQueue =
+            buildPriorityQueueFromRecordSet(trimToSize, recordSet, _recordComparator.reversed());
         ObjectOpenHashSet<Record> recordsToRetain = new ObjectOpenHashSet<>(priorityQueue.size());
         for (Record recordToRetain : priorityQueue) {
           recordsToRetain.add(recordToRetain);
@@ -376,9 +374,7 @@ public class TableResizer {
     }
   }
 
-  private PriorityQueue<Record> buildPriorityQueueFromRecordSet(
-      int size,
-      Set<Record> recordSet,
+  private PriorityQueue<Record> buildPriorityQueueFromRecordSet(int size, Set<Record> recordSet,
       Comparator<Record> comparator) {
     PriorityQueue<Record> priorityQueue = new PriorityQueue<>(size, comparator);
     for (Record record : recordSet) {
@@ -416,7 +412,8 @@ public class TableResizer {
       // num records to evict is smaller than num records to retain
       if (numRecordsToEvict > 0) {
         // make PQ of records to evict
-        PriorityQueue<Record> priorityQueue = buildPriorityQueueFromRecordSet(numRecordsToEvict, recordSet, _recordComparator);
+        PriorityQueue<Record> priorityQueue =
+            buildPriorityQueueFromRecordSet(numRecordsToEvict, recordSet, _recordComparator);
         for (Record recordToEvict : priorityQueue) {
           recordSet.remove(recordToEvict);
         }
@@ -424,7 +421,8 @@ public class TableResizer {
       return sortRecordSet(recordSet);
     } else {
       // make PQ of records to retain
-      PriorityQueue<Record> priorityQueue = buildPriorityQueueFromRecordSet(numRecordsToRetain, recordSet, _recordComparator.reversed());
+      PriorityQueue<Record> priorityQueue =
+          buildPriorityQueueFromRecordSet(numRecordsToRetain, recordSet, _recordComparator.reversed());
       // use PQ to get sorted list
       Record[] sortedArray = new Record[numRecordsToRetain];
       while (!priorityQueue.isEmpty()) {
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/operator/CombineGroupByOperator.java b/pinot-core/src/main/java/org/apache/pinot/core/operator/CombineGroupByOperator.java
index 0508142..a7bb5c2 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/operator/CombineGroupByOperator.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/operator/CombineGroupByOperator.java
@@ -37,7 +37,6 @@ import org.apache.pinot.common.request.BrokerRequest;
 import org.apache.pinot.common.response.ProcessingException;
 import org.apache.pinot.core.common.Operator;
 import org.apache.pinot.core.operator.blocks.IntermediateResultsBlock;
-import org.apache.pinot.core.query.aggregation.AggregationFunctionContext;
 import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
 import org.apache.pinot.core.query.aggregation.function.AggregationFunctionUtils;
 import org.apache.pinot.core.query.aggregation.groupby.AggregationGroupByResult;
@@ -108,13 +107,8 @@ public class CombineGroupByOperator extends BaseOperator<IntermediateResultsBloc
     AtomicInteger numGroups = new AtomicInteger();
     ConcurrentLinkedQueue<ProcessingException> mergedProcessingExceptions = new ConcurrentLinkedQueue<>();
 
-    AggregationFunctionContext[] aggregationFunctionContexts =
-        AggregationFunctionUtils.getAggregationFunctionContexts(_brokerRequest);
-    int numAggregationFunctions = aggregationFunctionContexts.length;
-    AggregationFunction[] aggregationFunctions = new AggregationFunction[numAggregationFunctions];
-    for (int i = 0; i < numAggregationFunctions; i++) {
-      aggregationFunctions[i] = aggregationFunctionContexts[i].getAggregationFunction();
-    }
+    AggregationFunction[] aggregationFunctions = AggregationFunctionUtils.getAggregationFunctions(_brokerRequest);
+    int numAggregationFunctions = aggregationFunctions.length;
 
     // We use a CountDownLatch to track if all Futures are finished by the query timeout, and cancel the unfinished
     // futures (try to interrupt the execution if it already started).
@@ -205,8 +199,7 @@ public class CombineGroupByOperator extends BaseOperator<IntermediateResultsBloc
           new AggregationGroupByTrimmingService(aggregationFunctions, (int) _brokerRequest.getGroupBy().getTopN());
       List<Map<String, Object>> trimmedResults =
           aggregationGroupByTrimmingService.trimIntermediateResultsMap(resultsMap);
-      IntermediateResultsBlock mergedBlock =
-          new IntermediateResultsBlock(aggregationFunctionContexts, trimmedResults, true);
+      IntermediateResultsBlock mergedBlock = new IntermediateResultsBlock(aggregationFunctions, trimmedResults, true);
 
       // Set the processing exceptions.
       if (!mergedProcessingExceptions.isEmpty()) {
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/operator/CombineGroupByOrderByOperator.java b/pinot-core/src/main/java/org/apache/pinot/core/operator/CombineGroupByOrderByOperator.java
index 5b724f1..c45f8ba 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/operator/CombineGroupByOrderByOperator.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/operator/CombineGroupByOrderByOperator.java
@@ -41,6 +41,8 @@ import org.apache.pinot.core.data.table.ConcurrentIndexedTable;
 import org.apache.pinot.core.data.table.Key;
 import org.apache.pinot.core.data.table.Record;
 import org.apache.pinot.core.operator.blocks.IntermediateResultsBlock;
+import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
+import org.apache.pinot.core.query.aggregation.function.AggregationFunctionUtils;
 import org.apache.pinot.core.query.aggregation.groupby.AggregationGroupByResult;
 import org.apache.pinot.core.query.aggregation.groupby.GroupKeyGenerator;
 import org.apache.pinot.core.query.exception.EarlyTerminationException;
@@ -98,7 +100,8 @@ public class CombineGroupByOrderByOperator extends BaseOperator<IntermediateResu
    */
   @Override
   protected IntermediateResultsBlock getNextBlock() {
-    int numAggregationFunctions = _brokerRequest.getAggregationsInfoSize();
+    AggregationFunction[] aggregationFunctions = AggregationFunctionUtils.getAggregationFunctions(_brokerRequest);
+    int numAggregationFunctions = aggregationFunctions.length;
     int numGroupBy = _brokerRequest.getGroupBy().getExpressionsSize();
     int numColumns = numGroupBy + numAggregationFunctions;
     ConcurrentLinkedQueue<ProcessingException> mergedProcessingExceptions = new ConcurrentLinkedQueue<>();
@@ -137,8 +140,9 @@ public class CombineGroupByOrderByOperator extends BaseOperator<IntermediateResu
             try {
               if (_dataSchema == null) {
                 _dataSchema = intermediateResultsBlock.getDataSchema();
-                _indexedTable = new ConcurrentIndexedTable(_dataSchema, _brokerRequest.getAggregationsInfo(),
-                    _brokerRequest.getOrderBy(), _indexedTableCapacity);
+                _indexedTable =
+                    new ConcurrentIndexedTable(_dataSchema, aggregationFunctions, _brokerRequest.getOrderBy(),
+                        _indexedTableCapacity);
               }
             } finally {
               _initLock.unlock();
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/operator/blocks/IntermediateResultsBlock.java b/pinot-core/src/main/java/org/apache/pinot/core/operator/blocks/IntermediateResultsBlock.java
index 882c8b2..63efe7c 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/operator/blocks/IntermediateResultsBlock.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/operator/blocks/IntermediateResultsBlock.java
@@ -39,7 +39,7 @@ import org.apache.pinot.core.common.datatable.DataTableBuilder;
 import org.apache.pinot.core.common.datatable.DataTableImplV2;
 import org.apache.pinot.core.data.table.Record;
 import org.apache.pinot.core.data.table.Table;
-import org.apache.pinot.core.query.aggregation.AggregationFunctionContext;
+import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
 import org.apache.pinot.core.query.aggregation.groupby.AggregationGroupByResult;
 import org.apache.pinot.core.query.selection.SelectionOperatorUtils;
 import org.apache.pinot.spi.utils.ByteArray;
@@ -51,7 +51,7 @@ import org.apache.pinot.spi.utils.ByteArray;
 public class IntermediateResultsBlock implements Block {
   private DataSchema _dataSchema;
   private Collection<Object[]> _selectionResult;
-  private AggregationFunctionContext[] _aggregationFunctionContexts;
+  private AggregationFunction[] _aggregationFunctions;
   private List<Object> _aggregationResult;
   private AggregationGroupByResult _aggregationGroupByResult;
   private List<Map<String, Object>> _combinedAggregationGroupByResult;
@@ -80,9 +80,9 @@ public class IntermediateResultsBlock implements Block {
    * <p>For aggregation group-by, the result is a list of maps from group keys to aggregation values.
    */
   @SuppressWarnings("unchecked")
-  public IntermediateResultsBlock(AggregationFunctionContext[] aggregationFunctionContexts, List aggregationResult,
+  public IntermediateResultsBlock(AggregationFunction[] aggregationFunctions, List aggregationResult,
       boolean isGroupBy) {
-    _aggregationFunctionContexts = aggregationFunctionContexts;
+    _aggregationFunctions = aggregationFunctions;
     if (isGroupBy) {
       _combinedAggregationGroupByResult = aggregationResult;
     } else {
@@ -93,18 +93,18 @@ public class IntermediateResultsBlock implements Block {
   /**
    * Constructor for aggregation group-by result with {@link AggregationGroupByResult}.
    */
-  public IntermediateResultsBlock(AggregationFunctionContext[] aggregationFunctionContexts,
+  public IntermediateResultsBlock(AggregationFunction[] aggregationFunctions,
       @Nullable AggregationGroupByResult aggregationGroupByResults) {
-    _aggregationFunctionContexts = aggregationFunctionContexts;
+    _aggregationFunctions = aggregationFunctions;
     _aggregationGroupByResult = aggregationGroupByResults;
   }
 
   /**
    * Constructor for aggregation group-by order-by result with {@link AggregationGroupByResult}.
    */
-  public IntermediateResultsBlock(AggregationFunctionContext[] aggregationFunctionContexts,
+  public IntermediateResultsBlock(AggregationFunction[] aggregationFunctions,
       @Nullable AggregationGroupByResult aggregationGroupByResults, DataSchema dataSchema) {
-    _aggregationFunctionContexts = aggregationFunctionContexts;
+    _aggregationFunctions = aggregationFunctions;
     _aggregationGroupByResult = aggregationGroupByResults;
     _dataSchema = dataSchema;
   }
@@ -134,7 +134,7 @@ public class IntermediateResultsBlock implements Block {
     return _dataSchema;
   }
 
-  public void setDataSchema(@Nullable DataSchema dataSchema) {
+  public void setDataSchema(DataSchema dataSchema) {
     _dataSchema = dataSchema;
   }
 
@@ -143,17 +143,17 @@ public class IntermediateResultsBlock implements Block {
     return _selectionResult;
   }
 
-  public void setSelectionResult(@Nullable Collection<Object[]> rowEventsSet) {
+  public void setSelectionResult(Collection<Object[]> rowEventsSet) {
     _selectionResult = rowEventsSet;
   }
 
   @Nullable
-  public AggregationFunctionContext[] getAggregationFunctionContexts() {
-    return _aggregationFunctionContexts;
+  public AggregationFunction[] getAggregationFunctions() {
+    return _aggregationFunctions;
   }
 
-  public void setAggregationFunctionContexts(AggregationFunctionContext[] aggregationFunctionContexts) {
-    _aggregationFunctionContexts = aggregationFunctionContexts;
+  public void setAggregationFunctions(AggregationFunction[] aggregationFunctions) {
+    _aggregationFunctions = aggregationFunctions;
   }
 
   @Nullable
@@ -161,7 +161,7 @@ public class IntermediateResultsBlock implements Block {
     return _aggregationResult;
   }
 
-  public void setAggregationResults(@Nullable List<Object> aggregationResults) {
+  public void setAggregationResults(List<Object> aggregationResults) {
     _aggregationResult = aggregationResults;
   }
 
@@ -175,7 +175,7 @@ public class IntermediateResultsBlock implements Block {
     return _processingExceptions;
   }
 
-  public void setProcessingExceptions(@Nullable List<ProcessingException> processingExceptions) {
+  public void setProcessingExceptions(List<ProcessingException> processingExceptions) {
     _processingExceptions = processingExceptions;
   }
 
@@ -322,14 +322,14 @@ public class IntermediateResultsBlock implements Block {
 
   private DataTable getAggregationResultDataTable()
       throws Exception {
-    // Extract each aggregation column name and type from aggregation function context.
-    int numAggregationFunctions = _aggregationFunctionContexts.length;
+    // Extract result column name and type from each aggregation function
+    int numAggregationFunctions = _aggregationFunctions.length;
     String[] columnNames = new String[numAggregationFunctions];
     ColumnDataType[] columnDataTypes = new ColumnDataType[numAggregationFunctions];
     for (int i = 0; i < numAggregationFunctions; i++) {
-      AggregationFunctionContext aggregationFunctionContext = _aggregationFunctionContexts[i];
-      columnNames[i] = aggregationFunctionContext.getAggregationColumnName();
-      columnDataTypes[i] = aggregationFunctionContext.getAggregationFunction().getIntermediateResultColumnType();
+      AggregationFunction aggregationFunction = _aggregationFunctions[i];
+      columnNames[i] = aggregationFunction.getColumnName();
+      columnDataTypes[i] = aggregationFunction.getIntermediateResultColumnType();
     }
 
     // Build the data table.
@@ -364,11 +364,11 @@ public class IntermediateResultsBlock implements Block {
 
     // Build the data table.
     DataTableBuilder dataTableBuilder = new DataTableBuilder(new DataSchema(columnNames, columnDataTypes));
-    int numAggregationFunctions = _aggregationFunctionContexts.length;
+    int numAggregationFunctions = _aggregationFunctions.length;
     for (int i = 0; i < numAggregationFunctions; i++) {
       dataTableBuilder.startRow();
-      AggregationFunctionContext aggregationFunctionContext = _aggregationFunctionContexts[i];
-      dataTableBuilder.setColumn(0, aggregationFunctionContext.getAggregationColumnName());
+      AggregationFunction aggregationFunction = _aggregationFunctions[i];
+      dataTableBuilder.setColumn(0, aggregationFunction.getColumnName());
       dataTableBuilder.setColumn(1, _combinedAggregationGroupByResult.get(i));
       dataTableBuilder.finishRow();
     }
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/operator/query/AggregationGroupByOperator.java b/pinot-core/src/main/java/org/apache/pinot/core/operator/query/AggregationGroupByOperator.java
index 58ae7b0..9236530 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/operator/query/AggregationGroupByOperator.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/operator/query/AggregationGroupByOperator.java
@@ -18,26 +18,27 @@
  */
 package org.apache.pinot.core.operator.query;
 
-import org.apache.pinot.common.request.GroupBy;
+import org.apache.pinot.common.request.transform.TransformExpressionTree;
 import org.apache.pinot.core.operator.BaseOperator;
 import org.apache.pinot.core.operator.ExecutionStatistics;
 import org.apache.pinot.core.operator.blocks.IntermediateResultsBlock;
 import org.apache.pinot.core.operator.blocks.TransformBlock;
 import org.apache.pinot.core.operator.transform.TransformOperator;
-import org.apache.pinot.core.query.aggregation.AggregationFunctionContext;
+import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
 import org.apache.pinot.core.query.aggregation.groupby.DefaultGroupByExecutor;
 import org.apache.pinot.core.query.aggregation.groupby.GroupByExecutor;
 import org.apache.pinot.core.startree.executor.StarTreeGroupByExecutor;
 
 
 /**
- * The <code>AggregationOperator</code> class provides the operator for aggregation group-by query on a single segment.
+ * The <code>AggregationGroupByOperator</code> class provides the operator for aggregation group-by query on a single
+ * segment.
  */
 public class AggregationGroupByOperator extends BaseOperator<IntermediateResultsBlock> {
   private static final String OPERATOR_NAME = "AggregationGroupByOperator";
 
-  private final AggregationFunctionContext[] _functionContexts;
-  private final GroupBy _groupBy;
+  private final AggregationFunction[] _aggregationFunctions;
+  private final TransformExpressionTree[] _groupByExpressions;
   private final int _maxInitialResultHolderCapacity;
   private final int _numGroupsLimit;
   private final TransformOperator _transformOperator;
@@ -46,11 +47,11 @@ public class AggregationGroupByOperator extends BaseOperator<IntermediateResults
 
   private int _numDocsScanned = 0;
 
-  public AggregationGroupByOperator(AggregationFunctionContext[] functionContexts, GroupBy groupBy,
-      int maxInitialResultHolderCapacity, int numGroupsLimit, TransformOperator transformOperator, long numTotalDocs,
-      boolean useStarTree) {
-    _functionContexts = functionContexts;
-    _groupBy = groupBy;
+  public AggregationGroupByOperator(AggregationFunction[] aggregationFunctions,
+      TransformExpressionTree[] groupByExpressions, int maxInitialResultHolderCapacity, int numGroupsLimit,
+      TransformOperator transformOperator, long numTotalDocs, boolean useStarTree) {
+    _aggregationFunctions = aggregationFunctions;
+    _groupByExpressions = groupByExpressions;
     _maxInitialResultHolderCapacity = maxInitialResultHolderCapacity;
     _numGroupsLimit = numGroupsLimit;
     _transformOperator = transformOperator;
@@ -64,12 +65,12 @@ public class AggregationGroupByOperator extends BaseOperator<IntermediateResults
     GroupByExecutor groupByExecutor;
     if (_useStarTree) {
       groupByExecutor =
-          new StarTreeGroupByExecutor(_functionContexts, _groupBy, _maxInitialResultHolderCapacity, _numGroupsLimit,
-              _transformOperator);
+          new StarTreeGroupByExecutor(_aggregationFunctions, _groupByExpressions, _maxInitialResultHolderCapacity,
+              _numGroupsLimit, _transformOperator);
     } else {
       groupByExecutor =
-          new DefaultGroupByExecutor(_functionContexts, _groupBy, _maxInitialResultHolderCapacity, _numGroupsLimit,
-              _transformOperator);
+          new DefaultGroupByExecutor(_aggregationFunctions, _groupByExpressions, _maxInitialResultHolderCapacity,
+              _numGroupsLimit, _transformOperator);
     }
     TransformBlock transformBlock;
     while ((transformBlock = _transformOperator.nextBlock()) != null) {
@@ -78,7 +79,7 @@ public class AggregationGroupByOperator extends BaseOperator<IntermediateResults
     }
 
     // Build intermediate result block based on aggregation group-by result from the executor
-    return new IntermediateResultsBlock(_functionContexts, groupByExecutor.getResult());
+    return new IntermediateResultsBlock(_aggregationFunctions, groupByExecutor.getResult());
   }
 
   @Override
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/operator/query/AggregationGroupByOrderByOperator.java b/pinot-core/src/main/java/org/apache/pinot/core/operator/query/AggregationGroupByOrderByOperator.java
index a2f840d..abd26c3 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/operator/query/AggregationGroupByOrderByOperator.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/operator/query/AggregationGroupByOrderByOperator.java
@@ -18,7 +18,6 @@
  */
 package org.apache.pinot.core.operator.query;
 
-import org.apache.pinot.common.request.GroupBy;
 import org.apache.pinot.common.request.transform.TransformExpressionTree;
 import org.apache.pinot.common.utils.DataSchema;
 import org.apache.pinot.core.operator.BaseOperator;
@@ -26,35 +25,36 @@ import org.apache.pinot.core.operator.ExecutionStatistics;
 import org.apache.pinot.core.operator.blocks.IntermediateResultsBlock;
 import org.apache.pinot.core.operator.blocks.TransformBlock;
 import org.apache.pinot.core.operator.transform.TransformOperator;
-import org.apache.pinot.core.query.aggregation.AggregationFunctionContext;
+import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
 import org.apache.pinot.core.query.aggregation.groupby.DefaultGroupByExecutor;
 import org.apache.pinot.core.query.aggregation.groupby.GroupByExecutor;
 import org.apache.pinot.core.startree.executor.StarTreeGroupByExecutor;
 
 
 /**
- * The <code>AggregationOperator</code> class provides the operator for aggregation group-by query on a single segment.
+ * The <code>AggregationGroupByOrderByOperator</code> class provides the operator for aggregation group-by query on a
+ * single segment.
  */
 public class AggregationGroupByOrderByOperator extends BaseOperator<IntermediateResultsBlock> {
   private static final String OPERATOR_NAME = "AggregationGroupByOrderByOperator";
 
   private final DataSchema _dataSchema;
 
-  private final AggregationFunctionContext[] _functionContexts;
-  private final GroupBy _groupBy;
+  private final AggregationFunction[] _aggregationFunctions;
+  private final TransformExpressionTree[] _groupByExpressions;
   private final int _maxInitialResultHolderCapacity;
   private final int _numGroupsLimit;
   private final TransformOperator _transformOperator;
   private final long _numTotalDocs;
   private final boolean _useStarTree;
 
-  private int _numDocsScanned;
+  private int _numDocsScanned = 0;
 
-  public AggregationGroupByOrderByOperator(AggregationFunctionContext[] functionContexts, GroupBy groupBy,
-      int maxInitialResultHolderCapacity, int numGroupsLimit, TransformOperator transformOperator, long numTotalDocs,
-      boolean useStarTree) {
-    _functionContexts = functionContexts;
-    _groupBy = groupBy;
+  public AggregationGroupByOrderByOperator(AggregationFunction[] aggregationFunctions,
+      TransformExpressionTree[] groupByExpressions, int maxInitialResultHolderCapacity, int numGroupsLimit,
+      TransformOperator transformOperator, long numTotalDocs, boolean useStarTree) {
+    _aggregationFunctions = aggregationFunctions;
+    _groupByExpressions = groupByExpressions;
     _maxInitialResultHolderCapacity = maxInitialResultHolderCapacity;
     _numGroupsLimit = numGroupsLimit;
     _transformOperator = transformOperator;
@@ -62,30 +62,28 @@ public class AggregationGroupByOrderByOperator extends BaseOperator<Intermediate
     _useStarTree = useStarTree;
 
     // NOTE: The indexedTable expects that the the data schema will have group by columns before aggregation columns
-    int numColumns = groupBy.getExpressionsSize() + _functionContexts.length;
+    int numGroupByExpressions = groupByExpressions.length;
+    int numAggregationFunctions = aggregationFunctions.length;
+    int numColumns = numGroupByExpressions + numAggregationFunctions;
     String[] columnNames = new String[numColumns];
     DataSchema.ColumnDataType[] columnDataTypes = new DataSchema.ColumnDataType[numColumns];
 
-    // extract column names and data types for group by keys
-    int index = 0;
-    for (String groupByColumn : groupBy.getExpressions()) {
-      columnNames[index] = groupByColumn;
-      TransformExpressionTree expression = TransformExpressionTree.compileToExpressionTree(groupByColumn);
-      columnDataTypes[index] =
-          DataSchema.ColumnDataType.fromDataTypeSV(_transformOperator.getResultMetadata(expression).getDataType());
-      index++;
+    // Extract column names and data types for group-by columns
+    for (int i = 0; i < numGroupByExpressions; i++) {
+      TransformExpressionTree groupByExpression = groupByExpressions[i];
+      columnNames[i] = groupByExpression.toString();
+      columnDataTypes[i] = DataSchema.ColumnDataType
+          .fromDataTypeSV(_transformOperator.getResultMetadata(groupByExpression).getDataType());
     }
 
-    // extract column names and data types for aggregations
-    for (AggregationFunctionContext functionContext : functionContexts) {
-      columnNames[index] = functionContext.getResultColumnName();
-      columnDataTypes[index] = functionContext.getAggregationFunction().getIntermediateResultColumnType();
-      index++;
+    // Extract column names and data types for aggregation functions
+    for (int i = 0; i < numAggregationFunctions; i++) {
+      AggregationFunction aggregationFunction = aggregationFunctions[i];
+      int index = numGroupByExpressions + i;
+      columnNames[index] = aggregationFunction.getResultColumnName();
+      columnDataTypes[index] = aggregationFunction.getIntermediateResultColumnType();
     }
 
-    // TODO: We need to support putting order by columns in the data schema.
-    //  It is possible that the order by column is not one of the group by or aggregation columns
-
     _dataSchema = new DataSchema(columnNames, columnDataTypes);
   }
 
@@ -95,12 +93,12 @@ public class AggregationGroupByOrderByOperator extends BaseOperator<Intermediate
     GroupByExecutor groupByExecutor;
     if (_useStarTree) {
       groupByExecutor =
-          new StarTreeGroupByExecutor(_functionContexts, _groupBy, _maxInitialResultHolderCapacity, _numGroupsLimit,
-              _transformOperator);
+          new StarTreeGroupByExecutor(_aggregationFunctions, _groupByExpressions, _maxInitialResultHolderCapacity,
+              _numGroupsLimit, _transformOperator);
     } else {
       groupByExecutor =
-          new DefaultGroupByExecutor(_functionContexts, _groupBy, _maxInitialResultHolderCapacity, _numGroupsLimit,
-              _transformOperator);
+          new DefaultGroupByExecutor(_aggregationFunctions, _groupByExpressions, _maxInitialResultHolderCapacity,
+              _numGroupsLimit, _transformOperator);
     }
     TransformBlock transformBlock;
     while ((transformBlock = _transformOperator.nextBlock()) != null) {
@@ -109,7 +107,7 @@ public class AggregationGroupByOrderByOperator extends BaseOperator<Intermediate
     }
 
     // Build intermediate result block based on aggregation group-by result from the executor
-    return new IntermediateResultsBlock(_functionContexts, groupByExecutor.getResult(), _dataSchema);
+    return new IntermediateResultsBlock(_aggregationFunctions, groupByExecutor.getResult(), _dataSchema);
   }
 
   @Override
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/operator/query/AggregationOperator.java b/pinot-core/src/main/java/org/apache/pinot/core/operator/query/AggregationOperator.java
index a20d39d..e29976a 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/operator/query/AggregationOperator.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/operator/query/AggregationOperator.java
@@ -24,8 +24,8 @@ import org.apache.pinot.core.operator.blocks.IntermediateResultsBlock;
 import org.apache.pinot.core.operator.blocks.TransformBlock;
 import org.apache.pinot.core.operator.transform.TransformOperator;
 import org.apache.pinot.core.query.aggregation.AggregationExecutor;
-import org.apache.pinot.core.query.aggregation.AggregationFunctionContext;
 import org.apache.pinot.core.query.aggregation.DefaultAggregationExecutor;
+import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
 import org.apache.pinot.core.startree.executor.StarTreeAggregationExecutor;
 
 
@@ -35,16 +35,16 @@ import org.apache.pinot.core.startree.executor.StarTreeAggregationExecutor;
 public class AggregationOperator extends BaseOperator<IntermediateResultsBlock> {
   private static final String OPERATOR_NAME = "AggregationOperator";
 
-  private final AggregationFunctionContext[] _functionContexts;
+  private final AggregationFunction[] _aggregationFunctions;
   private final TransformOperator _transformOperator;
   private final long _numTotalDocs;
   private final boolean _useStarTree;
 
   private int _numDocsScanned = 0;
 
-  public AggregationOperator(AggregationFunctionContext[] functionContexts, TransformOperator transformOperator,
+  public AggregationOperator(AggregationFunction[] aggregationFunctions, TransformOperator transformOperator,
       long numTotalDocs, boolean useStarTree) {
-    _functionContexts = functionContexts;
+    _aggregationFunctions = aggregationFunctions;
     _transformOperator = transformOperator;
     _numTotalDocs = numTotalDocs;
     _useStarTree = useStarTree;
@@ -55,9 +55,9 @@ public class AggregationOperator extends BaseOperator<IntermediateResultsBlock>
     // Perform aggregation on all the transform blocks
     AggregationExecutor aggregationExecutor;
     if (_useStarTree) {
-      aggregationExecutor = new StarTreeAggregationExecutor(_functionContexts);
+      aggregationExecutor = new StarTreeAggregationExecutor(_aggregationFunctions);
     } else {
-      aggregationExecutor = new DefaultAggregationExecutor(_functionContexts);
+      aggregationExecutor = new DefaultAggregationExecutor(_aggregationFunctions);
     }
     TransformBlock transformBlock;
     while ((transformBlock = _transformOperator.nextBlock()) != null) {
@@ -66,7 +66,7 @@ public class AggregationOperator extends BaseOperator<IntermediateResultsBlock>
     }
 
     // Build intermediate result block based on aggregation result from the executor
-    return new IntermediateResultsBlock(_functionContexts, aggregationExecutor.getResult(), false);
+    return new IntermediateResultsBlock(_aggregationFunctions, aggregationExecutor.getResult(), false);
   }
 
   @Override
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/operator/query/DictionaryBasedAggregationOperator.java b/pinot-core/src/main/java/org/apache/pinot/core/operator/query/DictionaryBasedAggregationOperator.java
index 5c230da..20cd2e8 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/operator/query/DictionaryBasedAggregationOperator.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/operator/query/DictionaryBasedAggregationOperator.java
@@ -21,14 +21,10 @@ package org.apache.pinot.core.operator.query;
 import java.util.ArrayList;
 import java.util.List;
 import java.util.Map;
-import org.apache.pinot.common.function.AggregationFunctionType;
+import org.apache.pinot.common.request.transform.TransformExpressionTree;
 import org.apache.pinot.core.operator.BaseOperator;
 import org.apache.pinot.core.operator.ExecutionStatistics;
 import org.apache.pinot.core.operator.blocks.IntermediateResultsBlock;
-import org.apache.pinot.core.query.aggregation.AggregationFunctionContext;
-import org.apache.pinot.core.query.aggregation.AggregationResultHolder;
-import org.apache.pinot.core.query.aggregation.DoubleAggregationResultHolder;
-import org.apache.pinot.core.query.aggregation.ObjectAggregationResultHolder;
 import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
 import org.apache.pinot.core.query.aggregation.function.customobject.MinMaxRangePair;
 import org.apache.pinot.core.segment.index.readers.Dictionary;
@@ -47,56 +43,43 @@ import org.apache.pinot.core.segment.index.readers.Dictionary;
 public class DictionaryBasedAggregationOperator extends BaseOperator<IntermediateResultsBlock> {
   private static final String OPERATOR_NAME = "DictionaryBasedAggregationOperator";
 
-  private final AggregationFunctionContext[] _aggregationFunctionContexts;
+  private final AggregationFunction[] _aggregationFunctions;
   private final Map<String, Dictionary> _dictionaryMap;
-  private final long _numTotalDocs;
+  private final int _numTotalDocs;
 
-  /**
-   * Constructor for the class.
-   * @param aggregationFunctionContexts Aggregation function contexts.
-   * @param numTotalDocs total raw docs from segmet metadata
-   * @param dictionaryMap Map of column to its dictionary.
-   */
-  public DictionaryBasedAggregationOperator(AggregationFunctionContext[] aggregationFunctionContexts, long numTotalDocs,
-      Map<String, Dictionary> dictionaryMap) {
-    _aggregationFunctionContexts = aggregationFunctionContexts;
+  public DictionaryBasedAggregationOperator(AggregationFunction[] aggregationFunctions,
+      Map<String, Dictionary> dictionaryMap, int numTotalDocs) {
+    _aggregationFunctions = aggregationFunctions;
     _dictionaryMap = dictionaryMap;
     _numTotalDocs = numTotalDocs;
   }
 
   @Override
   protected IntermediateResultsBlock getNextBlock() {
-    int numAggregationFunctions = _aggregationFunctionContexts.length;
+    int numAggregationFunctions = _aggregationFunctions.length;
     List<Object> aggregationResults = new ArrayList<>(numAggregationFunctions);
-
-    for (AggregationFunctionContext aggregationFunctionContext : _aggregationFunctionContexts) {
-      AggregationFunction function = aggregationFunctionContext.getAggregationFunction();
-      AggregationFunctionType functionType = function.getType();
-      String column = aggregationFunctionContext.getColumnName();
+    for (AggregationFunction aggregationFunction : _aggregationFunctions) {
+      String column = ((TransformExpressionTree) aggregationFunction.getInputExpressions().get(0)).getValue();
       Dictionary dictionary = _dictionaryMap.get(column);
-      AggregationResultHolder resultHolder;
-      switch (functionType) {
+      switch (aggregationFunction.getType()) {
         case MAX:
-          resultHolder = new DoubleAggregationResultHolder(dictionary.getDoubleValue(dictionary.length() - 1));
+          aggregationResults.add(dictionary.getDoubleValue(dictionary.length() - 1));
           break;
         case MIN:
-          resultHolder = new DoubleAggregationResultHolder(dictionary.getDoubleValue(0));
+          aggregationResults.add(dictionary.getDoubleValue(0));
           break;
         case MINMAXRANGE:
-          double max = dictionary.getDoubleValue(dictionary.length() - 1);
-          double min = dictionary.getDoubleValue(0);
-          resultHolder = new ObjectAggregationResultHolder();
-          resultHolder.setValue(new MinMaxRangePair(min, max));
+          aggregationResults.add(
+              new MinMaxRangePair(dictionary.getDoubleValue(0), dictionary.getDoubleValue(dictionary.length() - 1)));
           break;
         default:
           throw new IllegalStateException(
-              "Dictionary based aggregation operator does not support function type: " + functionType);
+              "Dictionary based aggregation operator does not support function type: " + aggregationFunction.getType());
       }
-      aggregationResults.add(function.extractAggregationResult(resultHolder));
     }
 
     // Build intermediate result block based on aggregation result from the executor.
-    return new IntermediateResultsBlock(_aggregationFunctionContexts, aggregationResults, false);
+    return new IntermediateResultsBlock(_aggregationFunctions, aggregationResults, false);
   }
 
   @Override
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/operator/query/MetadataBasedAggregationOperator.java b/pinot-core/src/main/java/org/apache/pinot/core/operator/query/MetadataBasedAggregationOperator.java
index f999e62..514299b 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/operator/query/MetadataBasedAggregationOperator.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/operator/query/MetadataBasedAggregationOperator.java
@@ -27,7 +27,7 @@ import org.apache.pinot.core.common.DataSource;
 import org.apache.pinot.core.operator.BaseOperator;
 import org.apache.pinot.core.operator.ExecutionStatistics;
 import org.apache.pinot.core.operator.blocks.IntermediateResultsBlock;
-import org.apache.pinot.core.query.aggregation.AggregationFunctionContext;
+import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
 import org.apache.pinot.core.segment.index.metadata.SegmentMetadata;
 
 
@@ -37,41 +37,33 @@ import org.apache.pinot.core.segment.index.metadata.SegmentMetadata;
 public class MetadataBasedAggregationOperator extends BaseOperator<IntermediateResultsBlock> {
   private static final String OPERATOR_NAME = "MetadataBasedAggregationOperator";
 
-  private final AggregationFunctionContext[] _aggregationFunctionContexts;
-  private final Map<String, DataSource> _dataSourceMap;
+  private final AggregationFunction[] _aggregationFunctions;
   private final SegmentMetadata _segmentMetadata;
+  private final Map<String, DataSource> _dataSourceMap;
 
-  /**
-   * Constructor for the class.
-   *
-   * @param aggregationFunctionContexts Aggregation function contexts.
-   * @param segmentMetadata Segment metadata.
-   * @param dataSourceMap Map of column to its data source.
-   */
-  public MetadataBasedAggregationOperator(AggregationFunctionContext[] aggregationFunctionContexts,
-      SegmentMetadata segmentMetadata, Map<String, DataSource> dataSourceMap) {
-    _aggregationFunctionContexts = aggregationFunctionContexts;
+  public MetadataBasedAggregationOperator(AggregationFunction[] aggregationFunctions, SegmentMetadata segmentMetadata,
+      Map<String, DataSource> dataSourceMap) {
+    _aggregationFunctions = aggregationFunctions;
+    _segmentMetadata = segmentMetadata;
 
     // Datasource is currently not used, but will start getting used as we add support for aggregation
     // functions other than count(*).
     _dataSourceMap = dataSourceMap;
-    _segmentMetadata = segmentMetadata;
   }
 
   @Override
   protected IntermediateResultsBlock getNextBlock() {
-    int numAggregationFunctions = _aggregationFunctionContexts.length;
+    int numAggregationFunctions = _aggregationFunctions.length;
     List<Object> aggregationResults = new ArrayList<>(numAggregationFunctions);
     long numTotalDocs = _segmentMetadata.getTotalDocs();
-    for (AggregationFunctionContext aggregationFunctionContext : _aggregationFunctionContexts) {
-      AggregationFunctionType functionType = aggregationFunctionContext.getAggregationFunction().getType();
-      Preconditions.checkState(functionType == AggregationFunctionType.COUNT,
-          "Metadata based aggregation operator does not support function type: " + functionType);
+    for (AggregationFunction aggregationFunction : _aggregationFunctions) {
+      Preconditions.checkState(aggregationFunction.getType() == AggregationFunctionType.COUNT,
+          "Metadata based aggregation operator does not support function type: " + aggregationFunction.getType());
       aggregationResults.add(numTotalDocs);
     }
 
     // Build intermediate result block based on aggregation result from the executor.
-    return new IntermediateResultsBlock(_aggregationFunctionContexts, aggregationResults, false);
+    return new IntermediateResultsBlock(_aggregationFunctions, aggregationResults, false);
   }
 
   @Override
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/plan/AggregationGroupByOrderByPlanNode.java b/pinot-core/src/main/java/org/apache/pinot/core/plan/AggregationGroupByOrderByPlanNode.java
index 70fb1e2..a22de73 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/plan/AggregationGroupByOrderByPlanNode.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/plan/AggregationGroupByOrderByPlanNode.java
@@ -18,7 +18,6 @@
  */
 package org.apache.pinot.core.plan;
 
-import java.util.HashSet;
 import java.util.List;
 import java.util.Set;
 import org.apache.pinot.common.request.AggregationInfo;
@@ -29,7 +28,7 @@ import org.apache.pinot.common.utils.request.FilterQueryTree;
 import org.apache.pinot.common.utils.request.RequestUtils;
 import org.apache.pinot.core.indexsegment.IndexSegment;
 import org.apache.pinot.core.operator.query.AggregationGroupByOrderByOperator;
-import org.apache.pinot.core.query.aggregation.AggregationFunctionContext;
+import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
 import org.apache.pinot.core.query.aggregation.function.AggregationFunctionUtils;
 import org.apache.pinot.core.startree.StarTreeUtils;
 import org.apache.pinot.core.startree.plan.StarTreeTransformPlanNode;
@@ -50,8 +49,9 @@ public class AggregationGroupByOrderByPlanNode implements PlanNode {
   private final int _maxInitialResultHolderCapacity;
   private final int _numGroupsLimit;
   private final List<AggregationInfo> _aggregationInfos;
-  private final AggregationFunctionContext[] _functionContexts;
+  private final AggregationFunction[] _aggregationFunctions;
   private final GroupBy _groupBy;
+  private final TransformExpressionTree[] _groupByExpressions;
   private final TransformPlanNode _transformPlanNode;
   private final StarTreeTransformPlanNode _starTreeTransformPlanNode;
 
@@ -61,38 +61,51 @@ public class AggregationGroupByOrderByPlanNode implements PlanNode {
     _maxInitialResultHolderCapacity = maxInitialResultHolderCapacity;
     _numGroupsLimit = numGroupsLimit;
     _aggregationInfos = brokerRequest.getAggregationsInfo();
-    _functionContexts = AggregationFunctionUtils.getAggregationFunctionContexts(brokerRequest);
+    _aggregationFunctions = AggregationFunctionUtils.getAggregationFunctions(brokerRequest);
     _groupBy = brokerRequest.getGroupBy();
-
-    Set<TransformExpressionTree> expressionsToTransform =
-        AggregationFunctionUtils.collectExpressionsToTransform(brokerRequest, _functionContexts);
+    List<String> groupByExpressions = _groupBy.getExpressions();
+    int numGroupByExpressions = groupByExpressions.size();
+    _groupByExpressions = new TransformExpressionTree[numGroupByExpressions];
+    for (int i = 0; i < numGroupByExpressions; i++) {
+      _groupByExpressions[i] = TransformExpressionTree.compileToExpressionTree(groupByExpressions.get(i));
+    }
 
     List<StarTreeV2> starTrees = indexSegment.getStarTrees();
     if (starTrees != null) {
       if (!StarTreeUtils.isStarTreeDisabled(brokerRequest)) {
-        Set<AggregationFunctionColumnPair> aggregationFunctionColumnPairs = new HashSet<>();
-        for (AggregationInfo aggregationInfo : _aggregationInfos) {
-          aggregationFunctionColumnPairs.add(AggregationFunctionUtils.getFunctionColumnPair(aggregationInfo));
-        }
-        Set<TransformExpressionTree> groupByExpressions = new HashSet<>();
-        for (String expression : _groupBy.getExpressions()) {
-          groupByExpressions.add(TransformExpressionTree.compileToExpressionTree(expression));
+        int numAggregationFunctions = _aggregationFunctions.length;
+        AggregationFunctionColumnPair[] aggregationFunctionColumnPairs =
+            new AggregationFunctionColumnPair[numAggregationFunctions];
+        boolean hasUnsupportedAggregationFunction = false;
+        for (int i = 0; i < numAggregationFunctions; i++) {
+          AggregationFunctionColumnPair aggregationFunctionColumnPair =
+              AggregationFunctionUtils.getAggregationFunctionColumnPair(_aggregationFunctions[i]);
+          if (aggregationFunctionColumnPair != null) {
+            aggregationFunctionColumnPairs[i] = aggregationFunctionColumnPair;
+          } else {
+            hasUnsupportedAggregationFunction = true;
+            break;
+          }
         }
-        FilterQueryTree rootFilterNode = RequestUtils.generateFilterQueryTree(brokerRequest);
-        for (StarTreeV2 starTreeV2 : starTrees) {
-          if (StarTreeUtils
-              .isFitForStarTree(starTreeV2.getMetadata(), aggregationFunctionColumnPairs, groupByExpressions,
-                  rootFilterNode)) {
-            _transformPlanNode = null;
-            _starTreeTransformPlanNode =
-                new StarTreeTransformPlanNode(starTreeV2, aggregationFunctionColumnPairs, groupByExpressions,
-                    rootFilterNode, brokerRequest.getDebugOptions());
-            return;
+        if (!hasUnsupportedAggregationFunction) {
+          FilterQueryTree rootFilterNode = RequestUtils.generateFilterQueryTree(brokerRequest);
+          for (StarTreeV2 starTreeV2 : starTrees) {
+            if (StarTreeUtils
+                .isFitForStarTree(starTreeV2.getMetadata(), aggregationFunctionColumnPairs, _groupByExpressions,
+                    rootFilterNode)) {
+              _transformPlanNode = null;
+              _starTreeTransformPlanNode =
+                  new StarTreeTransformPlanNode(starTreeV2, aggregationFunctionColumnPairs, _groupByExpressions,
+                      rootFilterNode, brokerRequest.getDebugOptions());
+              return;
+            }
           }
         }
       }
     }
 
+    Set<TransformExpressionTree> expressionsToTransform =
+        AggregationFunctionUtils.collectExpressionsToTransform(_aggregationFunctions, _groupByExpressions);
     _transformPlanNode = new TransformPlanNode(_indexSegment, brokerRequest, expressionsToTransform);
     _starTreeTransformPlanNode = null;
   }
@@ -102,19 +115,19 @@ public class AggregationGroupByOrderByPlanNode implements PlanNode {
     int numTotalDocs = _indexSegment.getSegmentMetadata().getTotalDocs();
     if (_transformPlanNode != null) {
       // Do not use star-tree
-      return new AggregationGroupByOrderByOperator(_functionContexts, _groupBy, _maxInitialResultHolderCapacity,
-          _numGroupsLimit, _transformPlanNode.run(), numTotalDocs, false);
+      return new AggregationGroupByOrderByOperator(_aggregationFunctions, _groupByExpressions,
+          _maxInitialResultHolderCapacity, _numGroupsLimit, _transformPlanNode.run(), numTotalDocs, false);
     } else {
       // Use star-tree
-      return new AggregationGroupByOrderByOperator(_functionContexts, _groupBy, _maxInitialResultHolderCapacity,
-          _numGroupsLimit, _starTreeTransformPlanNode.run(), numTotalDocs, true);
+      return new AggregationGroupByOrderByOperator(_aggregationFunctions, _groupByExpressions,
+          _maxInitialResultHolderCapacity, _numGroupsLimit, _starTreeTransformPlanNode.run(), numTotalDocs, true);
     }
   }
 
   @Override
   public void showTree(String prefix) {
-    LOGGER.debug(prefix + "Aggregation Group-by Plan Node:");
-    LOGGER.debug(prefix + "Operator: AggregationGroupByOperator");
+    LOGGER.debug(prefix + "Aggregation Group-by Order-by Plan Node:");
+    LOGGER.debug(prefix + "Operator: AggregationGroupByOrderByOperator");
     LOGGER.debug(prefix + "Argument 0: IndexSegment - " + _indexSegment.getSegmentName());
     LOGGER.debug(prefix + "Argument 1: Aggregations - " + _aggregationInfos);
     LOGGER.debug(prefix + "Argument 2: GroupBy - " + _groupBy);
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/plan/AggregationGroupByPlanNode.java b/pinot-core/src/main/java/org/apache/pinot/core/plan/AggregationGroupByPlanNode.java
index 44b456c..543903e 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/plan/AggregationGroupByPlanNode.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/plan/AggregationGroupByPlanNode.java
@@ -18,7 +18,6 @@
  */
 package org.apache.pinot.core.plan;
 
-import java.util.HashSet;
 import java.util.List;
 import java.util.Set;
 import org.apache.pinot.common.request.AggregationInfo;
@@ -29,7 +28,7 @@ import org.apache.pinot.common.utils.request.FilterQueryTree;
 import org.apache.pinot.common.utils.request.RequestUtils;
 import org.apache.pinot.core.indexsegment.IndexSegment;
 import org.apache.pinot.core.operator.query.AggregationGroupByOperator;
-import org.apache.pinot.core.query.aggregation.AggregationFunctionContext;
+import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
 import org.apache.pinot.core.query.aggregation.function.AggregationFunctionUtils;
 import org.apache.pinot.core.startree.StarTreeUtils;
 import org.apache.pinot.core.startree.plan.StarTreeTransformPlanNode;
@@ -50,8 +49,9 @@ public class AggregationGroupByPlanNode implements PlanNode {
   private final int _maxInitialResultHolderCapacity;
   private final int _numGroupsLimit;
   private final List<AggregationInfo> _aggregationInfos;
-  private final AggregationFunctionContext[] _functionContexts;
+  private final AggregationFunction[] _aggregationFunctions;
   private final GroupBy _groupBy;
+  private final TransformExpressionTree[] _groupByExpressions;
   private final TransformPlanNode _transformPlanNode;
   private final StarTreeTransformPlanNode _starTreeTransformPlanNode;
 
@@ -61,37 +61,51 @@ public class AggregationGroupByPlanNode implements PlanNode {
     _maxInitialResultHolderCapacity = maxInitialResultHolderCapacity;
     _numGroupsLimit = numGroupsLimit;
     _aggregationInfos = brokerRequest.getAggregationsInfo();
-    _functionContexts = AggregationFunctionUtils.getAggregationFunctionContexts(brokerRequest);
+    _aggregationFunctions = AggregationFunctionUtils.getAggregationFunctions(brokerRequest);
     _groupBy = brokerRequest.getGroupBy();
+    List<String> groupByExpressions = _groupBy.getExpressions();
+    int numGroupByExpressions = groupByExpressions.size();
+    _groupByExpressions = new TransformExpressionTree[numGroupByExpressions];
+    for (int i = 0; i < numGroupByExpressions; i++) {
+      _groupByExpressions[i] = TransformExpressionTree.compileToExpressionTree(groupByExpressions.get(i));
+    }
 
     List<StarTreeV2> starTrees = indexSegment.getStarTrees();
     if (starTrees != null) {
       if (!StarTreeUtils.isStarTreeDisabled(brokerRequest)) {
-        Set<AggregationFunctionColumnPair> aggregationFunctionColumnPairs = new HashSet<>();
-        for (AggregationInfo aggregationInfo : _aggregationInfos) {
-          aggregationFunctionColumnPairs.add(AggregationFunctionUtils.getFunctionColumnPair(aggregationInfo));
-        }
-        Set<TransformExpressionTree> groupByExpressions = new HashSet<>();
-        for (String expression : _groupBy.getExpressions()) {
-          groupByExpressions.add(TransformExpressionTree.compileToExpressionTree(expression));
+        int numAggregationFunctions = _aggregationFunctions.length;
+        AggregationFunctionColumnPair[] aggregationFunctionColumnPairs =
+            new AggregationFunctionColumnPair[numAggregationFunctions];
+        boolean hasUnsupportedAggregationFunction = false;
+        for (int i = 0; i < numAggregationFunctions; i++) {
+          AggregationFunctionColumnPair aggregationFunctionColumnPair =
+              AggregationFunctionUtils.getAggregationFunctionColumnPair(_aggregationFunctions[i]);
+          if (aggregationFunctionColumnPair != null) {
+            aggregationFunctionColumnPairs[i] = aggregationFunctionColumnPair;
+          } else {
+            hasUnsupportedAggregationFunction = true;
+            break;
+          }
         }
-        FilterQueryTree rootFilterNode = RequestUtils.generateFilterQueryTree(brokerRequest);
-        for (StarTreeV2 starTreeV2 : starTrees) {
-          if (StarTreeUtils
-              .isFitForStarTree(starTreeV2.getMetadata(), aggregationFunctionColumnPairs, groupByExpressions,
-                  rootFilterNode)) {
-            _transformPlanNode = null;
-            _starTreeTransformPlanNode =
-                new StarTreeTransformPlanNode(starTreeV2, aggregationFunctionColumnPairs, groupByExpressions,
-                    rootFilterNode, brokerRequest.getDebugOptions());
-            return;
+        if (!hasUnsupportedAggregationFunction) {
+          FilterQueryTree rootFilterNode = RequestUtils.generateFilterQueryTree(brokerRequest);
+          for (StarTreeV2 starTreeV2 : starTrees) {
+            if (StarTreeUtils
+                .isFitForStarTree(starTreeV2.getMetadata(), aggregationFunctionColumnPairs, _groupByExpressions,
+                    rootFilterNode)) {
+              _transformPlanNode = null;
+              _starTreeTransformPlanNode =
+                  new StarTreeTransformPlanNode(starTreeV2, aggregationFunctionColumnPairs, _groupByExpressions,
+                      rootFilterNode, brokerRequest.getDebugOptions());
+              return;
+            }
           }
         }
       }
     }
 
     Set<TransformExpressionTree> expressionsToTransform =
-        AggregationFunctionUtils.collectExpressionsToTransform(brokerRequest, _functionContexts);
+        AggregationFunctionUtils.collectExpressionsToTransform(_aggregationFunctions, _groupByExpressions);
     _transformPlanNode = new TransformPlanNode(_indexSegment, brokerRequest, expressionsToTransform);
     _starTreeTransformPlanNode = null;
   }
@@ -101,11 +115,11 @@ public class AggregationGroupByPlanNode implements PlanNode {
     int numTotalDocs = _indexSegment.getSegmentMetadata().getTotalDocs();
     if (_transformPlanNode != null) {
       // Do not use star-tree
-      return new AggregationGroupByOperator(_functionContexts, _groupBy, _maxInitialResultHolderCapacity,
+      return new AggregationGroupByOperator(_aggregationFunctions, _groupByExpressions, _maxInitialResultHolderCapacity,
           _numGroupsLimit, _transformPlanNode.run(), numTotalDocs, false);
     } else {
       // Use star-tree
-      return new AggregationGroupByOperator(_functionContexts, _groupBy, _maxInitialResultHolderCapacity,
+      return new AggregationGroupByOperator(_aggregationFunctions, _groupByExpressions, _maxInitialResultHolderCapacity,
           _numGroupsLimit, _starTreeTransformPlanNode.run(), numTotalDocs, true);
     }
   }
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/plan/AggregationPlanNode.java b/pinot-core/src/main/java/org/apache/pinot/core/plan/AggregationPlanNode.java
index 8932a37..ce0217b 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/plan/AggregationPlanNode.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/plan/AggregationPlanNode.java
@@ -18,7 +18,6 @@
  */
 package org.apache.pinot.core.plan;
 
-import java.util.HashSet;
 import java.util.List;
 import java.util.Set;
 import org.apache.pinot.common.request.AggregationInfo;
@@ -28,7 +27,7 @@ import org.apache.pinot.common.utils.request.FilterQueryTree;
 import org.apache.pinot.common.utils.request.RequestUtils;
 import org.apache.pinot.core.indexsegment.IndexSegment;
 import org.apache.pinot.core.operator.query.AggregationOperator;
-import org.apache.pinot.core.query.aggregation.AggregationFunctionContext;
+import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
 import org.apache.pinot.core.query.aggregation.function.AggregationFunctionUtils;
 import org.apache.pinot.core.startree.StarTreeUtils;
 import org.apache.pinot.core.startree.plan.StarTreeTransformPlanNode;
@@ -47,38 +46,50 @@ public class AggregationPlanNode implements PlanNode {
 
   private final IndexSegment _indexSegment;
   private final List<AggregationInfo> _aggregationInfos;
-  private final AggregationFunctionContext[] _functionContexts;
+  private final AggregationFunction[] _aggregationFunctions;
   private final TransformPlanNode _transformPlanNode;
   private final StarTreeTransformPlanNode _starTreeTransformPlanNode;
 
   public AggregationPlanNode(IndexSegment indexSegment, BrokerRequest brokerRequest) {
     _indexSegment = indexSegment;
     _aggregationInfos = brokerRequest.getAggregationsInfo();
-    _functionContexts = AggregationFunctionUtils.getAggregationFunctionContexts(brokerRequest);
+    _aggregationFunctions = AggregationFunctionUtils.getAggregationFunctions(brokerRequest);
 
     List<StarTreeV2> starTrees = indexSegment.getStarTrees();
     if (starTrees != null) {
       if (!StarTreeUtils.isStarTreeDisabled(brokerRequest)) {
-        Set<AggregationFunctionColumnPair> aggregationFunctionColumnPairs = new HashSet<>();
-        for (AggregationInfo aggregationInfo : _aggregationInfos) {
-          aggregationFunctionColumnPairs.add(AggregationFunctionUtils.getFunctionColumnPair(aggregationInfo));
+        int numAggregationFunctions = _aggregationFunctions.length;
+        AggregationFunctionColumnPair[] aggregationFunctionColumnPairs =
+            new AggregationFunctionColumnPair[numAggregationFunctions];
+        boolean hasUnsupportedAggregationFunction = false;
+        for (int i = 0; i < numAggregationFunctions; i++) {
+          AggregationFunctionColumnPair aggregationFunctionColumnPair =
+              AggregationFunctionUtils.getAggregationFunctionColumnPair(_aggregationFunctions[i]);
+          if (aggregationFunctionColumnPair != null) {
+            aggregationFunctionColumnPairs[i] = aggregationFunctionColumnPair;
+          } else {
+            hasUnsupportedAggregationFunction = true;
+            break;
+          }
         }
-        FilterQueryTree rootFilterNode = RequestUtils.generateFilterQueryTree(brokerRequest);
-        for (StarTreeV2 starTreeV2 : starTrees) {
-          if (StarTreeUtils
-              .isFitForStarTree(starTreeV2.getMetadata(), aggregationFunctionColumnPairs, null, rootFilterNode)) {
-            _transformPlanNode = null;
-            _starTreeTransformPlanNode =
-                new StarTreeTransformPlanNode(starTreeV2, aggregationFunctionColumnPairs, null, rootFilterNode,
-                    brokerRequest.getDebugOptions());
-            return;
+        if (!hasUnsupportedAggregationFunction) {
+          FilterQueryTree rootFilterNode = RequestUtils.generateFilterQueryTree(brokerRequest);
+          for (StarTreeV2 starTreeV2 : starTrees) {
+            if (StarTreeUtils
+                .isFitForStarTree(starTreeV2.getMetadata(), aggregationFunctionColumnPairs, null, rootFilterNode)) {
+              _transformPlanNode = null;
+              _starTreeTransformPlanNode =
+                  new StarTreeTransformPlanNode(starTreeV2, aggregationFunctionColumnPairs, null, rootFilterNode,
+                      brokerRequest.getDebugOptions());
+              return;
+            }
           }
         }
       }
     }
 
     Set<TransformExpressionTree> expressionsToTransform =
-        AggregationFunctionUtils.collectExpressionsToTransform(brokerRequest, _functionContexts);
+        AggregationFunctionUtils.collectExpressionsToTransform(_aggregationFunctions, null);
     _transformPlanNode = new TransformPlanNode(_indexSegment, brokerRequest, expressionsToTransform);
     _starTreeTransformPlanNode = null;
   }
@@ -88,10 +99,10 @@ public class AggregationPlanNode implements PlanNode {
     int numTotalDocs = _indexSegment.getSegmentMetadata().getTotalDocs();
     if (_transformPlanNode != null) {
       // Do not use star-tree
-      return new AggregationOperator(_functionContexts, _transformPlanNode.run(), numTotalDocs, false);
+      return new AggregationOperator(_aggregationFunctions, _transformPlanNode.run(), numTotalDocs, false);
     } else {
       // Use star-tree
-      return new AggregationOperator(_functionContexts, _starTreeTransformPlanNode.run(), numTotalDocs, true);
+      return new AggregationOperator(_aggregationFunctions, _starTreeTransformPlanNode.run(), numTotalDocs, true);
     }
   }
 
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/plan/DictionaryBasedAggregationPlanNode.java b/pinot-core/src/main/java/org/apache/pinot/core/plan/DictionaryBasedAggregationPlanNode.java
index 4bf368d..561f3eb 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/plan/DictionaryBasedAggregationPlanNode.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/plan/DictionaryBasedAggregationPlanNode.java
@@ -21,10 +21,11 @@ package org.apache.pinot.core.plan;
 import java.util.HashMap;
 import java.util.Map;
 import org.apache.pinot.common.request.BrokerRequest;
+import org.apache.pinot.common.request.transform.TransformExpressionTree;
 import org.apache.pinot.core.common.Operator;
 import org.apache.pinot.core.indexsegment.IndexSegment;
 import org.apache.pinot.core.operator.query.DictionaryBasedAggregationOperator;
-import org.apache.pinot.core.query.aggregation.AggregationFunctionContext;
+import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
 import org.apache.pinot.core.query.aggregation.function.AggregationFunctionUtils;
 import org.apache.pinot.core.segment.index.readers.Dictionary;
 import org.slf4j.Logger;
@@ -37,9 +38,9 @@ import org.slf4j.LoggerFactory;
 public class DictionaryBasedAggregationPlanNode implements PlanNode {
   private static final Logger LOGGER = LoggerFactory.getLogger(DictionaryBasedAggregationPlanNode.class);
 
+  private final IndexSegment _indexSegment;
+  private final AggregationFunction[] _aggregationFunctions;
   private final Map<String, Dictionary> _dictionaryMap;
-  private final AggregationFunctionContext[] _aggregationFunctionContexts;
-  private IndexSegment _indexSegment;
 
   /**
    * Constructor for the class.
@@ -49,22 +50,18 @@ public class DictionaryBasedAggregationPlanNode implements PlanNode {
    */
   public DictionaryBasedAggregationPlanNode(IndexSegment indexSegment, BrokerRequest brokerRequest) {
     _indexSegment = indexSegment;
+    _aggregationFunctions = AggregationFunctionUtils.getAggregationFunctions(brokerRequest);
     _dictionaryMap = new HashMap<>();
-
-    _aggregationFunctionContexts = AggregationFunctionUtils.getAggregationFunctionContexts(brokerRequest);
-
-    for (AggregationFunctionContext aggregationFunctionContext : _aggregationFunctionContexts) {
-      String column = aggregationFunctionContext.getColumnName();
-      if (!_dictionaryMap.containsKey(column)) {
-        _dictionaryMap.put(column, _indexSegment.getDataSource(column).getDictionary());
-      }
+    for (AggregationFunction aggregationFunction : _aggregationFunctions) {
+      String column = ((TransformExpressionTree) aggregationFunction.getInputExpressions().get(0)).getValue();
+      _dictionaryMap.computeIfAbsent(column, k -> _indexSegment.getDataSource(k).getDictionary());
     }
   }
 
   @Override
   public Operator run() {
-    return new DictionaryBasedAggregationOperator(_aggregationFunctionContexts,
-        _indexSegment.getSegmentMetadata().getTotalDocs(), _dictionaryMap);
+    return new DictionaryBasedAggregationOperator(_aggregationFunctions, _dictionaryMap,
+        _indexSegment.getSegmentMetadata().getTotalDocs());
   }
 
   @Override
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/plan/MetadataBasedAggregationPlanNode.java b/pinot-core/src/main/java/org/apache/pinot/core/plan/MetadataBasedAggregationPlanNode.java
index 3ac5e5d..8125088 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/plan/MetadataBasedAggregationPlanNode.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/plan/MetadataBasedAggregationPlanNode.java
@@ -24,13 +24,13 @@ import java.util.Map;
 import org.apache.pinot.common.function.AggregationFunctionType;
 import org.apache.pinot.common.request.AggregationInfo;
 import org.apache.pinot.common.request.BrokerRequest;
+import org.apache.pinot.common.request.transform.TransformExpressionTree;
 import org.apache.pinot.core.common.DataSource;
 import org.apache.pinot.core.common.Operator;
 import org.apache.pinot.core.indexsegment.IndexSegment;
 import org.apache.pinot.core.operator.query.MetadataBasedAggregationOperator;
-import org.apache.pinot.core.query.aggregation.AggregationFunctionContext;
+import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
 import org.apache.pinot.core.query.aggregation.function.AggregationFunctionUtils;
-import org.apache.pinot.core.segment.index.metadata.SegmentMetadata;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -43,7 +43,8 @@ public class MetadataBasedAggregationPlanNode implements PlanNode {
 
   private final IndexSegment _indexSegment;
   private final List<AggregationInfo> _aggregationInfos;
-  private final BrokerRequest _brokerRequest;
+  private final AggregationFunction[] _aggregationFunctions;
+  private final Map<String, DataSource> _dataSourceMap;
 
   /**
    * Constructor for the class.
@@ -53,27 +54,21 @@ public class MetadataBasedAggregationPlanNode implements PlanNode {
    */
   public MetadataBasedAggregationPlanNode(IndexSegment indexSegment, BrokerRequest brokerRequest) {
     _indexSegment = indexSegment;
-    _brokerRequest = brokerRequest;
     _aggregationInfos = brokerRequest.getAggregationsInfo();
+    _aggregationFunctions = AggregationFunctionUtils.getAggregationFunctions(brokerRequest);
+    _dataSourceMap = new HashMap<>();
+    for (AggregationFunction aggregationFunction : _aggregationFunctions) {
+      if (aggregationFunction.getType() != AggregationFunctionType.COUNT) {
+        String column = ((TransformExpressionTree) aggregationFunction.getInputExpressions().get(0)).getValue();
+        _dataSourceMap.computeIfAbsent(column, _indexSegment::getDataSource);
+      }
+    }
   }
 
   @Override
   public Operator run() {
-    SegmentMetadata segmentMetadata = _indexSegment.getSegmentMetadata();
-    AggregationFunctionContext[] aggregationFunctionContexts =
-        AggregationFunctionUtils.getAggregationFunctionContexts(_brokerRequest);
-
-    Map<String, DataSource> dataSourceMap = new HashMap<>();
-    for (AggregationFunctionContext aggregationFunctionContext : aggregationFunctionContexts) {
-      if (aggregationFunctionContext.getAggregationFunction().getType() != AggregationFunctionType.COUNT) {
-        String column = aggregationFunctionContext.getColumnName();
-        if (!dataSourceMap.containsKey(column)) {
-          dataSourceMap.put(column, _indexSegment.getDataSource(column));
-        }
-      }
-    }
-
-    return new MetadataBasedAggregationOperator(aggregationFunctionContexts, segmentMetadata, dataSourceMap);
+    return new MetadataBasedAggregationOperator(_aggregationFunctions, _indexSegment.getSegmentMetadata(),
+        _dataSourceMap);
   }
 
   @Override
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/plan/maker/InstancePlanMakerImplV2.java b/pinot-core/src/main/java/org/apache/pinot/core/plan/maker/InstancePlanMakerImplV2.java
index 182f79b..2103115 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/plan/maker/InstancePlanMakerImplV2.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/plan/maker/InstancePlanMakerImplV2.java
@@ -26,7 +26,6 @@ import java.util.concurrent.ExecutorService;
 import org.apache.pinot.common.function.AggregationFunctionType;
 import org.apache.pinot.common.request.AggregationInfo;
 import org.apache.pinot.common.request.BrokerRequest;
-import org.apache.pinot.common.request.transform.TransformExpressionTree;
 import org.apache.pinot.core.data.manager.SegmentDataManager;
 import org.apache.pinot.core.indexsegment.IndexSegment;
 import org.apache.pinot.core.plan.AggregationGroupByOrderByPlanNode;
@@ -205,9 +204,9 @@ public class InstancePlanMakerImplV2 implements PlanMaker {
         AggregationFunctionType.getAggregationFunctionType(aggregationInfo.getAggregationType());
     if (functionType
         .isOfType(AggregationFunctionType.MIN, AggregationFunctionType.MAX, AggregationFunctionType.MINMAXRANGE)) {
-      String expression = AggregationFunctionUtils.getAggregationExpressions(aggregationInfo).get(0);
-      if (TransformExpressionTree.compileToExpressionTree(expression).isColumn()) {
-        Dictionary dictionary = indexSegment.getDataSource(expression).getDictionary();
+      String column = AggregationFunctionUtils.getArguments(aggregationInfo).get(0);
+      if (indexSegment.getColumnNames().contains(column)) {
+        Dictionary dictionary = indexSegment.getDataSource(column).getDictionary();
         return dictionary != null && dictionary.isSorted();
       }
     }
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/AggregationFunctionContext.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/AggregationFunctionContext.java
deleted file mode 100644
index 9d1027f..0000000
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/AggregationFunctionContext.java
+++ /dev/null
@@ -1,85 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.pinot.core.query.aggregation;
-
-import com.google.common.base.Preconditions;
-import java.util.List;
-import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
-import org.apache.pinot.core.query.aggregation.function.AggregationFunctionUtils;
-
-
-/**
- * This class caches miscellaneous data to perform efficient aggregation.
- *
- * TODO: Remove this class, as it no longer provides any value after aggregation functions now store
- * their arguments.
- */
-public class AggregationFunctionContext {
-  private final AggregationFunction _aggregationFunction;
-  private final List<String> _expressions;
-  private final String columnName;
-
-  public AggregationFunctionContext(AggregationFunction aggregationFunction, List<String> expressions) {
-    Preconditions.checkArgument(expressions.size() >= 1, "Aggregation functions require at least one argument.");
-    _aggregationFunction = aggregationFunction;
-    _expressions = expressions;
-    columnName = AggregationFunctionUtils.concatArgs(expressions);
-  }
-
-  /**
-   * Returns the aggregation function.
-   */
-  public AggregationFunction getAggregationFunction() {
-    return _aggregationFunction;
-  }
-
-  /**
-   * Returns the arguments for the aggregation function.
-   *
-   * @return List of Strings containing the arguments for the aggregation function.
-   */
-  public List<String> getExpressions() {
-    return _expressions;
-  }
-
-  /**
-   * Returns the column for aggregation function.
-   *
-   * @return Aggregation Column (could be column name or UDF expression).
-   */
-  public String getColumnName() {
-    return columnName;
-  }
-
-  /**
-   * Returns the aggregation column name for the results.
-   * <p>E.g. AVG(foo) -> avg_foo
-   */
-  public String getAggregationColumnName() {
-    return _aggregationFunction.getColumnName();
-  }
-
-  /**
-   * Returns the aggregation column name for the result table.
-   * <p>E.g. AVGMV(foo) -> avgMV(foo)
-   */
-  public String getResultColumnName() {
-    return _aggregationFunction.getResultColumnName();
-  }
-}
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/DefaultAggregationExecutor.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/DefaultAggregationExecutor.java
index 4ab0754..706bf1d 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/DefaultAggregationExecutor.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/DefaultAggregationExecutor.java
@@ -19,70 +19,42 @@
 package org.apache.pinot.core.query.aggregation;
 
 import java.util.ArrayList;
-import java.util.Collections;
-import java.util.HashMap;
 import java.util.List;
-import java.util.Map;
-import org.apache.pinot.common.function.AggregationFunctionType;
-import org.apache.pinot.common.request.transform.TransformExpressionTree;
-import org.apache.pinot.core.common.BlockValSet;
 import org.apache.pinot.core.operator.blocks.TransformBlock;
 import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
+import org.apache.pinot.core.query.aggregation.function.AggregationFunctionUtils;
 
 
 public class DefaultAggregationExecutor implements AggregationExecutor {
-  protected final int _numFunctions;
-  protected final AggregationFunction[] _functions;
-  protected final AggregationResultHolder[] _resultHolders;
-  protected final TransformExpressionTree[][] _expressions;
-
-  public DefaultAggregationExecutor(AggregationFunctionContext[] functionContexts) {
-    _numFunctions = functionContexts.length;
-    _functions = new AggregationFunction[_numFunctions];
-    _resultHolders = new AggregationResultHolder[_numFunctions];
-
-    _expressions = new TransformExpressionTree[_numFunctions][];
-    for (int i = 0; i < _numFunctions; i++) {
-      AggregationFunction function = functionContexts[i].getAggregationFunction();
-      _functions[i] = function;
-      _resultHolders[i] = _functions[i].createAggregationResultHolder();
-
-      if (function.getType() != AggregationFunctionType.COUNT) {
-        // count(*) does not have a column so handle rest of the aggregate
-        // functions -- sum, min, max etc
-
-        List<TransformExpressionTree> inputExpressionsList = function.getInputExpressions();
-        _expressions[i] = inputExpressionsList.toArray(new TransformExpressionTree[0]);
-      }
+  protected final AggregationFunction[] _aggregationFunctions;
+  protected final AggregationResultHolder[] _aggregationResultHolders;
+
+  public DefaultAggregationExecutor(AggregationFunction[] aggregationFunctions) {
+    _aggregationFunctions = aggregationFunctions;
+    int numAggregationFunctions = aggregationFunctions.length;
+    _aggregationResultHolders = new AggregationResultHolder[numAggregationFunctions];
+    for (int i = 0; i < numAggregationFunctions; i++) {
+      _aggregationResultHolders[i] = aggregationFunctions[i].createAggregationResultHolder();
     }
   }
 
   @Override
   public void aggregate(TransformBlock transformBlock) {
+    int numAggregationFunctions = _aggregationFunctions.length;
     int length = transformBlock.getNumDocs();
-    for (int i = 0; i < _numFunctions; i++) {
-      AggregationFunction function = _functions[i];
-      AggregationResultHolder resultHolder = _resultHolders[i];
-      if (function.getType() == AggregationFunctionType.COUNT) {
-        // handle count(*) function
-        function.aggregate(length, resultHolder, Collections.emptyMap());
-      } else {
-        // handle rest of the aggregate functions -- sum, min, max etc
-        Map<String, BlockValSet> blockValSetMap = new HashMap<>();
-
-        for (int j = 0; j < _expressions[i].length; j++) {
-          blockValSetMap.put(_expressions[i][j].toString(), transformBlock.getBlockValueSet(_expressions[i][j]));
-        }
-        function.aggregate(length, resultHolder, blockValSetMap);
-      }
+    for (int i = 0; i < numAggregationFunctions; i++) {
+      AggregationFunction aggregationFunction = _aggregationFunctions[i];
+      aggregationFunction.aggregate(length, _aggregationResultHolders[i],
+          AggregationFunctionUtils.getBlockValSetMap(aggregationFunction, transformBlock));
     }
   }
 
   @Override
   public List<Object> getResult() {
-    List<Object> aggregationResults = new ArrayList<>(_numFunctions);
-    for (int i = 0; i < _numFunctions; i++) {
-      aggregationResults.add(_functions[i].extractAggregationResult(_resultHolders[i]));
+    int numFunctions = _aggregationFunctions.length;
+    List<Object> aggregationResults = new ArrayList<>(numFunctions);
+    for (int i = 0; i < numFunctions; i++) {
+      aggregationResults.add(_aggregationFunctions[i].extractAggregationResult(_aggregationResultHolders[i]));
     }
     return aggregationResults;
   }
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/DistinctTable.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/DistinctTable.java
index de51993..d294218 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/DistinctTable.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/DistinctTable.java
@@ -21,7 +21,6 @@ package org.apache.pinot.core.query.aggregation;
 import java.io.IOException;
 import java.nio.ByteBuffer;
 import java.util.ArrayList;
-import java.util.Collections;
 import java.util.HashSet;
 import java.util.Iterator;
 import java.util.List;
@@ -36,6 +35,7 @@ import org.apache.pinot.core.common.datatable.DataTableFactory;
 import org.apache.pinot.core.data.table.BaseTable;
 import org.apache.pinot.core.data.table.Key;
 import org.apache.pinot.core.data.table.Record;
+import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
 import org.apache.pinot.spi.utils.ByteArray;
 
 
@@ -59,7 +59,7 @@ public class DistinctTable extends BaseTable {
     // NOTE: The passed in capacity is calculated based on the LIMIT in the query as Math.max(limit * 5, 5000). When
     //       LIMIT is smaller than (64 * 1024 * 0.75 (load factor) / 5 = 9830), then it is guaranteed that no resize is
     //       required.
-    super(dataSchema, Collections.emptyList(), orderBy, capacity);
+    super(dataSchema, new AggregationFunction[0], orderBy, capacity);
     int initialCapacity = Math.min(MAX_INITIAL_CAPACITY, HashUtil.getHashMapCapacity(capacity));
     _uniqueRecordsSet = new HashSet<>(initialCapacity);
     _noMoreNewRecords = false;
@@ -162,8 +162,8 @@ public class DistinctTable extends BaseTable {
     // information to pass to super class so just pass null, empty lists
     // and the broker will set the correct information before merging the
     // data tables.
-    super(new DataSchema(new String[0], new DataSchema.ColumnDataType[0]), Collections.emptyList(), new ArrayList<>(),
-        0);
+    super(new DataSchema(new String[0], new DataSchema.ColumnDataType[0]), new AggregationFunction[0],
+        new ArrayList<>(), 0);
     DataTable dataTable = DataTableFactory.getDataTable(byteBuffer);
     _dataSchema = dataTable.getDataSchema();
     _uniqueRecordsSet = new HashSet<>();
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunction.java
index 08e4868..362309c 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunction.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunction.java
@@ -76,21 +76,22 @@ public interface AggregationFunction<IntermediateResult, FinalResult extends Com
   /**
    * Performs aggregation on the given block value sets (aggregation only).
    */
-  void aggregate(int length, AggregationResultHolder aggregationResultHolder, Map<String, BlockValSet> blockValSetMap);
+  void aggregate(int length, AggregationResultHolder aggregationResultHolder,
+      Map<TransformExpressionTree, BlockValSet> blockValSetMap);
 
   /**
    * Performs aggregation on the given group key array and block value sets (aggregation group-by on single-value
    * columns).
    */
   void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder,
-      Map<String, BlockValSet> blockValSets);
+      Map<TransformExpressionTree, BlockValSet> blockValSetMap);
 
   /**
    * Performs aggregation on the given group keys array and block value sets (aggregation group-by on multi-value
    * columns).
    */
   void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder groupByResultHolder,
-      Map<String, BlockValSet> blockValSets);
+      Map<TransformExpressionTree, BlockValSet> blockValSetMap);
 
   /**
    * Extracts the intermediate result from the aggregation result holder (aggregation only).
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java
index 90f99f6..f21f7fb 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java
@@ -19,7 +19,6 @@
 package org.apache.pinot.core.query.aggregation.function;
 
 import com.google.common.base.Preconditions;
-import java.util.ArrayList;
 import java.util.List;
 import javax.annotation.Nullable;
 import org.apache.pinot.common.function.AggregationFunctionType;
@@ -45,45 +44,79 @@ public class AggregationFunctionFactory {
   public static AggregationFunction getAggregationFunction(AggregationInfo aggregationInfo,
       @Nullable BrokerRequest brokerRequest) {
     String functionName = aggregationInfo.getAggregationType();
-    List<String> expressions = AggregationFunctionUtils.getAggregationExpressions(aggregationInfo);
+    List<String> arguments = AggregationFunctionUtils.getArguments(aggregationInfo);
 
     try {
       String upperCaseFunctionName = functionName.toUpperCase();
+      String column = arguments.get(0);
       if (upperCaseFunctionName.startsWith("PERCENTILE")) {
         String remainingFunctionName = upperCaseFunctionName.substring(10);
-        List<String> args = new ArrayList<>(expressions);
-        if (remainingFunctionName.matches("\\d+")) {
-          // Percentile
-          args.add(remainingFunctionName);
-          return new PercentileAggregationFunction(args);
-        } else if (remainingFunctionName.matches("EST\\d+")) {
-          // PercentileEst
-          args.add(remainingFunctionName.substring(3));
-          return new PercentileEstAggregationFunction(args);
-        } else if (remainingFunctionName.matches("TDIGEST\\d+")) {
-          // PercentileTDigest
-          args.add(remainingFunctionName.substring(7));
-          return new PercentileTDigestAggregationFunction(args);
-        } else if (remainingFunctionName.matches("\\d+MV")) {
-          // PercentileMV
-          args.add(remainingFunctionName.substring(0, remainingFunctionName.length() - 2));
-          return new PercentileMVAggregationFunction(args);
-        } else if (remainingFunctionName.matches("EST\\d+MV")) {
-          // PercentileEstMV
-          args.add(remainingFunctionName.substring(3, remainingFunctionName.length() - 2));
-          return new PercentileEstMVAggregationFunction(args);
-        } else if (remainingFunctionName.matches("TDIGEST\\d+MV")) {
-          // PercentileTDigestMV
-          args.add(remainingFunctionName.substring(7, remainingFunctionName.length() - 2));
-          return new PercentileTDigestMVAggregationFunction(args);
-        } else {
-          throw new IllegalArgumentException();
+        int numArguments = arguments.size();
+        if (numArguments == 1) {
+          // Single argument percentile (e.g. Percentile99(foo), PercentileTDigest95(bar), etc.)
+          if (remainingFunctionName.matches("\\d+")) {
+            // Percentile
+            return new PercentileAggregationFunction(column,
+                AggregationFunctionUtils.parsePercentile(remainingFunctionName));
+          } else if (remainingFunctionName.matches("EST\\d+")) {
+            // PercentileEst
+            String percentileString = remainingFunctionName.substring(3);
+            return new PercentileEstAggregationFunction(column,
+                AggregationFunctionUtils.parsePercentile(percentileString));
+          } else if (remainingFunctionName.matches("TDIGEST\\d+")) {
+            // PercentileTDigest
+            String percentileString = remainingFunctionName.substring(7);
+            return new PercentileTDigestAggregationFunction(column,
+                AggregationFunctionUtils.parsePercentile(percentileString));
+          } else if (remainingFunctionName.matches("\\d+MV")) {
+            // PercentileMV
+            String percentileString = remainingFunctionName.substring(0, remainingFunctionName.length() - 2);
+            return new PercentileMVAggregationFunction(column,
+                AggregationFunctionUtils.parsePercentile(percentileString));
+          } else if (remainingFunctionName.matches("EST\\d+MV")) {
+            // PercentileEstMV
+            String percentileString = remainingFunctionName.substring(3, remainingFunctionName.length() - 2);
+            return new PercentileEstMVAggregationFunction(column,
+                AggregationFunctionUtils.parsePercentile(percentileString));
+          } else if (remainingFunctionName.matches("TDIGEST\\d+MV")) {
+            // PercentileTDigestMV
+            String percentileString = remainingFunctionName.substring(7, remainingFunctionName.length() - 2);
+            return new PercentileTDigestMVAggregationFunction(column,
+                AggregationFunctionUtils.parsePercentile(percentileString));
+          }
+        } else if (numArguments == 2) {
+          // Double arguments percentile (e.g. percentile(foo, 99), percentileTDigest(bar, 95), etc.)
+          int percentile = AggregationFunctionUtils.parsePercentile(arguments.get(1));
+          if (remainingFunctionName.isEmpty()) {
+            // Percentile
+            return new PercentileAggregationFunction(column, percentile);
+          }
+          if (remainingFunctionName.equals("EST")) {
+            // PercentileEst
+            return new PercentileEstAggregationFunction(column, percentile);
+          }
+          if (remainingFunctionName.equals("TDIGEST")) {
+            // PercentileTDigest
+            return new PercentileTDigestAggregationFunction(column, percentile);
+          }
+          if (remainingFunctionName.equals("MV")) {
+            // PercentileMV
+            return new PercentileMVAggregationFunction(column, percentile);
+          }
+          if (remainingFunctionName.equals("ESTMV")) {
+            // PercentileEstMV
+            return new PercentileEstMVAggregationFunction(column, percentile);
+          }
+          if (remainingFunctionName.equals("TDIGESTMV")) {
+            // PercentileTDigestMV
+            return new PercentileTDigestMVAggregationFunction(column, percentile);
+          }
         }
+        throw new IllegalArgumentException("Invalid percentile function");
       } else {
-        String column = expressions.get(0);
         switch (AggregationFunctionType.valueOf(upperCaseFunctionName)) {
           case COUNT:
-            return new CountAggregationFunction(column);
+            return new CountAggregationFunction();
           case MIN:
             return new MinAggregationFunction(column);
           case MAX:
@@ -103,7 +136,7 @@ public class AggregationFunctionFactory {
           case FASTHLL:
             return new FastHLLAggregationFunction(column);
           case DISTINCTCOUNTTHETASKETCH:
-            return new DistinctCountThetaSketchAggregationFunction(expressions);
+            return new DistinctCountThetaSketchAggregationFunction(arguments);
           case COUNTMV:
             return new CountMVAggregationFunction(column);
           case MINMV:
@@ -125,14 +158,13 @@ public class AggregationFunctionFactory {
           case DISTINCT:
             Preconditions.checkState(brokerRequest != null,
                 "Broker request must be provided for 'DISTINCT' aggregation function");
-            return new DistinctAggregationFunction(expressions, brokerRequest.getOrderBy(),
-                brokerRequest.getLimit());
+            return new DistinctAggregationFunction(arguments, brokerRequest.getOrderBy(), brokerRequest.getLimit());
           default:
             throw new IllegalArgumentException();
         }
       }
     } catch (Exception e) {
-      throw new BadQueryRequestException("Invalid aggregation function name: " + functionName);
+      throw new BadQueryRequestException("Invalid aggregation: " + aggregationInfo);
     }
   }
 }
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionUtils.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionUtils.java
index 995440f..8b75740 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionUtils.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionUtils.java
@@ -22,16 +22,20 @@ import com.google.common.base.Preconditions;
 import com.google.common.math.DoubleMath;
 import java.io.Serializable;
 import java.util.Arrays;
-import java.util.LinkedHashSet;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
 import java.util.List;
 import java.util.Locale;
+import java.util.Map;
 import java.util.Set;
 import javax.annotation.Nullable;
 import org.apache.pinot.common.function.AggregationFunctionType;
 import org.apache.pinot.common.request.AggregationInfo;
 import org.apache.pinot.common.request.BrokerRequest;
 import org.apache.pinot.common.request.transform.TransformExpressionTree;
-import org.apache.pinot.core.query.aggregation.AggregationFunctionContext;
+import org.apache.pinot.core.common.BlockValSet;
+import org.apache.pinot.core.operator.blocks.TransformBlock;
 import org.apache.pinot.core.startree.v2.AggregationFunctionColumnPair;
 import org.apache.pinot.parsers.CompilerConstants;
 
@@ -44,79 +48,26 @@ public class AggregationFunctionUtils {
   }
 
   /**
-   * Extracts the aggregation column (could be column name or UDF expression) from the {@link AggregationInfo}.
-   *
-   *
-   */
-  /**
-   * Returns the arguments for {@link AggregationFunction} as List of Strings.
-   * For backward compatibility, it uses the new Thrift field `expressions` if found, or else,
-   * falls back to the previous aggregationParams based approach.
-   *
-   * @param aggregationInfo Aggregation Info
-   * @return List of aggregation function arguments
+   * Extracts the aggregation function arguments (could be column name, transform function or constant) from the
+   * {@link AggregationInfo} as an array of Strings.
+   * <p>NOTE: For backward-compatibility, uses the new Thrift field `expressions` if found, or falls back to the old
+   * aggregationParams based approach.
    */
-  public static List<String> getAggregationExpressions(AggregationInfo aggregationInfo) {
+  public static List<String> getArguments(AggregationInfo aggregationInfo) {
     List<String> expressions = aggregationInfo.getExpressions();
     if (expressions != null) {
       return expressions;
+    } else {
+      // NOTE: When the server is upgraded before the broker, the expressions won't be set. Falls back to the old
+      //       aggregationParams based approach.
+      String column = aggregationInfo.getAggregationParams().get(CompilerConstants.COLUMN_KEY_IN_AGGREGATION_INFO);
+      return Arrays.asList(column.split(CompilerConstants.AGGREGATION_FUNCTION_ARG_SEPARATOR));
     }
-
-    String params = aggregationInfo.getAggregationParams().get(CompilerConstants.COLUMN_KEY_IN_AGGREGATION_INFO);
-    return Arrays.asList(params.split(CompilerConstants.AGGREGATION_FUNCTION_ARG_SEPARATOR));
   }
 
   /**
-   * Creates an {@link AggregationFunctionColumnPair} from the {@link AggregationInfo}.
-   * Asserts that the function only expects one argument.
+   * Creates an array of {@link AggregationFunction}s based on the given {@link BrokerRequest}.
    */
-  public static AggregationFunctionColumnPair getFunctionColumnPair(AggregationInfo aggregationInfo) {
-    List<String> aggregationExpressions = getAggregationExpressions(aggregationInfo);
-    int numExpressions = aggregationExpressions.size();
-    AggregationFunctionType functionType =
-        AggregationFunctionType.getAggregationFunctionType(aggregationInfo.getAggregationType());
-    Preconditions
-        .checkState(numExpressions == 1, "Expected one argument for '" + functionType + "', got: " + numExpressions);
-    return new AggregationFunctionColumnPair(functionType, aggregationExpressions.get(0));
-  }
-
-  public static boolean isDistinct(AggregationFunctionContext[] functionContexts) {
-    return functionContexts.length == 1
-        && functionContexts[0].getAggregationFunction().getType() == AggregationFunctionType.DISTINCT;
-  }
-
-  /**
-   * Creates an {@link AggregationFunctionContext} from the {@link AggregationInfo}.
-   * NOTE: This method does not work for {@code DISTINCT} aggregation function.
-   * TODO: Remove this method and always pass in the broker request
-   */
-  public static AggregationFunctionContext getAggregationFunctionContext(AggregationInfo aggregationInfo) {
-    return getAggregationFunctionContext(aggregationInfo, null);
-  }
-
-  /**
-   * NOTE: Broker request cannot be {@code null} for {@code DISTINCT} aggregation function.
-   * TODO: Always pass in non-null broker request
-   */
-  public static AggregationFunctionContext getAggregationFunctionContext(AggregationInfo aggregationInfo,
-      @Nullable BrokerRequest brokerRequest) {
-    List<String> aggregationExpressions = getAggregationExpressions(aggregationInfo);
-    AggregationFunction aggregationFunction =
-        AggregationFunctionFactory.getAggregationFunction(aggregationInfo, brokerRequest);
-    return new AggregationFunctionContext(aggregationFunction, aggregationExpressions);
-  }
-
-  public static AggregationFunctionContext[] getAggregationFunctionContexts(BrokerRequest brokerRequest) {
-    List<AggregationInfo> aggregationInfos = brokerRequest.getAggregationsInfo();
-    int numAggregationFunctions = aggregationInfos.size();
-    AggregationFunctionContext[] aggregationFunctionContexts = new AggregationFunctionContext[numAggregationFunctions];
-    for (int i = 0; i < numAggregationFunctions; i++) {
-      AggregationInfo aggregationInfo = aggregationInfos.get(i);
-      aggregationFunctionContexts[i] = getAggregationFunctionContext(aggregationInfo, brokerRequest);
-    }
-    return aggregationFunctionContexts;
-  }
-
   public static AggregationFunction[] getAggregationFunctions(BrokerRequest brokerRequest) {
     List<AggregationInfo> aggregationInfos = brokerRequest.getAggregationsInfo();
     int numAggregationFunctions = aggregationInfos.size();
@@ -128,6 +79,29 @@ public class AggregationFunctionUtils {
     return aggregationFunctions;
   }
 
+  /**
+   * (For Star-Tree) Creates an {@link AggregationFunctionColumnPair} from the {@link AggregationFunction}. Returns
+   * {@code null} if the {@link AggregationFunction} cannot be represented as an {@link AggregationFunctionColumnPair}
+   * (e.g. has multiple arguments, argument is not column etc.).
+   */
+  @Nullable
+  public static AggregationFunctionColumnPair getAggregationFunctionColumnPair(
+      AggregationFunction aggregationFunction) {
+    AggregationFunctionType aggregationFunctionType = aggregationFunction.getType();
+    if (aggregationFunctionType == AggregationFunctionType.COUNT) {
+      return AggregationFunctionColumnPair.COUNT_STAR;
+    }
+    //noinspection unchecked
+    List<TransformExpressionTree> inputExpressions = aggregationFunction.getInputExpressions();
+    if (inputExpressions.size() == 1) {
+      TransformExpressionTree inputExpression = inputExpressions.get(0);
+      if (inputExpression.isColumn()) {
+        return new AggregationFunctionColumnPair(aggregationFunctionType, inputExpression.getValue());
+      }
+    }
+    return null;
+  }
+
   public static boolean[] getAggregationFunctionsSelectStatus(List<AggregationInfo> aggregationInfos) {
     int numAggregationFunctions = aggregationInfos.size();
     boolean[] aggregationFunctionsStatus = new boolean[numAggregationFunctions];
@@ -164,13 +138,20 @@ public class AggregationFunctionUtils {
 
   /**
    * Utility function to parse percentile value from string.
-   * Asserts that percentile value is within 0 and 100.
+   * <p>Asserts that percentile value is within 0 and 100.
+   * <p>NOTE: When percentileString is from the second argument (e.g. percentile(foo, 99), percentileTDigest(bar, 95),
+   *          etc.), it might be standardized into single-quoted format.
    *
    * @param percentileString Input String
    * @return Percentile value parsed from String.
    */
   public static int parsePercentile(String percentileString) {
-    int percentile = Integer.parseInt(percentileString);
+    int percentile;
+    if (percentileString.charAt(0) == '\'') {
+      percentile = Integer.parseInt(percentileString.substring(1, percentileString.length() - 1));
+    } else {
+      percentile = Integer.parseInt(percentileString);
+    }
     Preconditions.checkState(percentile >= 0 && percentile <= 100);
     return percentile;
   }
@@ -181,37 +162,63 @@ public class AggregationFunctionUtils {
    * @param arguments Arguments to concatenate
    * @return Concatenated String of arguments
    */
-  public static String concatArgs(List<String> arguments) {
-    return (arguments.size() > 1) ? String.join(CompilerConstants.AGGREGATION_FUNCTION_ARG_SEPARATOR, arguments)
-        : arguments.get(0);
+  public static String concatArgs(String[] arguments) {
+    return arguments.length > 1 ? String.join(CompilerConstants.AGGREGATION_FUNCTION_ARG_SEPARATOR, arguments)
+        : arguments[0];
   }
 
   /**
-   * Compiles and returns all transform expressions required for computing the aggregation, group-by
-   * and order-by
-   *
-   * @param brokerRequest Broker Request
-   * @param functionContexts Aggregation Function contexts
-   * @return Set of compiled expressions in the aggregation, group-by and order-by clauses
+   * Collects all transform expressions required for aggregation/group-by queries.
+   * <p>NOTE: We don't need to consider order-by columns here as the ordering is only allowed for aggregation functions
+   *          or group-by expressions.
    */
-  public static Set<TransformExpressionTree> collectExpressionsToTransform(BrokerRequest brokerRequest,
-      AggregationFunctionContext[] functionContexts) {
-
-    Set<TransformExpressionTree> expressionTrees = new LinkedHashSet<>();
-    for (AggregationFunctionContext functionContext : functionContexts) {
-      AggregationFunction function = functionContext.getAggregationFunction();
-      expressionTrees.addAll(function.getInputExpressions());
+  public static Set<TransformExpressionTree> collectExpressionsToTransform(AggregationFunction[] aggregationFunctions,
+      @Nullable TransformExpressionTree[] groupByExpressions) {
+    Set<TransformExpressionTree> expressions = new HashSet<>();
+    for (AggregationFunction aggregationFunction : aggregationFunctions) {
+      //noinspection unchecked
+      expressions.addAll(aggregationFunction.getInputExpressions());
+    }
+    if (groupByExpressions != null) {
+      expressions.addAll(Arrays.asList(groupByExpressions));
     }
+    return expressions;
+  }
 
-    // Extract group-by expressions
-    if (brokerRequest.isSetGroupBy()) {
-      for (String expression : brokerRequest.getGroupBy().getExpressions()) {
-        expressionTrees.add(TransformExpressionTree.compileToExpressionTree(expression));
-      }
+  /**
+   * Creates a map from expression required by the {@link AggregationFunction} to {@link BlockValSet} fetched from the
+   * {@link TransformBlock}.
+   */
+  public static Map<TransformExpressionTree, BlockValSet> getBlockValSetMap(AggregationFunction aggregationFunction,
+      TransformBlock transformBlock) {
+    //noinspection unchecked
+    List<TransformExpressionTree> expressions = aggregationFunction.getInputExpressions();
+    int numExpressions = expressions.size();
+    if (numExpressions == 0) {
+      return Collections.emptyMap();
     }
+    if (numExpressions == 1) {
+      TransformExpressionTree expression = expressions.get(0);
+      return Collections.singletonMap(expression, transformBlock.getBlockValueSet(expression));
+    }
+    Map<TransformExpressionTree, BlockValSet> blockValSetMap = new HashMap<>();
+    for (TransformExpressionTree expression : expressions) {
+      blockValSetMap.put(expression, transformBlock.getBlockValueSet(expression));
+    }
+    return blockValSetMap;
+  }
 
-    // TODO: Add order-by expressions when available in brokerRequest for aggregation queries.
-    // The current order-by implementation assumes that ordering will be on aggregation/group-by columns.
-    return expressionTrees;
+  /**
+   * (For Star-Tree) Creates a map from expression required by the {@link AggregationFunctionColumnPair} to
+   * {@link BlockValSet} fetched from the {@link TransformBlock}.
+   * <p>NOTE: We construct the map with original column name as the key but fetch BlockValSet with the aggregation
+   *          function pair so that the aggregation result column name is consistent with or without star-tree.
+   */
+  public static Map<TransformExpressionTree, BlockValSet> getBlockValSetMap(
+      AggregationFunctionColumnPair aggregationFunctionColumnPair, TransformBlock transformBlock) {
+    TransformExpressionTree expression = new TransformExpressionTree(TransformExpressionTree.ExpressionType.IDENTIFIER,
+        aggregationFunctionColumnPair.getColumn(), null);
+    BlockValSet blockValSet = transformBlock.getBlockValueSet(aggregationFunctionColumnPair.toColumnName());
+    return Collections.singletonMap(expression, blockValSet);
   }
 }
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AvgAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AvgAggregationFunction.java
index 33d21a6..52f2d6c 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AvgAggregationFunction.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AvgAggregationFunction.java
@@ -18,8 +18,6 @@
  */
 package org.apache.pinot.core.query.aggregation.function;
 
-import java.util.Collections;
-import java.util.List;
 import java.util.Map;
 import org.apache.pinot.common.function.AggregationFunctionType;
 import org.apache.pinot.common.request.transform.TransformExpressionTree;
@@ -31,23 +29,14 @@ import org.apache.pinot.core.query.aggregation.ObjectAggregationResultHolder;
 import org.apache.pinot.core.query.aggregation.function.customobject.AvgPair;
 import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
 import org.apache.pinot.core.query.aggregation.groupby.ObjectGroupByResultHolder;
-import org.apache.pinot.pql.parsers.pql2.ast.IdentifierAstNode;
 import org.apache.pinot.spi.data.FieldSpec.DataType;
 
 
-public class AvgAggregationFunction implements AggregationFunction<AvgPair, Double> {
+public class AvgAggregationFunction extends BaseSingleInputAggregationFunction<AvgPair, Double> {
   private static final double DEFAULT_FINAL_RESULT = Double.NEGATIVE_INFINITY;
 
-  protected final String _column;
-  protected final List<TransformExpressionTree> _inputExpressions;
-
-  /**
-   * Constructor for the class.
-   * @param column Column name to aggregate on.
-   */
   public AvgAggregationFunction(String column) {
-    _column = column;
-    _inputExpressions = Collections.singletonList(TransformExpressionTree.compileToExpressionTree(_column));
+    super(column);
   }
 
   @Override
@@ -56,21 +45,6 @@ public class AvgAggregationFunction implements AggregationFunction<AvgPair, Doub
   }
 
   @Override
-  public String getColumnName() {
-    return getType().getName() + "_" + _column;
-  }
-
-  @Override
-  public String getResultColumnName() {
-    return getType().getName().toLowerCase() + "(" + _column + ")";
-  }
-
-  @Override
-  public List<TransformExpressionTree> getInputExpressions() {
-    return _inputExpressions;
-  }
-
-  @Override
   public void accept(AggregationFunctionVisitorBase visitor) {
     visitor.visit(this);
   }
@@ -87,8 +61,8 @@ public class AvgAggregationFunction implements AggregationFunction<AvgPair, Doub
 
   @Override
   public void aggregate(int length, AggregationResultHolder aggregationResultHolder,
-      Map<String, BlockValSet> blockValSetMap) {
-    BlockValSet blockValSet = blockValSetMap.get(_column);
+      Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+    BlockValSet blockValSet = blockValSetMap.get(_expression);
 
     if (blockValSet.getValueType() != DataType.BYTES) {
       double[] doubleValues = blockValSet.getDoubleValuesSV();
@@ -96,7 +70,7 @@ public class AvgAggregationFunction implements AggregationFunction<AvgPair, Doub
       for (int i = 0; i < length; i++) {
         sum += doubleValues[i];
       }
-      setAggregationResult(aggregationResultHolder, sum, (long) length);
+      setAggregationResult(aggregationResultHolder, sum, length);
     } else {
       // Serialized AvgPair
       byte[][] bytesValues = blockValSet.getBytesValuesSV();
@@ -122,8 +96,8 @@ public class AvgAggregationFunction implements AggregationFunction<AvgPair, Doub
 
   @Override
   public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder,
-      Map<String, BlockValSet> blockValSetMap) {
-    BlockValSet blockValSet = blockValSetMap.get(_column);
+      Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+    BlockValSet blockValSet = blockValSetMap.get(_expression);
 
     if (blockValSet.getValueType() != DataType.BYTES) {
       double[] doubleValues = blockValSet.getDoubleValuesSV();
@@ -142,8 +116,8 @@ public class AvgAggregationFunction implements AggregationFunction<AvgPair, Doub
 
   @Override
   public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder groupByResultHolder,
-      Map<String, BlockValSet> blockValSetMap) {
-    BlockValSet blockValSet = blockValSetMap.get(_column);
+      Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+    BlockValSet blockValSet = blockValSetMap.get(_expression);
 
     if (blockValSet.getValueType() != DataType.BYTES) {
       double[] doubleValues = blockValSet.getDoubleValuesSV();
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AvgMVAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AvgMVAggregationFunction.java
index f45b41f..abe319b 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AvgMVAggregationFunction.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AvgMVAggregationFunction.java
@@ -20,6 +20,7 @@ package org.apache.pinot.core.query.aggregation.function;
 
 import java.util.Map;
 import org.apache.pinot.common.function.AggregationFunctionType;
+import org.apache.pinot.common.request.transform.TransformExpressionTree;
 import org.apache.pinot.core.common.BlockValSet;
 import org.apache.pinot.core.query.aggregation.AggregationResultHolder;
 import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
@@ -27,10 +28,6 @@ import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
 
 public class AvgMVAggregationFunction extends AvgAggregationFunction {
 
-  /**
-   * Constructor for the class.
-   * @param column Column name to aggregate on.
-   */
   public AvgMVAggregationFunction(String column) {
     super(column);
   }
@@ -46,8 +43,9 @@ public class AvgMVAggregationFunction extends AvgAggregationFunction {
   }
 
   @Override
-  public void aggregate(int length, AggregationResultHolder aggregationResultHolder, Map<String, BlockValSet> blockValSetMap) {
-    double[][] valuesArray = blockValSetMap.get(_column).getDoubleValuesMV();
+  public void aggregate(int length, AggregationResultHolder aggregationResultHolder,
+      Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+    double[][] valuesArray = blockValSetMap.get(_expression).getDoubleValuesMV();
     double sum = 0.0;
     long count = 0L;
     for (int i = 0; i < length; i++) {
@@ -62,8 +60,8 @@ public class AvgMVAggregationFunction extends AvgAggregationFunction {
 
   @Override
   public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder,
-      Map<String, BlockValSet> blockValSetMap) {
-    double[][] valuesArray = blockValSetMap.get(_column).getDoubleValuesMV();
+      Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+    double[][] valuesArray = blockValSetMap.get(_expression).getDoubleValuesMV();
     for (int i = 0; i < length; i++) {
       aggregateOnGroupKey(groupKeyArray[i], groupByResultHolder, valuesArray[i]);
     }
@@ -71,8 +69,8 @@ public class AvgMVAggregationFunction extends AvgAggregationFunction {
 
   @Override
   public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder groupByResultHolder,
-      Map<String, BlockValSet> blockValSetMap) {
-    double[][] valuesArray = blockValSetMap.get(_column).getDoubleValuesMV();
+      Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+    double[][] valuesArray = blockValSetMap.get(_expression).getDoubleValuesMV();
     for (int i = 0; i < length; i++) {
       double[] values = valuesArray[i];
       for (int groupKey : groupKeysArray[i]) {
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/BaseSingleInputAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/BaseSingleInputAggregationFunction.java
new file mode 100644
index 0000000..43641b2
--- /dev/null
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/BaseSingleInputAggregationFunction.java
@@ -0,0 +1,57 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.core.query.aggregation.function;
+
+import java.util.Collections;
+import java.util.List;
+import org.apache.pinot.common.request.transform.TransformExpressionTree;
+
+
+/**
+ * Base implementation of {@link AggregationFunction} with single input expression.
+ */
+public abstract class BaseSingleInputAggregationFunction<I, F extends Comparable> implements AggregationFunction<I, F> {
+  protected final String _column;
+  protected final TransformExpressionTree _expression;
+
+  /**
+   * Constructor for the class.
+   *
+   * @param column Column to aggregate on (could be column name or transform function).
+   */
+  public BaseSingleInputAggregationFunction(String column) {
+    _column = column;
+    _expression = TransformExpressionTree.compileToExpressionTree(column);
+  }
+
+  @Override
+  public String getColumnName() {
+    return getType().getName() + "_" + _column;
+  }
+
+  @Override
+  public String getResultColumnName() {
+    return getType().getName().toLowerCase() + "(" + _column + ")";
+  }
+
+  @Override
+  public List<TransformExpressionTree> getInputExpressions() {
+    return Collections.singletonList(_expression);
+  }
+}
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/CountAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/CountAggregationFunction.java
index 20ce56b..a437ce0 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/CountAggregationFunction.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/CountAggregationFunction.java
@@ -29,21 +29,17 @@ import org.apache.pinot.core.query.aggregation.AggregationResultHolder;
 import org.apache.pinot.core.query.aggregation.DoubleAggregationResultHolder;
 import org.apache.pinot.core.query.aggregation.groupby.DoubleGroupByResultHolder;
 import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
+import org.apache.pinot.core.startree.v2.AggregationFunctionColumnPair;
 
 
 public class CountAggregationFunction implements AggregationFunction<Long, Long> {
-  private static final String COLUMN_NAME = AggregationFunctionType.COUNT.getName() + "_star";
+  private static final String COLUMN_NAME = "count_star";
+  private static final String RESULT_COLUMN_NAME = "count(*)";
   private static final double DEFAULT_INITIAL_VALUE = 0.0;
-
-  protected final String _column;
-
-  /**
-   * Constructor for the class.
-   * @param column Column name to aggregate on.
-   */
-  public CountAggregationFunction(String column) {
-    _column = column;
-  }
+  // Special expression used by star-tree to pass in BlockValSet
+  private static final TransformExpressionTree STAR_TREE_COUNT_STAR_EXPRESSION =
+      new TransformExpressionTree(TransformExpressionTree.ExpressionType.IDENTIFIER, AggregationFunctionColumnPair.STAR,
+          null);
 
   @Override
   public AggregationFunctionType getType() {
@@ -57,7 +53,7 @@ public class CountAggregationFunction implements AggregationFunction<Long, Long>
 
   @Override
   public String getResultColumnName() {
-    return AggregationFunctionType.COUNT.getName().toLowerCase() + "(*)";
+    return RESULT_COLUMN_NAME;
   }
 
   @Override
@@ -82,12 +78,12 @@ public class CountAggregationFunction implements AggregationFunction<Long, Long>
 
   @Override
   public void aggregate(int length, AggregationResultHolder aggregationResultHolder,
-      Map<String, BlockValSet> blockValSetMap) {
+      Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
     if (blockValSetMap.size() == 0) {
       aggregationResultHolder.setValue(aggregationResultHolder.getDoubleResult() + length);
     } else {
       // Star-tree pre-aggregated values
-      long[] valueArray = blockValSetMap.get(_column).getLongValuesSV();
+      long[] valueArray = blockValSetMap.get(STAR_TREE_COUNT_STAR_EXPRESSION).getLongValuesSV();
       long count = 0;
       for (int i = 0; i < length; i++) {
         count += valueArray[i];
@@ -98,7 +94,7 @@ public class CountAggregationFunction implements AggregationFunction<Long, Long>
 
   @Override
   public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder,
-      Map<String, BlockValSet> blockValSetMap) {
+      Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
     if (blockValSetMap.size() == 0) {
       for (int i = 0; i < length; i++) {
         int groupKey = groupKeyArray[i];
@@ -106,7 +102,7 @@ public class CountAggregationFunction implements AggregationFunction<Long, Long>
       }
     } else {
       // Star-tree pre-aggregated values
-      long[] valueArray = blockValSetMap.get(_column).getLongValuesSV();
+      long[] valueArray = blockValSetMap.get(STAR_TREE_COUNT_STAR_EXPRESSION).getLongValuesSV();
       for (int i = 0; i < length; i++) {
         int groupKey = groupKeyArray[i];
         groupByResultHolder.setValueForKey(groupKey, groupByResultHolder.getDoubleResult(groupKey) + valueArray[i]);
@@ -116,7 +112,7 @@ public class CountAggregationFunction implements AggregationFunction<Long, Long>
 
   @Override
   public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder groupByResultHolder,
-      Map<String, BlockValSet> blockValSetMap) {
+      Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
     if (blockValSetMap.size() == 0) {
       for (int i = 0; i < length; i++) {
         for (int groupKey : groupKeysArray[i]) {
@@ -125,7 +121,7 @@ public class CountAggregationFunction implements AggregationFunction<Long, Long>
       }
     } else {
       // Star-tree pre-aggregated values
-      long[] valueArray = blockValSetMap.get(_column).getLongValuesSV();
+      long[] valueArray = blockValSetMap.get(STAR_TREE_COUNT_STAR_EXPRESSION).getLongValuesSV();
       for (int i = 0; i < length; i++) {
         long value = valueArray[i];
         for (int groupKey : groupKeysArray[i]) {
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/CountMVAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/CountMVAggregationFunction.java
index 29b540e..96afabe 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/CountMVAggregationFunction.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/CountMVAggregationFunction.java
@@ -29,16 +29,17 @@ import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
 
 
 public class CountMVAggregationFunction extends CountAggregationFunction {
-
-  private final List<TransformExpressionTree> _inputExpressions;
+  private final String _column;
+  private final TransformExpressionTree _expression;
 
   /**
    * Constructor for the class.
-   * @param column Column name to aggregate on.
+   *
+   * @param column Column to aggregate on (could be column name or transform function).
    */
   public CountMVAggregationFunction(String column) {
-    super(column);
-    _inputExpressions = Collections.singletonList(TransformExpressionTree.compileToExpressionTree(column));
+    _column = column;
+    _expression = TransformExpressionTree.compileToExpressionTree(column);
   }
 
   @Override
@@ -48,17 +49,17 @@ public class CountMVAggregationFunction extends CountAggregationFunction {
 
   @Override
   public String getColumnName() {
-    return getType().getName() + "_" + _column;
+    return AggregationFunctionType.COUNTMV.getName() + "_" + _column;
   }
 
   @Override
   public String getResultColumnName() {
-    return getType().getName().toLowerCase() + "(" + _column + ")";
+    return AggregationFunctionType.COUNTMV.getName().toLowerCase() + "(" + _column + ")";
   }
 
   @Override
   public List<TransformExpressionTree> getInputExpressions() {
-    return _inputExpressions;
+    return Collections.singletonList(_expression);
   }
 
   @Override
@@ -68,8 +69,8 @@ public class CountMVAggregationFunction extends CountAggregationFunction {
 
   @Override
   public void aggregate(int length, AggregationResultHolder aggregationResultHolder,
-      Map<String, BlockValSet> blockValSetMap) {
-    int[] valueArray = blockValSetMap.get(_column).getNumMVEntries();
+      Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+    int[] valueArray = blockValSetMap.get(_expression).getNumMVEntries();
     long count = 0L;
     for (int i = 0; i < length; i++) {
       count += valueArray[i];
@@ -79,8 +80,8 @@ public class CountMVAggregationFunction extends CountAggregationFunction {
 
   @Override
   public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder,
-      Map<String, BlockValSet> blockValSetMap) {
-    int[] valueArray = blockValSetMap.get(_column).getNumMVEntries();
+      Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+    int[] valueArray = blockValSetMap.get(_expression).getNumMVEntries();
     for (int i = 0; i < length; i++) {
       int groupKey = groupKeyArray[i];
       groupByResultHolder.setValueForKey(groupKey, groupByResultHolder.getDoubleResult(groupKey) + valueArray[i]);
@@ -89,8 +90,8 @@ public class CountMVAggregationFunction extends CountAggregationFunction {
 
   @Override
   public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder groupByResultHolder,
-      Map<String, BlockValSet> blockValSetMap) {
-    int[] valueArray = blockValSetMap.get(_column).getNumMVEntries();
+      Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+    int[] valueArray = blockValSetMap.get(_expression).getNumMVEntries();
     for (int i = 0; i < length; i++) {
       int value = valueArray[i];
       for (int groupKey : groupKeysArray[i]) {
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctAggregationFunction.java
index b09f028..8aceb29 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctAggregationFunction.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctAggregationFunction.java
@@ -37,7 +37,6 @@ import org.apache.pinot.core.query.aggregation.DistinctTable;
 import org.apache.pinot.core.query.aggregation.ObjectAggregationResultHolder;
 import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
 import org.apache.pinot.core.util.GroupByUtils;
-import org.apache.pinot.pql.parsers.pql2.ast.IdentifierAstNode;
 
 
 /**
@@ -79,12 +78,13 @@ public class DistinctAggregationFunction implements AggregationFunction<Distinct
 
   @Override
   public String getColumnName() {
-    return getType().getName() + "_" + AggregationFunctionUtils.concatArgs(Arrays.asList(_columns));
+    return AggregationFunctionType.DISTINCT.getName() + "_" + AggregationFunctionUtils.concatArgs(_columns);
   }
 
   @Override
   public String getResultColumnName() {
-    return getType().getName().toLowerCase() + "(" + AggregationFunctionUtils.concatArgs(Arrays.asList(_columns)) + ")";
+    return AggregationFunctionType.DISTINCT.getName().toLowerCase() + "(" + AggregationFunctionUtils
+        .concatArgs(_columns) + ")";
   }
 
   @Override
@@ -104,23 +104,23 @@ public class DistinctAggregationFunction implements AggregationFunction<Distinct
 
   @Override
   public void aggregate(int length, AggregationResultHolder aggregationResultHolder,
-      Map<String, BlockValSet> blockValSetMap) {
-    int numColumns = _columns.length;
+      Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
     int numBlockValSets = blockValSetMap.size();
-    Preconditions.checkState(numBlockValSets == numColumns, "Size mismatch: numBlockValSets = %s, numColumns = %s",
-        numBlockValSets, numColumns);
-
-    DistinctTable distinctTable = aggregationResultHolder.getResult();
-    BlockValSet[] blockValSets = new BlockValSet[numColumns];
-
-    for (int i = 0; i < numColumns; i++) {
-      blockValSets[i] = blockValSetMap.get(_columns[i]);
+    int numExpressions = _inputExpressions.size();
+    Preconditions
+        .checkState(numBlockValSets == numExpressions, "Size mismatch: numBlockValSets = %s, numExpressions = %s",
+            numBlockValSets, numExpressions);
+
+    BlockValSet[] blockValSets = new BlockValSet[numExpressions];
+    for (int i = 0; i < numExpressions; i++) {
+      blockValSets[i] = blockValSetMap.get(_inputExpressions.get(i));
     }
 
+    DistinctTable distinctTable = aggregationResultHolder.getResult();
     if (distinctTable == null) {
-      ColumnDataType[] columnDataTypes = new ColumnDataType[numColumns];
-      for (int i = 0; i < numColumns; i++) {
-        columnDataTypes[i] = ColumnDataType.fromDataTypeSV(blockValSetMap.get(_columns[i]).getValueType());
+      ColumnDataType[] columnDataTypes = new ColumnDataType[numExpressions];
+      for (int i = 0; i < numExpressions; i++) {
+        columnDataTypes[i] = ColumnDataType.fromDataTypeSV(blockValSetMap.get(_inputExpressions.get(i)).getValueType());
       }
       DataSchema dataSchema = new DataSchema(_columns, columnDataTypes);
       distinctTable = new DistinctTable(dataSchema, _orderBy, _capacity);
@@ -147,8 +147,7 @@ public class DistinctAggregationFunction implements AggregationFunction<Distinct
     if (distinctTable != null) {
       return distinctTable;
     } else {
-      int numColumns = _columns.length;
-      ColumnDataType[] columnDataTypes = new ColumnDataType[numColumns];
+      ColumnDataType[] columnDataTypes = new ColumnDataType[_columns.length];
       // NOTE: Use STRING for unknown type
       Arrays.fill(columnDataTypes, ColumnDataType.STRING);
       return new DistinctTable(new DataSchema(_columns, columnDataTypes), _orderBy, _capacity);
@@ -188,13 +187,13 @@ public class DistinctAggregationFunction implements AggregationFunction<Distinct
 
   @Override
   public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder,
-      Map<String, BlockValSet> blockValSetMap) {
+      Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
     throw new UnsupportedOperationException("Operation not supported for DISTINCT aggregation function");
   }
 
   @Override
   public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder groupByResultHolder,
-      Map<String, BlockValSet> blockValSetMap) {
+      Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
     throw new UnsupportedOperationException("Operation not supported for DISTINCT aggregation function");
   }
 
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountAggregationFunction.java
index ebc3aaf..4fafb1d 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountAggregationFunction.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountAggregationFunction.java
@@ -19,8 +19,6 @@
 package org.apache.pinot.core.query.aggregation.function;
 
 import it.unimi.dsi.fastutil.ints.IntOpenHashSet;
-import java.util.Collections;
-import java.util.List;
 import java.util.Map;
 import org.apache.pinot.common.function.AggregationFunctionType;
 import org.apache.pinot.common.request.transform.TransformExpressionTree;
@@ -33,18 +31,10 @@ import org.apache.pinot.core.query.aggregation.groupby.ObjectGroupByResultHolder
 import org.apache.pinot.spi.data.FieldSpec;
 
 
-public class DistinctCountAggregationFunction implements AggregationFunction<IntOpenHashSet, Integer> {
+public class DistinctCountAggregationFunction extends BaseSingleInputAggregationFunction<IntOpenHashSet, Integer> {
 
-  protected final String _column;
-  private final List<TransformExpressionTree> _inputExpressions;
-
-  /**
-   * Constructor for the class.
-   * @param column Column name to aggregate on.
-   */
   public DistinctCountAggregationFunction(String column) {
-    _column = column;
-    _inputExpressions = Collections.singletonList(TransformExpressionTree.compileToExpressionTree(_column));
+    super(column);
   }
 
   @Override
@@ -53,21 +43,6 @@ public class DistinctCountAggregationFunction implements AggregationFunction<Int
   }
 
   @Override
-  public String getColumnName() {
-    return getType().getName() + "_" + _column;
-  }
-
-  @Override
-  public String getResultColumnName() {
-    return getType().getName().toLowerCase() + "(" + _column + ")";
-  }
-
-  @Override
-  public List<TransformExpressionTree> getInputExpressions() {
-    return _inputExpressions;
-  }
-
-  @Override
   public void accept(AggregationFunctionVisitorBase visitor) {
     visitor.visit(this);
   }
@@ -83,8 +58,9 @@ public class DistinctCountAggregationFunction implements AggregationFunction<Int
   }
 
   @Override
-  public void aggregate(int length, AggregationResultHolder aggregationResultHolder, Map<String, BlockValSet> blockValSetMap) {
-    BlockValSet blockValSet = blockValSetMap.get(_column);
+  public void aggregate(int length, AggregationResultHolder aggregationResultHolder,
+      Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+    BlockValSet blockValSet = blockValSetMap.get(_expression);
     IntOpenHashSet valueSet = getValueSet(aggregationResultHolder);
 
     FieldSpec.DataType valueType = blockValSet.getValueType();
@@ -126,8 +102,8 @@ public class DistinctCountAggregationFunction implements AggregationFunction<Int
 
   @Override
   public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder,
-      Map<String, BlockValSet> blockValSetMap) {
-    BlockValSet blockValSet = blockValSetMap.get(_column);
+      Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+    BlockValSet blockValSet = blockValSetMap.get(_expression);
     FieldSpec.DataType valueType = blockValSet.getValueType();
 
     switch (valueType) {
@@ -168,8 +144,8 @@ public class DistinctCountAggregationFunction implements AggregationFunction<Int
 
   @Override
   public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder groupByResultHolder,
-      Map<String, BlockValSet> blockValSetMap) {
-    BlockValSet blockValSet = blockValSetMap.get(_column);
+      Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+    BlockValSet blockValSet = blockValSetMap.get(_expression);
 
     FieldSpec.DataType valueType = blockValSet.getValueType();
     switch (valueType) {
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountHLLAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountHLLAggregationFunction.java
index f8371ea..88beff9 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountHLLAggregationFunction.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountHLLAggregationFunction.java
@@ -20,8 +20,6 @@ package org.apache.pinot.core.query.aggregation.function;
 
 import com.clearspring.analytics.stream.cardinality.HyperLogLog;
 import com.google.common.base.Preconditions;
-import java.util.Collections;
-import java.util.List;
 import java.util.Map;
 import org.apache.pinot.common.function.AggregationFunctionType;
 import org.apache.pinot.common.request.transform.TransformExpressionTree;
@@ -35,19 +33,11 @@ import org.apache.pinot.core.query.aggregation.groupby.ObjectGroupByResultHolder
 import org.apache.pinot.spi.data.FieldSpec.DataType;
 
 
-public class DistinctCountHLLAggregationFunction implements AggregationFunction<HyperLogLog, Long> {
-  protected final String _column;
-
+public class DistinctCountHLLAggregationFunction extends BaseSingleInputAggregationFunction<HyperLogLog, Long> {
   public static final int DEFAULT_LOG2M = 8;
-  private final List<TransformExpressionTree> _inputExpressions;
 
-  /**
-   * Constructor for the class.
-   * @param column Column name to aggregate on.
-   */
   public DistinctCountHLLAggregationFunction(String column) {
-    _column = column;
-    _inputExpressions = Collections.singletonList(TransformExpressionTree.compileToExpressionTree(_column));
+    super(column);
   }
 
   @Override
@@ -56,21 +46,6 @@ public class DistinctCountHLLAggregationFunction implements AggregationFunction<
   }
 
   @Override
-  public String getColumnName() {
-    return getType().getName() + "_" + _column;
-  }
-
-  @Override
-  public String getResultColumnName() {
-    return getType().getName().toLowerCase() + "(" + _column + ")";
-  }
-
-  @Override
-  public List<TransformExpressionTree> getInputExpressions() {
-    return _inputExpressions;
-  }
-
-  @Override
   public void accept(AggregationFunctionVisitorBase visitor) {
     visitor.visit(this);
   }
@@ -86,8 +61,9 @@ public class DistinctCountHLLAggregationFunction implements AggregationFunction<
   }
 
   @Override
-  public void aggregate(int length, AggregationResultHolder aggregationResultHolder, Map<String, BlockValSet> blockValSetMap) {
-    BlockValSet blockValSet = blockValSetMap.get(_column);
+  public void aggregate(int length, AggregationResultHolder aggregationResultHolder,
+      Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+    BlockValSet blockValSet = blockValSetMap.get(_expression);
     DataType valueType = blockValSet.getValueType();
 
     if (valueType != DataType.BYTES) {
@@ -151,8 +127,8 @@ public class DistinctCountHLLAggregationFunction implements AggregationFunction<
 
   @Override
   public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder,
-      Map<String, BlockValSet> blockValSetMap) {
-    BlockValSet blockValSet = blockValSetMap.get(_column);
+      Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+    BlockValSet blockValSet = blockValSetMap.get(_expression);
     DataType valueType = blockValSet.getValueType();
 
     switch (valueType) {
@@ -211,8 +187,8 @@ public class DistinctCountHLLAggregationFunction implements AggregationFunction<
 
   @Override
   public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder groupByResultHolder,
-      Map<String, BlockValSet> blockValSetMap) {
-    BlockValSet blockValSet = blockValSetMap.get(_column);
+      Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+    BlockValSet blockValSet = blockValSetMap.get(_expression);
     DataType valueType = blockValSet.getValueType();
 
     switch (valueType) {
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountHLLMVAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountHLLMVAggregationFunction.java
index eee8db4..1ff547c 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountHLLMVAggregationFunction.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountHLLMVAggregationFunction.java
@@ -21,6 +21,7 @@ package org.apache.pinot.core.query.aggregation.function;
 import com.clearspring.analytics.stream.cardinality.HyperLogLog;
 import java.util.Map;
 import org.apache.pinot.common.function.AggregationFunctionType;
+import org.apache.pinot.common.request.transform.TransformExpressionTree;
 import org.apache.pinot.core.common.BlockValSet;
 import org.apache.pinot.core.query.aggregation.AggregationResultHolder;
 import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
@@ -29,10 +30,6 @@ import org.apache.pinot.spi.data.FieldSpec.DataType;
 
 public class DistinctCountHLLMVAggregationFunction extends DistinctCountHLLAggregationFunction {
 
-  /**
-   * Constructor for the class.
-   * @param column Column name to aggregate on.
-   */
   public DistinctCountHLLMVAggregationFunction(String column) {
     super(column);
   }
@@ -48,10 +45,11 @@ public class DistinctCountHLLMVAggregationFunction extends DistinctCountHLLAggre
   }
 
   @Override
-  public void aggregate(int length, AggregationResultHolder aggregationResultHolder, Map<String, BlockValSet> blockValSetMap) {
+  public void aggregate(int length, AggregationResultHolder aggregationResultHolder,
+      Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
     HyperLogLog hyperLogLog = getDefaultHyperLogLog(aggregationResultHolder);
 
-    BlockValSet blockValSet = blockValSetMap.get(_column);
+    BlockValSet blockValSet = blockValSetMap.get(_expression);
     DataType valueType = blockValSet.getValueType();
     switch (valueType) {
       case INT:
@@ -102,8 +100,8 @@ public class DistinctCountHLLMVAggregationFunction extends DistinctCountHLLAggre
 
   @Override
   public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder,
-      Map<String, BlockValSet> blockValSetMap) {
-    BlockValSet blockValSet = blockValSetMap.get(_column);
+      Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+    BlockValSet blockValSet = blockValSetMap.get(_expression);
     DataType valueType = blockValSet.getValueType();
 
     switch (valueType) {
@@ -160,8 +158,8 @@ public class DistinctCountHLLMVAggregationFunction extends DistinctCountHLLAggre
 
   @Override
   public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder groupByResultHolder,
-      Map<String, BlockValSet> blockValSetMap) {
-    BlockValSet blockValSet = blockValSetMap.get(_column);
+      Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+    BlockValSet blockValSet = blockValSetMap.get(_expression);
     DataType valueType = blockValSet.getValueType();
 
     switch (valueType) {
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountMVAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountMVAggregationFunction.java
index 5b69532..910c8aa 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountMVAggregationFunction.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountMVAggregationFunction.java
@@ -21,6 +21,7 @@ package org.apache.pinot.core.query.aggregation.function;
 import it.unimi.dsi.fastutil.ints.IntOpenHashSet;
 import java.util.Map;
 import org.apache.pinot.common.function.AggregationFunctionType;
+import org.apache.pinot.common.request.transform.TransformExpressionTree;
 import org.apache.pinot.core.common.BlockValSet;
 import org.apache.pinot.core.query.aggregation.AggregationResultHolder;
 import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
@@ -29,10 +30,6 @@ import org.apache.pinot.spi.data.FieldSpec;
 
 public class DistinctCountMVAggregationFunction extends DistinctCountAggregationFunction {
 
-  /**
-   * Constructor for the class.
-   * @param column Column name to aggregate on.
-   */
   public DistinctCountMVAggregationFunction(String column) {
     super(column);
   }
@@ -48,10 +45,11 @@ public class DistinctCountMVAggregationFunction extends DistinctCountAggregation
   }
 
   @Override
-  public void aggregate(int length, AggregationResultHolder aggregationResultHolder, Map<String, BlockValSet> blockValSetMap) {
+  public void aggregate(int length, AggregationResultHolder aggregationResultHolder,
+      Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
     IntOpenHashSet valueSet = getValueSet(aggregationResultHolder);
 
-    BlockValSet blockValSet = blockValSetMap.get(_column);
+    BlockValSet blockValSet = blockValSetMap.get(_expression);
     FieldSpec.DataType valueType = blockValSet.getValueType();
     switch (valueType) {
       case INT:
@@ -100,8 +98,8 @@ public class DistinctCountMVAggregationFunction extends DistinctCountAggregation
 
   @Override
   public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder,
-      Map<String, BlockValSet> blockValSetMap) {
-    BlockValSet blockValSet = blockValSetMap.get(_column);
+      Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+    BlockValSet blockValSet = blockValSetMap.get(_expression);
     FieldSpec.DataType valueType = blockValSet.getValueType();
 
     switch (valueType) {
@@ -157,8 +155,8 @@ public class DistinctCountMVAggregationFunction extends DistinctCountAggregation
 
   @Override
   public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder groupByResultHolder,
-      Map<String, BlockValSet> blockValSetMap) {
-    BlockValSet blockValSet = blockValSetMap.get(_column);
+      Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+    BlockValSet blockValSet = blockValSetMap.get(_expression);
     FieldSpec.DataType valueType = blockValSet.getValueType();
 
     switch (valueType) {
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountRawHLLAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountRawHLLAggregationFunction.java
index fe946ca..40050f6 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountRawHLLAggregationFunction.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountRawHLLAggregationFunction.java
@@ -19,8 +19,6 @@
 package org.apache.pinot.core.query.aggregation.function;
 
 import com.clearspring.analytics.stream.cardinality.HyperLogLog;
-import java.util.Collections;
-import java.util.List;
 import java.util.Map;
 import org.apache.pinot.common.function.AggregationFunctionType;
 import org.apache.pinot.common.request.transform.TransformExpressionTree;
@@ -31,25 +29,17 @@ import org.apache.pinot.core.query.aggregation.function.customobject.SerializedH
 import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
 
 
-public class DistinctCountRawHLLAggregationFunction implements AggregationFunction<HyperLogLog, SerializedHLL> {
+public class DistinctCountRawHLLAggregationFunction extends BaseSingleInputAggregationFunction<HyperLogLog, SerializedHLL> {
   private final DistinctCountHLLAggregationFunction _distinctCountHLLAggregationFunction;
 
-  protected final String _column;
-  private final List<TransformExpressionTree> _inputExpressions;
-
-  /**
-   * Constructor for the class.
-   * @param column Column name to aggregate on.
-   */
   public DistinctCountRawHLLAggregationFunction(String column) {
-    this(new DistinctCountHLLAggregationFunction(column), column);
+    this(column, new DistinctCountHLLAggregationFunction(column));
   }
 
-  DistinctCountRawHLLAggregationFunction(DistinctCountHLLAggregationFunction distinctCountHLLAggregationFunction,
-      String column) {
+  DistinctCountRawHLLAggregationFunction(String column,
+      DistinctCountHLLAggregationFunction distinctCountHLLAggregationFunction) {
+    super(column);
     _distinctCountHLLAggregationFunction = distinctCountHLLAggregationFunction;
-    _column = column;
-    _inputExpressions = Collections.singletonList(TransformExpressionTree.compileToExpressionTree(_column));
   }
 
   @Override
@@ -58,21 +48,6 @@ public class DistinctCountRawHLLAggregationFunction implements AggregationFuncti
   }
 
   @Override
-  public String getColumnName() {
-    return getType().getName() + "_" + _column;
-  }
-
-  @Override
-  public String getResultColumnName() {
-    return getType().getName().toLowerCase() + "(" + _column + ")";
-  }
-
-  @Override
-  public List<TransformExpressionTree> getInputExpressions() {
-    return _inputExpressions;
-  }
-
-  @Override
   public void accept(AggregationFunctionVisitorBase visitor) {
     _distinctCountHLLAggregationFunction.accept(visitor);
   }
@@ -89,19 +64,19 @@ public class DistinctCountRawHLLAggregationFunction implements AggregationFuncti
 
   @Override
   public void aggregate(int length, AggregationResultHolder aggregationResultHolder,
-      Map<String, BlockValSet> blockValSetMap) {
+      Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
     _distinctCountHLLAggregationFunction.aggregate(length, aggregationResultHolder, blockValSetMap);
   }
 
   @Override
   public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder,
-      Map<String, BlockValSet> blockValSetMap) {
+      Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
     _distinctCountHLLAggregationFunction.aggregateGroupBySV(length, groupKeyArray, groupByResultHolder, blockValSetMap);
   }
 
   @Override
   public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder groupByResultHolder,
-      Map<String, BlockValSet> blockValSetMap) {
+      Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
     _distinctCountHLLAggregationFunction
         .aggregateGroupByMV(length, groupKeysArray, groupByResultHolder, blockValSetMap);
   }
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountRawHLLMVAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountRawHLLMVAggregationFunction.java
index 2a1cf2a..3712599 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountRawHLLMVAggregationFunction.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountRawHLLMVAggregationFunction.java
@@ -23,12 +23,8 @@ import org.apache.pinot.common.function.AggregationFunctionType;
 
 public class DistinctCountRawHLLMVAggregationFunction extends DistinctCountRawHLLAggregationFunction {
 
-  /**
-   * Constructor for the class.
-   * @param column Column name to aggregate on.
-   */
   public DistinctCountRawHLLMVAggregationFunction(String column) {
-    super(new DistinctCountHLLMVAggregationFunction(column), column);
+    super(column, new DistinctCountHLLMVAggregationFunction(column));
   }
 
   @Override
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountThetaSketchAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountThetaSketchAggregationFunction.java
index 2dd03d3..b7463a0 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountThetaSketchAggregationFunction.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountThetaSketchAggregationFunction.java
@@ -64,6 +64,7 @@ import org.apache.pinot.sql.parsers.CalciteSqlParser;
 public class DistinctCountThetaSketchAggregationFunction implements AggregationFunction<Map<String, Sketch>, Integer> {
 
   private String _thetaSketchColumn;
+  private TransformExpressionTree _thetaSketchIdentifier;
   private Set<String> _predicateStrings;
   private Expression _postAggregationExpression;
   private Set<PredicateInfo> _predicateInfoSet;
@@ -130,15 +131,15 @@ public class DistinctCountThetaSketchAggregationFunction implements AggregationF
 
   @Override
   public void aggregate(int length, AggregationResultHolder aggregationResultHolder,
-      Map<String, BlockValSet> blockValSetMap) {
+      Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
 
     Map<String, Union> result = getDefaultResult(aggregationResultHolder, _predicateStrings);
-    Sketch[] sketches = deserializeSketches(blockValSetMap.get(_thetaSketchColumn).getBytesValuesSV(), length);
+    Sketch[] sketches = deserializeSketches(blockValSetMap.get(_thetaSketchIdentifier).getBytesValuesSV(), length);
 
     for (PredicateInfo predicateInfo : _predicateInfoSet) {
       String predicate = predicateInfo.getStringVal();
 
-      BlockValSet blockValSet = blockValSetMap.get(predicateInfo.getColumn());
+      BlockValSet blockValSet = blockValSetMap.get(predicateInfo.getExpression());
       FieldSpec.DataType valueType = blockValSet.getValueType();
       PredicateEvaluator predicateEvaluator = predicateInfo.getPredicateEvaluator(valueType);
 
@@ -181,13 +182,13 @@ public class DistinctCountThetaSketchAggregationFunction implements AggregationF
 
   @Override
   public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder,
-      Map<String, BlockValSet> blockValSetMap) {
-    Sketch[] sketches = deserializeSketches(blockValSetMap.get(_thetaSketchColumn).getBytesValuesSV(), length);
+      Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+    Sketch[] sketches = deserializeSketches(blockValSetMap.get(_thetaSketchIdentifier).getBytesValuesSV(), length);
 
     for (PredicateInfo predicateInfo : _predicateInfoSet) {
       String predicate = predicateInfo.getStringVal();
 
-      BlockValSet blockValSet = blockValSetMap.get(predicateInfo.getColumn());
+      BlockValSet blockValSet = blockValSetMap.get(predicateInfo.getExpression());
       FieldSpec.DataType valueType = blockValSet.getValueType();
       PredicateEvaluator predicateEvaluator = predicateInfo.getPredicateEvaluator(valueType);
 
@@ -237,13 +238,13 @@ public class DistinctCountThetaSketchAggregationFunction implements AggregationF
 
   @Override
   public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder groupByResultHolder,
-      Map<String, BlockValSet> blockValSetMap) {
-    Sketch[] sketches = deserializeSketches(blockValSetMap.get(_thetaSketchColumn).getBytesValuesSV(), length);
+      Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+    Sketch[] sketches = deserializeSketches(blockValSetMap.get(_thetaSketchIdentifier).getBytesValuesSV(), length);
 
     for (PredicateInfo predicateInfo : _predicateInfoSet) {
       String predicate = predicateInfo.getStringVal();
 
-      BlockValSet blockValSet = blockValSetMap.get(predicateInfo.getColumn());
+      BlockValSet blockValSet = blockValSetMap.get(predicateInfo.getExpression());
       FieldSpec.DataType valueType = blockValSet.getValueType();
       PredicateEvaluator predicateEvaluator = predicateInfo.getPredicateEvaluator(valueType);
 
@@ -425,10 +426,12 @@ public class DistinctCountThetaSketchAggregationFunction implements AggregationF
 
     // Initialize the Theta-Sketch Column.
     _thetaSketchColumn = arguments.get(0);
+    _thetaSketchIdentifier =
+        new TransformExpressionTree(TransformExpressionTree.ExpressionType.IDENTIFIER, _thetaSketchColumn, null);
 
     // Initialize input expressions. It is expected they are covered between the theta-sketch column and the predicates.
     _inputExpressions = new ArrayList<>();
-    _inputExpressions.add(TransformExpressionTree.compileToExpressionTree(_thetaSketchColumn));
+    _inputExpressions.add(_thetaSketchIdentifier);
 
     // Initialize thetaSketchParams
     String paramsString = arguments.get(1);
@@ -443,14 +446,18 @@ public class DistinctCountThetaSketchAggregationFunction implements AggregationF
 
     if (predicatesSpecified) {
       for (String predicateString : _predicateStrings) {
+        // FIXME: Standardize predicate string?
+
         Expression expression = CalciteSqlParser.compileToExpression(predicateString);
 
         // TODO: Add support for complex predicates with AND/OR.
         String filterColumn = ParserUtils.getFilterColumn(expression);
         Predicate predicate = Predicate
             .newPredicate(ParserUtils.getFilterType(expression), filterColumn, ParserUtils.getFilterValues(expression));
+        TransformExpressionTree filterExpression =
+            new TransformExpressionTree(TransformExpressionTree.ExpressionType.IDENTIFIER, filterColumn, null);
 
-        _predicateInfoSet.add(new PredicateInfo(predicateString, filterColumn, predicate));
+        _predicateInfoSet.add(new PredicateInfo(predicateString, filterExpression, predicate));
         _expressionMap.put(expression, predicateString);
         _inputExpressions.add(new TransformExpressionTree(new IdentifierAstNode(filterColumn)));
       }
@@ -461,12 +468,14 @@ public class DistinctCountThetaSketchAggregationFunction implements AggregationF
         String filterColumn = ParserUtils.getFilterColumn(predicateExpression);
         Predicate predicate = Predicate.newPredicate(ParserUtils.getFilterType(predicateExpression), filterColumn,
             ParserUtils.getFilterValues(predicateExpression));
+        TransformExpressionTree filterExpression =
+            new TransformExpressionTree(TransformExpressionTree.ExpressionType.IDENTIFIER, filterColumn, null);
 
         String predicateString = ParserUtils.standardizeExpression(predicateExpression, false);
         _predicateStrings.add(predicateString);
-        _predicateInfoSet.add(new PredicateInfo(predicateString, filterColumn, predicate));
+        _predicateInfoSet.add(new PredicateInfo(predicateString, filterExpression, predicate));
         _expressionMap.put(predicateExpression, predicateString);
-        _inputExpressions.add(new TransformExpressionTree(new IdentifierAstNode(filterColumn)));
+        _inputExpressions.add(filterExpression);
       }
     }
   }
@@ -567,15 +576,15 @@ public class DistinctCountThetaSketchAggregationFunction implements AggregationF
    *
    */
   private static class PredicateInfo {
-
     private final String _stringVal;
-    private final String _column; // LHS
+    private final TransformExpressionTree _expression; // LHS
+    // FIXME: Predicate does not have equals() and hashCode() implemented
     private final Predicate _predicate;
     private PredicateEvaluator _predicateEvaluator;
 
-    private PredicateInfo(String stringVal, String column, Predicate predicate) {
+    private PredicateInfo(String stringVal, TransformExpressionTree expression, Predicate predicate) {
       _stringVal = stringVal;
-      _column = column;
+      _expression = expression;
       _predicate = predicate;
       _predicateEvaluator = null; // Initialized lazily
     }
@@ -584,8 +593,8 @@ public class DistinctCountThetaSketchAggregationFunction implements AggregationF
       return _stringVal;
     }
 
-    public String getColumn() {
-      return _column;
+    public TransformExpressionTree getExpression() {
+      return _expression;
     }
 
     public Predicate getPredicate() {
@@ -625,13 +634,13 @@ public class DistinctCountThetaSketchAggregationFunction implements AggregationF
       }
 
       PredicateInfo that = (PredicateInfo) o;
-      return Objects.equals(_stringVal, that._stringVal) && Objects.equals(_column, that._column) && Objects
+      return Objects.equals(_stringVal, that._stringVal) && Objects.equals(_expression, that._expression) && Objects
           .equals(_predicate, that._predicate);
     }
 
     @Override
     public int hashCode() {
-      return Objects.hash(_stringVal, _column, _predicate);
+      return Objects.hash(_stringVal, _expression, _predicate);
     }
   }
-}
\ No newline at end of file
+}
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/FastHLLAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/FastHLLAggregationFunction.java
index 12d92cc..0f7022e 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/FastHLLAggregationFunction.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/FastHLLAggregationFunction.java
@@ -20,8 +20,6 @@ package org.apache.pinot.core.query.aggregation.function;
 
 import com.clearspring.analytics.stream.cardinality.HyperLogLog;
 import com.google.common.base.Preconditions;
-import java.util.Collections;
-import java.util.List;
 import java.util.Map;
 import org.apache.pinot.common.function.AggregationFunctionType;
 import org.apache.pinot.common.request.transform.TransformExpressionTree;
@@ -38,20 +36,12 @@ import org.apache.pinot.core.query.aggregation.groupby.ObjectGroupByResultHolder
  * Use {@link DistinctCountHLLAggregationFunction} on byte[] for serialized HyperLogLog.
  */
 @Deprecated
-public class FastHLLAggregationFunction implements AggregationFunction<HyperLogLog, Long> {
+public class FastHLLAggregationFunction extends BaseSingleInputAggregationFunction<HyperLogLog, Long> {
   public static final int DEFAULT_LOG2M = 8;
   private static final int BYTE_TO_CHAR_OFFSET = 129;
 
-  private final String _column;
-  private final List<TransformExpressionTree> _inputExpressions;
-
-  /**
-   * Constructor for the class.
-   * @param column Column name to aggregate on.
-   */
   public FastHLLAggregationFunction(String column) {
-    _column = column;
-    _inputExpressions = Collections.singletonList(TransformExpressionTree.compileToExpressionTree(_column));
+    super(column);
   }
 
   @Override
@@ -60,21 +50,6 @@ public class FastHLLAggregationFunction implements AggregationFunction<HyperLogL
   }
 
   @Override
-  public String getColumnName() {
-    return getType().getName() + "_" + _column;
-  }
-
-  @Override
-  public String getResultColumnName() {
-    return getType().getName().toLowerCase() + "(" + _column + ")";
-  }
-
-  @Override
-  public List<TransformExpressionTree> getInputExpressions() {
-    return _inputExpressions;
-  }
-
-  @Override
   public void accept(AggregationFunctionVisitorBase visitor) {
     visitor.visit(this);
   }
@@ -90,8 +65,9 @@ public class FastHLLAggregationFunction implements AggregationFunction<HyperLogL
   }
 
   @Override
-  public void aggregate(int length, AggregationResultHolder aggregationResultHolder, Map<String, BlockValSet> blockValSetMap) {
-    String[] values = blockValSetMap.get(_column).getStringValuesSV();
+  public void aggregate(int length, AggregationResultHolder aggregationResultHolder,
+      Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+    String[] values = blockValSetMap.get(_expression).getStringValuesSV();
     try {
       HyperLogLog hyperLogLog = aggregationResultHolder.getResult();
       if (hyperLogLog != null) {
@@ -112,8 +88,8 @@ public class FastHLLAggregationFunction implements AggregationFunction<HyperLogL
 
   @Override
   public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder,
-      Map<String, BlockValSet> blockValSetMap) {
-    String[] values = blockValSetMap.get(_column).getStringValuesSV();
+      Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+    String[] values = blockValSetMap.get(_expression).getStringValuesSV();
     try {
       for (int i = 0; i < length; i++) {
         HyperLogLog value = convertStringToHLL(values[i]);
@@ -132,8 +108,8 @@ public class FastHLLAggregationFunction implements AggregationFunction<HyperLogL
 
   @Override
   public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder groupByResultHolder,
-      Map<String, BlockValSet> blockValSetMap) {
-    String[] values = blockValSetMap.get(_column).getStringValuesSV();
+      Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+    String[] values = blockValSetMap.get(_expression).getStringValuesSV();
     try {
       for (int i = 0; i < length; i++) {
         HyperLogLog value = convertStringToHLL(values[i]);
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MaxAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MaxAggregationFunction.java
index a9b3ff3..5d141b3 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MaxAggregationFunction.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MaxAggregationFunction.java
@@ -18,8 +18,6 @@
  */
 package org.apache.pinot.core.query.aggregation.function;
 
-import java.util.Collections;
-import java.util.List;
 import java.util.Map;
 import org.apache.pinot.common.function.AggregationFunctionType;
 import org.apache.pinot.common.request.transform.TransformExpressionTree;
@@ -31,19 +29,11 @@ import org.apache.pinot.core.query.aggregation.groupby.DoubleGroupByResultHolder
 import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
 
 
-public class MaxAggregationFunction implements AggregationFunction<Double, Double> {
+public class MaxAggregationFunction extends BaseSingleInputAggregationFunction<Double, Double> {
   private static final double DEFAULT_INITIAL_VALUE = Double.NEGATIVE_INFINITY;
 
-  protected final String _column;
-  private final List<TransformExpressionTree> _inputExpressions;
-
-  /**
-   * Constructor for the class.
-   * @param column Column name to aggregate on.
-   */
   public MaxAggregationFunction(String column) {
-    _column = column;
-    _inputExpressions = Collections.singletonList(TransformExpressionTree.compileToExpressionTree(_column));
+    super(column);
   }
 
   @Override
@@ -52,21 +42,6 @@ public class MaxAggregationFunction implements AggregationFunction<Double, Doubl
   }
 
   @Override
-  public String getColumnName() {
-    return getType().getName() + "_" + _column;
-  }
-
-  @Override
-  public String getResultColumnName() {
-    return getType().getName().toLowerCase() + "(" + _column + ")";
-  }
-
-  @Override
-  public List<TransformExpressionTree> getInputExpressions() {
-    return _inputExpressions;
-  }
-
-  @Override
   public void accept(AggregationFunctionVisitorBase visitor) {
     visitor.visit(this);
   }
@@ -82,8 +57,9 @@ public class MaxAggregationFunction implements AggregationFunction<Double, Doubl
   }
 
   @Override
-  public void aggregate(int length, AggregationResultHolder aggregationResultHolder, Map<String, BlockValSet> blockValSetMap) {
-    double[] valueArray = blockValSetMap.get(_column).getDoubleValuesSV();
+  public void aggregate(int length, AggregationResultHolder aggregationResultHolder,
+      Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+    double[] valueArray = blockValSetMap.get(_expression).getDoubleValuesSV();
     double max = aggregationResultHolder.getDoubleResult();
     for (int i = 0; i < length; i++) {
       double value = valueArray[i];
@@ -96,8 +72,8 @@ public class MaxAggregationFunction implements AggregationFunction<Double, Doubl
 
   @Override
   public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder,
-      Map<String, BlockValSet> blockValSetMap) {
-    double[] valueArray = blockValSetMap.get(_column).getDoubleValuesSV();
+      Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+    double[] valueArray = blockValSetMap.get(_expression).getDoubleValuesSV();
     for (int i = 0; i < length; i++) {
       double value = valueArray[i];
       int groupKey = groupKeyArray[i];
@@ -109,8 +85,8 @@ public class MaxAggregationFunction implements AggregationFunction<Double, Doubl
 
   @Override
   public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder groupByResultHolder,
-      Map<String, BlockValSet> blockValSetMap) {
-    double[] valueArray = blockValSetMap.get(_column).getDoubleValuesSV();
+      Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+    double[] valueArray = blockValSetMap.get(_expression).getDoubleValuesSV();
     for (int i = 0; i < length; i++) {
       double value = valueArray[i];
       for (int groupKey : groupKeysArray[i]) {
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MaxMVAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MaxMVAggregationFunction.java
index 3083f87..434bcea 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MaxMVAggregationFunction.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MaxMVAggregationFunction.java
@@ -20,6 +20,7 @@ package org.apache.pinot.core.query.aggregation.function;
 
 import java.util.Map;
 import org.apache.pinot.common.function.AggregationFunctionType;
+import org.apache.pinot.common.request.transform.TransformExpressionTree;
 import org.apache.pinot.core.common.BlockValSet;
 import org.apache.pinot.core.query.aggregation.AggregationResultHolder;
 import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
@@ -27,10 +28,6 @@ import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
 
 public class MaxMVAggregationFunction extends MaxAggregationFunction {
 
-  /**
-   * Constructor for the class.
-   * @param column Column name to aggregate on.
-   */
   public MaxMVAggregationFunction(String column) {
     super(column);
   }
@@ -46,8 +43,9 @@ public class MaxMVAggregationFunction extends MaxAggregationFunction {
   }
 
   @Override
-  public void aggregate(int length, AggregationResultHolder aggregationResultHolder, Map<String, BlockValSet> blockValSetMap) {
-    double[][] valuesArray = blockValSetMap.get(_column).getDoubleValuesMV();
+  public void aggregate(int length, AggregationResultHolder aggregationResultHolder,
+      Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+    double[][] valuesArray = blockValSetMap.get(_expression).getDoubleValuesMV();
     double max = aggregationResultHolder.getDoubleResult();
     for (int i = 0; i < length; i++) {
       for (double value : valuesArray[i]) {
@@ -61,8 +59,8 @@ public class MaxMVAggregationFunction extends MaxAggregationFunction {
 
   @Override
   public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder,
-      Map<String, BlockValSet> blockValSetMap) {
-    double[][] valuesArray = blockValSetMap.get(_column).getDoubleValuesMV();
+      Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+    double[][] valuesArray = blockValSetMap.get(_expression).getDoubleValuesMV();
     for (int i = 0; i < length; i++) {
       int groupKey = groupKeyArray[i];
       double max = groupByResultHolder.getDoubleResult(groupKey);
@@ -77,8 +75,8 @@ public class MaxMVAggregationFunction extends MaxAggregationFunction {
 
   @Override
   public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder groupByResultHolder,
-      Map<String, BlockValSet> blockValSetMap) {
-    double[][] valuesArray = blockValSetMap.get(_column).getDoubleValuesMV();
+      Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+    double[][] valuesArray = blockValSetMap.get(_expression).getDoubleValuesMV();
     for (int i = 0; i < length; i++) {
       double[] values = valuesArray[i];
       for (int groupKey : groupKeysArray[i]) {
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MinAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MinAggregationFunction.java
index ddf451a..68aca61 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MinAggregationFunction.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MinAggregationFunction.java
@@ -18,8 +18,6 @@
  */
 package org.apache.pinot.core.query.aggregation.function;
 
-import java.util.Collections;
-import java.util.List;
 import java.util.Map;
 import org.apache.pinot.common.function.AggregationFunctionType;
 import org.apache.pinot.common.request.transform.TransformExpressionTree;
@@ -31,19 +29,11 @@ import org.apache.pinot.core.query.aggregation.groupby.DoubleGroupByResultHolder
 import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
 
 
-public class MinAggregationFunction implements AggregationFunction<Double, Double> {
+public class MinAggregationFunction extends BaseSingleInputAggregationFunction<Double, Double> {
   private static final double DEFAULT_VALUE = Double.POSITIVE_INFINITY;
 
-  protected final String _column;
-  private final List<TransformExpressionTree> _inputExpressions;
-
-  /**
-   * Constructor for the class.
-   * @param column Column name to aggregate on.
-   */
   public MinAggregationFunction(String column) {
-    _column = column;
-    _inputExpressions = Collections.singletonList(TransformExpressionTree.compileToExpressionTree(_column));
+    super(column);
   }
 
   @Override
@@ -52,21 +42,6 @@ public class MinAggregationFunction implements AggregationFunction<Double, Doubl
   }
 
   @Override
-  public String getColumnName() {
-    return getType().getName() + "_" + _column;
-  }
-
-  @Override
-  public String getResultColumnName() {
-    return getType().getName().toLowerCase() + "(" + _column + ")";
-  }
-
-  @Override
-  public List<TransformExpressionTree> getInputExpressions() {
-    return _inputExpressions;
-  }
-
-  @Override
   public void accept(AggregationFunctionVisitorBase visitor) {
     visitor.visit(this);
   }
@@ -82,8 +57,9 @@ public class MinAggregationFunction implements AggregationFunction<Double, Doubl
   }
 
   @Override
-  public void aggregate(int length, AggregationResultHolder aggregationResultHolder, Map<String, BlockValSet> blockValSetMap) {
-    double[] valueArray = blockValSetMap.get(_column).getDoubleValuesSV();
+  public void aggregate(int length, AggregationResultHolder aggregationResultHolder,
+      Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+    double[] valueArray = blockValSetMap.get(_expression).getDoubleValuesSV();
     double min = aggregationResultHolder.getDoubleResult();
     for (int i = 0; i < length; i++) {
       double value = valueArray[i];
@@ -96,8 +72,8 @@ public class MinAggregationFunction implements AggregationFunction<Double, Doubl
 
   @Override
   public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder,
-      Map<String, BlockValSet> blockValSetMap) {
-    double[] valueArray = blockValSetMap.get(_column).getDoubleValuesSV();
+      Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+    double[] valueArray = blockValSetMap.get(_expression).getDoubleValuesSV();
     for (int i = 0; i < length; i++) {
       double value = valueArray[i];
       int groupKey = groupKeyArray[i];
@@ -109,8 +85,8 @@ public class MinAggregationFunction implements AggregationFunction<Double, Doubl
 
   @Override
   public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder groupByResultHolder,
-      Map<String, BlockValSet> blockValSetMap) {
-    double[] valueArray = blockValSetMap.get(_column).getDoubleValuesSV();
+      Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+    double[] valueArray = blockValSetMap.get(_expression).getDoubleValuesSV();
     for (int i = 0; i < length; i++) {
       double value = valueArray[i];
       for (int groupKey : groupKeysArray[i]) {
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MinMVAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MinMVAggregationFunction.java
index f7b3141..e97a405 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MinMVAggregationFunction.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MinMVAggregationFunction.java
@@ -20,6 +20,7 @@ package org.apache.pinot.core.query.aggregation.function;
 
 import java.util.Map;
 import org.apache.pinot.common.function.AggregationFunctionType;
+import org.apache.pinot.common.request.transform.TransformExpressionTree;
 import org.apache.pinot.core.common.BlockValSet;
 import org.apache.pinot.core.query.aggregation.AggregationResultHolder;
 import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
@@ -27,10 +28,6 @@ import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
 
 public class MinMVAggregationFunction extends MinAggregationFunction {
 
-  /**
-   * Constructor for the class.
-   * @param column Column name to aggregate on.
-   */
   public MinMVAggregationFunction(String column) {
     super(column);
   }
@@ -46,8 +43,9 @@ public class MinMVAggregationFunction extends MinAggregationFunction {
   }
 
   @Override
-  public void aggregate(int length, AggregationResultHolder aggregationResultHolder, Map<String, BlockValSet> blockValSetMap) {
-    double[][] valuesArray = blockValSetMap.get(_column).getDoubleValuesMV();
+  public void aggregate(int length, AggregationResultHolder aggregationResultHolder,
+      Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+    double[][] valuesArray = blockValSetMap.get(_expression).getDoubleValuesMV();
     double min = aggregationResultHolder.getDoubleResult();
     for (int i = 0; i < length; i++) {
       for (double value : valuesArray[i]) {
@@ -61,8 +59,8 @@ public class MinMVAggregationFunction extends MinAggregationFunction {
 
   @Override
   public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder,
-      Map<String, BlockValSet> blockValSetMap) {
-    double[][] valuesArray = blockValSetMap.get(_column).getDoubleValuesMV();
+      Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+    double[][] valuesArray = blockValSetMap.get(_expression).getDoubleValuesMV();
     for (int i = 0; i < length; i++) {
       int groupKey = groupKeyArray[i];
       double min = groupByResultHolder.getDoubleResult(groupKey);
@@ -77,8 +75,8 @@ public class MinMVAggregationFunction extends MinAggregationFunction {
 
   @Override
   public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder groupByResultHolder,
-      Map<String, BlockValSet> blockValSetMap) {
-    double[][] valuesArray = blockValSetMap.get(_column).getDoubleValuesMV();
+      Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+    double[][] valuesArray = blockValSetMap.get(_expression).getDoubleValuesMV();
     for (int i = 0; i < length; i++) {
       double[] values = valuesArray[i];
       for (int groupKey : groupKeysArray[i]) {
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MinMaxRangeAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MinMaxRangeAggregationFunction.java
index ca4ec10..53aa53a 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MinMaxRangeAggregationFunction.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MinMaxRangeAggregationFunction.java
@@ -18,8 +18,6 @@
  */
 package org.apache.pinot.core.query.aggregation.function;
 
-import java.util.Collections;
-import java.util.List;
 import java.util.Map;
 import org.apache.pinot.common.function.AggregationFunctionType;
 import org.apache.pinot.common.request.transform.TransformExpressionTree;
@@ -34,18 +32,10 @@ import org.apache.pinot.core.query.aggregation.groupby.ObjectGroupByResultHolder
 import org.apache.pinot.spi.data.FieldSpec.DataType;
 
 
-public class MinMaxRangeAggregationFunction implements AggregationFunction<MinMaxRangePair, Double> {
+public class MinMaxRangeAggregationFunction extends BaseSingleInputAggregationFunction<MinMaxRangePair, Double> {
 
-  protected final String _column;
-  private final List<TransformExpressionTree> _inputExpressions;
-
-  /**
-   * Constructor for the class.
-   * @param column Column name to aggregate on.
-   */
   public MinMaxRangeAggregationFunction(String column) {
-    _column = column;
-    _inputExpressions = Collections.singletonList(TransformExpressionTree.compileToExpressionTree(_column));
+    super(column);
   }
 
   @Override
@@ -54,21 +44,6 @@ public class MinMaxRangeAggregationFunction implements AggregationFunction<MinMa
   }
 
   @Override
-  public String getColumnName() {
-    return getType().getName() + "_" + _column;
-  }
-
-  @Override
-  public String getResultColumnName() {
-    return getType().getName().toLowerCase() + "(" + _column + ")";
-  }
-
-  @Override
-  public List<TransformExpressionTree> getInputExpressions() {
-    return _inputExpressions;
-  }
-
-  @Override
   public void accept(AggregationFunctionVisitorBase visitor) {
     visitor.visit(this);
   }
@@ -85,11 +60,11 @@ public class MinMaxRangeAggregationFunction implements AggregationFunction<MinMa
 
   @Override
   public void aggregate(int length, AggregationResultHolder aggregationResultHolder,
-      Map<String, BlockValSet> blockValSetMap) {
+      Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
     double min = Double.POSITIVE_INFINITY;
     double max = Double.NEGATIVE_INFINITY;
 
-    BlockValSet blockValSet = blockValSetMap.get(_column);
+    BlockValSet blockValSet = blockValSetMap.get(_expression);
     if (blockValSet.getValueType() != DataType.BYTES) {
       double[] doubleValues = blockValSet.getDoubleValuesSV();
       for (int i = 0; i < length; i++) {
@@ -130,8 +105,8 @@ public class MinMaxRangeAggregationFunction implements AggregationFunction<MinMa
 
   @Override
   public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder,
-      Map<String, BlockValSet> blockValSetMap) {
-    BlockValSet blockValSet = blockValSetMap.get(_column);
+      Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+    BlockValSet blockValSet = blockValSetMap.get(_expression);
     if (blockValSet.getValueType() != DataType.BYTES) {
       double[] doubleValues = blockValSet.getDoubleValuesSV();
       for (int i = 0; i < length; i++) {
@@ -150,8 +125,8 @@ public class MinMaxRangeAggregationFunction implements AggregationFunction<MinMa
 
   @Override
   public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder groupByResultHolder,
-      Map<String, BlockValSet> blockValSetMap) {
-    BlockValSet blockValSet = blockValSetMap.get(_column);
+      Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+    BlockValSet blockValSet = blockValSetMap.get(_expression);
     if (blockValSet.getValueType() != DataType.BYTES) {
       double[] doubleValues = blockValSet.getDoubleValuesSV();
       for (int i = 0; i < length; i++) {
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MinMaxRangeMVAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MinMaxRangeMVAggregationFunction.java
index 5a18190..d88a098 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MinMaxRangeMVAggregationFunction.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MinMaxRangeMVAggregationFunction.java
@@ -20,6 +20,7 @@ package org.apache.pinot.core.query.aggregation.function;
 
 import java.util.Map;
 import org.apache.pinot.common.function.AggregationFunctionType;
+import org.apache.pinot.common.request.transform.TransformExpressionTree;
 import org.apache.pinot.core.common.BlockValSet;
 import org.apache.pinot.core.query.aggregation.AggregationResultHolder;
 import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
@@ -27,10 +28,6 @@ import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
 
 public class MinMaxRangeMVAggregationFunction extends MinMaxRangeAggregationFunction {
 
-  /**
-   * Constructor for the class.
-   * @param column Column name to aggregate on.
-   */
   public MinMaxRangeMVAggregationFunction(String column) {
     super(column);
   }
@@ -46,8 +43,9 @@ public class MinMaxRangeMVAggregationFunction extends MinMaxRangeAggregationFunc
   }
 
   @Override
-  public void aggregate(int length, AggregationResultHolder aggregationResultHolder, Map<String, BlockValSet> blockValSetMap) {
-    double[][] valuesArray = blockValSetMap.get(_column).getDoubleValuesMV();
+  public void aggregate(int length, AggregationResultHolder aggregationResultHolder,
+      Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+    double[][] valuesArray = blockValSetMap.get(_expression).getDoubleValuesMV();
     double min = Double.POSITIVE_INFINITY;
     double max = Double.NEGATIVE_INFINITY;
     for (int i = 0; i < length; i++) {
@@ -66,8 +64,8 @@ public class MinMaxRangeMVAggregationFunction extends MinMaxRangeAggregationFunc
 
   @Override
   public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder,
-      Map<String, BlockValSet> blockValSetMap) {
-    double[][] valuesArray = blockValSetMap.get(_column).getDoubleValuesMV();
+      Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+    double[][] valuesArray = blockValSetMap.get(_expression).getDoubleValuesMV();
     for (int i = 0; i < length; i++) {
       aggregateOnGroupKey(groupKeyArray[i], groupByResultHolder, valuesArray[i]);
     }
@@ -75,8 +73,8 @@ public class MinMaxRangeMVAggregationFunction extends MinMaxRangeAggregationFunc
 
   @Override
   public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder groupByResultHolder,
-      Map<String, BlockValSet> blockValSetMap) {
-    double[][] valuesArray = blockValSetMap.get(_column).getDoubleValuesMV();
+      Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+    double[][] valuesArray = blockValSetMap.get(_expression).getDoubleValuesMV();
     for (int i = 0; i < length; i++) {
       double[] values = valuesArray[i];
       for (int groupKey : groupKeysArray[i]) {
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileAggregationFunction.java
index 580e7f3..d779d71 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileAggregationFunction.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileAggregationFunction.java
@@ -18,11 +18,8 @@
  */
 package org.apache.pinot.core.query.aggregation.function;
 
-import com.google.common.base.Preconditions;
 import it.unimi.dsi.fastutil.doubles.DoubleArrayList;
 import java.util.Arrays;
-import java.util.Collections;
-import java.util.List;
 import java.util.Map;
 import org.apache.pinot.common.function.AggregationFunctionType;
 import org.apache.pinot.common.request.transform.TransformExpressionTree;
@@ -34,29 +31,14 @@ import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
 import org.apache.pinot.core.query.aggregation.groupby.ObjectGroupByResultHolder;
 
 
-public class PercentileAggregationFunction implements AggregationFunction<DoubleArrayList, Double> {
+public class PercentileAggregationFunction extends BaseSingleInputAggregationFunction<DoubleArrayList, Double> {
   private static final double DEFAULT_FINAL_RESULT = Double.NEGATIVE_INFINITY;
 
   protected final int _percentile;
-  protected final String _column;
-  private final List<TransformExpressionTree> _inputExpressions;
 
-  /**
-   * Constructor for the class.
-   *
-   * @param arguments List of arguments.
-   *                  <ul>
-   *                  <li> Arg 0: Column name to aggregate.</li>
-   *                  <li> Arg 1: Percentile to compute. </li>
-   *                  </ul>
-   */
-  public PercentileAggregationFunction(List<String> arguments) {
-    int numArgs = arguments.size();
-    Preconditions.checkArgument(numArgs == 2, getType() + " expects two argument, got: " + numArgs);
-
-    _column = arguments.get(0);
-    _percentile = AggregationFunctionUtils.parsePercentile(arguments.get(1));
-    _inputExpressions = Collections.singletonList(TransformExpressionTree.compileToExpressionTree(_column));
+  public PercentileAggregationFunction(String column, int percentile) {
+    super(column);
+    _percentile = percentile;
   }
 
   @Override
@@ -66,17 +48,12 @@ public class PercentileAggregationFunction implements AggregationFunction<Double
 
   @Override
   public String getColumnName() {
-    return getType().getName() + _percentile + "_" + _column;
+    return AggregationFunctionType.PERCENTILE.getName() + _percentile + "_" + _column;
   }
 
   @Override
   public String getResultColumnName() {
-    return getType().getName().toLowerCase() + _percentile + "(" + _column + ")";
-  }
-
-  @Override
-  public List<TransformExpressionTree> getInputExpressions() {
-    return _inputExpressions;
+    return AggregationFunctionType.PERCENTILE.getName().toLowerCase() + _percentile + "(" + _column + ")";
   }
 
   @Override
@@ -95,9 +72,10 @@ public class PercentileAggregationFunction implements AggregationFunction<Double
   }
 
   @Override
-  public void aggregate(int length, AggregationResultHolder aggregationResultHolder, Map<String, BlockValSet> blockValSetMap) {
+  public void aggregate(int length, AggregationResultHolder aggregationResultHolder,
+      Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
     DoubleArrayList valueList = getValueList(aggregationResultHolder);
-    double[] valueArray = blockValSetMap.get(_column).getDoubleValuesSV();
+    double[] valueArray = blockValSetMap.get(_expression).getDoubleValuesSV();
     for (int i = 0; i < length; i++) {
       valueList.add(valueArray[i]);
     }
@@ -105,8 +83,8 @@ public class PercentileAggregationFunction implements AggregationFunction<Double
 
   @Override
   public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder,
-      Map<String, BlockValSet> blockValSetMap) {
-    double[] valueArray = blockValSetMap.get(_column).getDoubleValuesSV();
+      Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+    double[] valueArray = blockValSetMap.get(_expression).getDoubleValuesSV();
     for (int i = 0; i < length; i++) {
       DoubleArrayList valueList = getValueList(groupByResultHolder, groupKeyArray[i]);
       valueList.add(valueArray[i]);
@@ -115,8 +93,8 @@ public class PercentileAggregationFunction implements AggregationFunction<Double
 
   @Override
   public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder groupByResultHolder,
-      Map<String, BlockValSet> blockValSetMap) {
-    double[] valueArray = blockValSetMap.get(_column).getDoubleValuesSV();
+      Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+    double[] valueArray = blockValSetMap.get(_expression).getDoubleValuesSV();
     for (int i = 0; i < length; i++) {
       double value = valueArray[i];
       for (int groupKey : groupKeysArray[i]) {
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileEstAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileEstAggregationFunction.java
index 00cfd1d..5a89731 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileEstAggregationFunction.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileEstAggregationFunction.java
@@ -18,9 +18,6 @@
  */
 package org.apache.pinot.core.query.aggregation.function;
 
-import com.google.common.base.Preconditions;
-import java.util.Collections;
-import java.util.List;
 import java.util.Map;
 import org.apache.pinot.common.function.AggregationFunctionType;
 import org.apache.pinot.common.request.transform.TransformExpressionTree;
@@ -35,29 +32,14 @@ import org.apache.pinot.core.query.aggregation.groupby.ObjectGroupByResultHolder
 import org.apache.pinot.spi.data.FieldSpec.DataType;
 
 
-public class PercentileEstAggregationFunction implements AggregationFunction<QuantileDigest, Long> {
+public class PercentileEstAggregationFunction extends BaseSingleInputAggregationFunction<QuantileDigest, Long> {
   public static final double DEFAULT_MAX_ERROR = 0.05;
 
   protected final int _percentile;
-  protected final String _column;
-  private final List<TransformExpressionTree> _inputExpressions;
 
-  /**
-   * Constructor for the class.
-   *
-   * @param arguments List of arguments.
-   *                  <ul>
-   *                  <li> Arg 0: Column name to aggregate.</li>
-   *                  <li> Arg 1: Percentile to compute. </li>
-   *                  </ul>
-   */
-  public PercentileEstAggregationFunction(List<String> arguments) {
-    int numArgs = arguments.size();
-    Preconditions.checkArgument(numArgs == 2, getType() + " expects two argument, got: " + numArgs);
-
-    _column = arguments.get(0);
-    _percentile = AggregationFunctionUtils.parsePercentile(arguments.get(1));
-    _inputExpressions = Collections.singletonList(TransformExpressionTree.compileToExpressionTree(_column));
+  public PercentileEstAggregationFunction(String column, int percentile) {
+    super(column);
+    _percentile = percentile;
   }
 
   @Override
@@ -76,11 +58,6 @@ public class PercentileEstAggregationFunction implements AggregationFunction<Qua
   }
 
   @Override
-  public List<TransformExpressionTree> getInputExpressions() {
-    return _inputExpressions;
-  }
-
-  @Override
   public void accept(AggregationFunctionVisitorBase visitor) {
     visitor.visit(this);
   }
@@ -97,8 +74,8 @@ public class PercentileEstAggregationFunction implements AggregationFunction<Qua
 
   @Override
   public void aggregate(int length, AggregationResultHolder aggregationResultHolder,
-      Map<String, BlockValSet> blockValSetMap) {
-    BlockValSet blockValSet = blockValSetMap.get(_column);
+      Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+    BlockValSet blockValSet = blockValSetMap.get(_expression);
     if (blockValSet.getValueType() != DataType.BYTES) {
       long[] longValues = blockValSet.getLongValuesSV();
       QuantileDigest quantileDigest = getDefaultQuantileDigest(aggregationResultHolder);
@@ -125,8 +102,8 @@ public class PercentileEstAggregationFunction implements AggregationFunction<Qua
 
   @Override
   public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder,
-      Map<String, BlockValSet> blockValSetMap) {
-    BlockValSet blockValSet = blockValSetMap.get(_column);
+      Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+    BlockValSet blockValSet = blockValSetMap.get(_expression);
     if (blockValSet.getValueType() != DataType.BYTES) {
       long[] longValues = blockValSet.getLongValuesSV();
       for (int i = 0; i < length; i++) {
@@ -150,8 +127,8 @@ public class PercentileEstAggregationFunction implements AggregationFunction<Qua
 
   @Override
   public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder groupByResultHolder,
-      Map<String, BlockValSet> blockValSetMap) {
-    BlockValSet blockValSet = blockValSetMap.get(_column);
+      Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+    BlockValSet blockValSet = blockValSetMap.get(_expression);
     if (blockValSet.getValueType() != DataType.BYTES) {
       long[] longValues = blockValSet.getLongValuesSV();
       for (int i = 0; i < length; i++) {
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileEstMVAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileEstMVAggregationFunction.java
index 8809d42..bfc44b4 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileEstMVAggregationFunction.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileEstMVAggregationFunction.java
@@ -18,9 +18,9 @@
  */
 package org.apache.pinot.core.query.aggregation.function;
 
-import java.util.List;
 import java.util.Map;
 import org.apache.pinot.common.function.AggregationFunctionType;
+import org.apache.pinot.common.request.transform.TransformExpressionTree;
 import org.apache.pinot.core.common.BlockValSet;
 import org.apache.pinot.core.query.aggregation.AggregationResultHolder;
 import org.apache.pinot.core.query.aggregation.function.customobject.QuantileDigest;
@@ -29,17 +29,8 @@ import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
 
 public class PercentileEstMVAggregationFunction extends PercentileEstAggregationFunction {
 
-  /**
-   * Constructor for the class.
-   *
-   * @param arguments List of arguments.
-   *                  <ul>
-   *                  <li> Arg 0: Column name to aggregate.</li>
-   *                  <li> Arg 1: Percentile to compute. </li>
-   *                  </ul>
-   */
-  public PercentileEstMVAggregationFunction(List<String> arguments) {
-    super(arguments);
+  public PercentileEstMVAggregationFunction(String column, int percentile) {
+    super(column, percentile);
   }
 
   @Override
@@ -63,8 +54,9 @@ public class PercentileEstMVAggregationFunction extends PercentileEstAggregation
   }
 
   @Override
-  public void aggregate(int length, AggregationResultHolder aggregationResultHolder, Map<String, BlockValSet> blockValSetMap) {
-    long[][] valuesArray = blockValSetMap.get(_column).getLongValuesMV();
+  public void aggregate(int length, AggregationResultHolder aggregationResultHolder,
+      Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+    long[][] valuesArray = blockValSetMap.get(_expression).getLongValuesMV();
     QuantileDigest quantileDigest = getDefaultQuantileDigest(aggregationResultHolder);
     for (int i = 0; i < length; i++) {
       for (long value : valuesArray[i]) {
@@ -75,8 +67,8 @@ public class PercentileEstMVAggregationFunction extends PercentileEstAggregation
 
   @Override
   public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder,
-      Map<String, BlockValSet> blockValSetMap) {
-    long[][] valuesArray = blockValSetMap.get(_column).getLongValuesMV();
+      Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+    long[][] valuesArray = blockValSetMap.get(_expression).getLongValuesMV();
     for (int i = 0; i < length; i++) {
       QuantileDigest quantileDigest = getDefaultQuantileDigest(groupByResultHolder, groupKeyArray[i]);
       for (long value : valuesArray[i]) {
@@ -87,8 +79,8 @@ public class PercentileEstMVAggregationFunction extends PercentileEstAggregation
 
   @Override
   public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder groupByResultHolder,
-      Map<String, BlockValSet> blockValSetMap) {
-    long[][] valuesArray = blockValSetMap.get(_column).getLongValuesMV();
+      Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+    long[][] valuesArray = blockValSetMap.get(_expression).getLongValuesMV();
     for (int i = 0; i < length; i++) {
       long[] values = valuesArray[i];
       for (int groupKey : groupKeysArray[i]) {
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileMVAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileMVAggregationFunction.java
index fba5f02..b02af5a 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileMVAggregationFunction.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileMVAggregationFunction.java
@@ -19,9 +19,9 @@
 package org.apache.pinot.core.query.aggregation.function;
 
 import it.unimi.dsi.fastutil.doubles.DoubleArrayList;
-import java.util.List;
 import java.util.Map;
 import org.apache.pinot.common.function.AggregationFunctionType;
+import org.apache.pinot.common.request.transform.TransformExpressionTree;
 import org.apache.pinot.core.common.BlockValSet;
 import org.apache.pinot.core.query.aggregation.AggregationResultHolder;
 import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
@@ -29,17 +29,8 @@ import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
 
 public class PercentileMVAggregationFunction extends PercentileAggregationFunction {
 
-  /**
-   * Constructor for the class.
-   *
-   * @param arguments List of arguments.
-   *                  <ul>
-   *                  <li> Arg 0: Column name to aggregate.</li>
-   *                  <li> Arg 1: Percentile to compute. </li>
-   *                  </ul>
-   */
-  public PercentileMVAggregationFunction(List<String> arguments) {
-    super(arguments);
+  public PercentileMVAggregationFunction(String column, int percentile) {
+    super(column, percentile);
   }
 
   @Override
@@ -63,8 +54,9 @@ public class PercentileMVAggregationFunction extends PercentileAggregationFuncti
   }
 
   @Override
-  public void aggregate(int length, AggregationResultHolder aggregationResultHolder, Map<String, BlockValSet> blockValSetMap) {
-    double[][] valuesArray = blockValSetMap.get(_column).getDoubleValuesMV();
+  public void aggregate(int length, AggregationResultHolder aggregationResultHolder,
+      Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+    double[][] valuesArray = blockValSetMap.get(_expression).getDoubleValuesMV();
     DoubleArrayList valueList = getValueList(aggregationResultHolder);
     for (int i = 0; i < length; i++) {
       for (double value : valuesArray[i]) {
@@ -75,8 +67,8 @@ public class PercentileMVAggregationFunction extends PercentileAggregationFuncti
 
   @Override
   public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder,
-      Map<String, BlockValSet> blockValSetMap) {
-    double[][] valuesArray = blockValSetMap.get(_column).getDoubleValuesMV();
+      Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+    double[][] valuesArray = blockValSetMap.get(_expression).getDoubleValuesMV();
     for (int i = 0; i < length; i++) {
       DoubleArrayList valueList = getValueList(groupByResultHolder, groupKeyArray[i]);
       for (double value : valuesArray[i]) {
@@ -87,8 +79,8 @@ public class PercentileMVAggregationFunction extends PercentileAggregationFuncti
 
   @Override
   public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder groupByResultHolder,
-      Map<String, BlockValSet> blockValSetMap) {
-    double[][] valuesArray = blockValSetMap.get(_column).getDoubleValuesMV();
+      Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+    double[][] valuesArray = blockValSetMap.get(_expression).getDoubleValuesMV();
     for (int i = 0; i < length; i++) {
       double[] values = valuesArray[i];
       for (int groupKey : groupKeysArray[i]) {
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileTDigestAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileTDigestAggregationFunction.java
index faa9d71..bbe76dd 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileTDigestAggregationFunction.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileTDigestAggregationFunction.java
@@ -18,10 +18,7 @@
  */
 package org.apache.pinot.core.query.aggregation.function;
 
-import com.google.common.base.Preconditions;
 import com.tdunning.math.stats.TDigest;
-import java.util.Collections;
-import java.util.List;
 import java.util.Map;
 import org.apache.pinot.common.function.AggregationFunctionType;
 import org.apache.pinot.common.request.transform.TransformExpressionTree;
@@ -38,28 +35,14 @@ import org.apache.pinot.spi.data.FieldSpec.DataType;
 /**
  * TDigest based Percentile aggregation function.
  */
-public class PercentileTDigestAggregationFunction implements AggregationFunction<TDigest, Double> {
+public class PercentileTDigestAggregationFunction extends BaseSingleInputAggregationFunction<TDigest, Double> {
   public static final int DEFAULT_TDIGEST_COMPRESSION = 100;
 
   protected final int _percentile;
-  protected final String _column;
-  private final List<TransformExpressionTree> _inputExpressions;
 
-  /**
-   * Constructor for the class.
-   *
-   * @param arguments List of arguments.
-   *                  <ul>
-   *                  <li> Arg 0: Column name to aggregate.</li>
-   *                  <li> Arg 1: Percentile to compute. </li>
-   *                  </ul>
-   */
-  public PercentileTDigestAggregationFunction(List<String> arguments) {
-    int numArgs = arguments.size();
-    Preconditions.checkArgument(numArgs == 2, getType() + " expects two argument, got: " + numArgs);
-    _column = arguments.get(0);
-    _percentile = AggregationFunctionUtils.parsePercentile(arguments.get(1));
-    _inputExpressions = Collections.singletonList(TransformExpressionTree.compileToExpressionTree(_column));
+  public PercentileTDigestAggregationFunction(String column, int percentile) {
+    super(column);
+    _percentile = percentile;
   }
 
   @Override
@@ -78,11 +61,6 @@ public class PercentileTDigestAggregationFunction implements AggregationFunction
   }
 
   @Override
-  public List<TransformExpressionTree> getInputExpressions() {
-    return _inputExpressions;
-  }
-
-  @Override
   public void accept(AggregationFunctionVisitorBase visitor) {
     visitor.visit(this);
   }
@@ -98,8 +76,9 @@ public class PercentileTDigestAggregationFunction implements AggregationFunction
   }
 
   @Override
-  public void aggregate(int length, AggregationResultHolder aggregationResultHolder, Map<String, BlockValSet> blockValSetMap) {
-    BlockValSet blockValSet = blockValSetMap.get(_column);
+  public void aggregate(int length, AggregationResultHolder aggregationResultHolder,
+      Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+    BlockValSet blockValSet = blockValSetMap.get(_expression);
     if (blockValSet.getValueType() != DataType.BYTES) {
       double[] doubleValues = blockValSet.getDoubleValuesSV();
       TDigest tDigest = getDefaultTDigest(aggregationResultHolder);
@@ -126,8 +105,8 @@ public class PercentileTDigestAggregationFunction implements AggregationFunction
 
   @Override
   public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder,
-      Map<String, BlockValSet> blockValSetMap) {
-    BlockValSet blockValSet = blockValSetMap.get(_column);
+      Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+    BlockValSet blockValSet = blockValSetMap.get(_expression);
     if (blockValSet.getValueType() != DataType.BYTES) {
       double[] doubleValues = blockValSet.getDoubleValuesSV();
       for (int i = 0; i < length; i++) {
@@ -151,8 +130,8 @@ public class PercentileTDigestAggregationFunction implements AggregationFunction
 
   @Override
   public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder groupByResultHolder,
-      Map<String, BlockValSet> blockValSetMap) {
-    BlockValSet blockValSet = blockValSetMap.get(_column);
+      Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+    BlockValSet blockValSet = blockValSetMap.get(_expression);
     if (blockValSet.getValueType() != DataType.BYTES) {
       double[] doubleValues = blockValSet.getDoubleValuesSV();
       for (int i = 0; i < length; i++) {
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileTDigestMVAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileTDigestMVAggregationFunction.java
index d71a7fe..01f4039 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileTDigestMVAggregationFunction.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileTDigestMVAggregationFunction.java
@@ -19,9 +19,9 @@
 package org.apache.pinot.core.query.aggregation.function;
 
 import com.tdunning.math.stats.TDigest;
-import java.util.List;
 import java.util.Map;
 import org.apache.pinot.common.function.AggregationFunctionType;
+import org.apache.pinot.common.request.transform.TransformExpressionTree;
 import org.apache.pinot.core.common.BlockValSet;
 import org.apache.pinot.core.query.aggregation.AggregationResultHolder;
 import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
@@ -29,17 +29,8 @@ import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
 
 public class PercentileTDigestMVAggregationFunction extends PercentileTDigestAggregationFunction {
 
-  /**
-   * Constructor for the class.
-   *
-   * @param arguments List of arguments.
-   *                  <ul>
-   *                  <li> Arg 0: Column name to aggregate.</li>
-   *                  <li> Arg 1: Percentile to compute. </li>
-   *                  </ul>
-   */
-  public PercentileTDigestMVAggregationFunction(List<String> arguments) {
-    super(arguments);
+  public PercentileTDigestMVAggregationFunction(String column, int percentile) {
+    super(column, percentile);
   }
 
   @Override
@@ -63,8 +54,9 @@ public class PercentileTDigestMVAggregationFunction extends PercentileTDigestAgg
   }
 
   @Override
-  public void aggregate(int length, AggregationResultHolder aggregationResultHolder, Map<String, BlockValSet> blockValSetMap) {
-    double[][] valuesArray = blockValSetMap.get(_column).getDoubleValuesMV();
+  public void aggregate(int length, AggregationResultHolder aggregationResultHolder,
+      Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+    double[][] valuesArray = blockValSetMap.get(_expression).getDoubleValuesMV();
     TDigest tDigest = getDefaultTDigest(aggregationResultHolder);
     for (int i = 0; i < length; i++) {
       for (double value : valuesArray[i]) {
@@ -75,8 +67,8 @@ public class PercentileTDigestMVAggregationFunction extends PercentileTDigestAgg
 
   @Override
   public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder,
-      Map<String, BlockValSet> blockValSetMap) {
-    double[][] valuesArray = blockValSetMap.get(_column).getDoubleValuesMV();
+      Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+    double[][] valuesArray = blockValSetMap.get(_expression).getDoubleValuesMV();
     for (int i = 0; i < length; i++) {
       TDigest tDigest = getDefaultTDigest(groupByResultHolder, groupKeyArray[i]);
       for (double value : valuesArray[i]) {
@@ -87,8 +79,8 @@ public class PercentileTDigestMVAggregationFunction extends PercentileTDigestAgg
 
   @Override
   public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder groupByResultHolder,
-      Map<String, BlockValSet> blockValSetMap) {
-    double[][] valuesArray = blockValSetMap.get(_column).getDoubleValuesMV();
+      Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+    double[][] valuesArray = blockValSetMap.get(_expression).getDoubleValuesMV();
     for (int i = 0; i < length; i++) {
       double[] values = valuesArray[i];
       for (int groupKey : groupKeysArray[i]) {
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/SumAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/SumAggregationFunction.java
index 2416919..7e71c21 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/SumAggregationFunction.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/SumAggregationFunction.java
@@ -18,8 +18,6 @@
  */
 package org.apache.pinot.core.query.aggregation.function;
 
-import java.util.Collections;
-import java.util.List;
 import java.util.Map;
 import org.apache.pinot.common.function.AggregationFunctionType;
 import org.apache.pinot.common.request.transform.TransformExpressionTree;
@@ -31,18 +29,11 @@ import org.apache.pinot.core.query.aggregation.groupby.DoubleGroupByResultHolder
 import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
 
 
-public class SumAggregationFunction implements AggregationFunction<Double, Double> {
+public class SumAggregationFunction extends BaseSingleInputAggregationFunction<Double, Double> {
   private static final double DEFAULT_VALUE = 0.0;
-  protected final String _column;
-  private final List<TransformExpressionTree> _inputExpressions;
 
-  /**
-   * Constructor for the class.
-   * @param column Column name to aggregate on.
-   */
   public SumAggregationFunction(String column) {
-    _column = column;
-    _inputExpressions = Collections.singletonList(TransformExpressionTree.compileToExpressionTree(_column));
+    super(column);
   }
 
   @Override
@@ -51,21 +42,6 @@ public class SumAggregationFunction implements AggregationFunction<Double, Doubl
   }
 
   @Override
-  public String getColumnName() {
-    return getType().getName() + "_" + _column;
-  }
-
-  @Override
-  public String getResultColumnName() {
-    return getType().getName().toLowerCase() + "(" + _column + ")";
-  }
-
-  @Override
-  public List<TransformExpressionTree> getInputExpressions() {
-    return _inputExpressions;
-  }
-
-  @Override
   public void accept(AggregationFunctionVisitorBase visitor) {
     visitor.visit(this);
   }
@@ -81,8 +57,9 @@ public class SumAggregationFunction implements AggregationFunction<Double, Doubl
   }
 
   @Override
-  public void aggregate(int length, AggregationResultHolder aggregationResultHolder, Map<String, BlockValSet> blockValSetMap) {
-    double[] valueArray = blockValSetMap.get(_column).getDoubleValuesSV();
+  public void aggregate(int length, AggregationResultHolder aggregationResultHolder,
+      Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+    double[] valueArray = blockValSetMap.get(_expression).getDoubleValuesSV();
     double sum = aggregationResultHolder.getDoubleResult();
     for (int i = 0; i < length; i++) {
       sum += valueArray[i];
@@ -92,8 +69,8 @@ public class SumAggregationFunction implements AggregationFunction<Double, Doubl
 
   @Override
   public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder,
-      Map<String, BlockValSet> blockValSetMap) {
-    double[] valueArray = blockValSetMap.get(_column).getDoubleValuesSV();
+      Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+    double[] valueArray = blockValSetMap.get(_expression).getDoubleValuesSV();
     for (int i = 0; i < length; i++) {
       int groupKey = groupKeyArray[i];
       groupByResultHolder.setValueForKey(groupKey, groupByResultHolder.getDoubleResult(groupKey) + valueArray[i]);
@@ -102,8 +79,8 @@ public class SumAggregationFunction implements AggregationFunction<Double, Doubl
 
   @Override
   public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder groupByResultHolder,
-      Map<String, BlockValSet> blockValSetMap) {
-    double[] valueArray = blockValSetMap.get(_column).getDoubleValuesSV();
+      Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+    double[] valueArray = blockValSetMap.get(_expression).getDoubleValuesSV();
     for (int i = 0; i < length; i++) {
       double value = valueArray[i];
       for (int groupKey : groupKeysArray[i]) {
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/SumMVAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/SumMVAggregationFunction.java
index 8c468a7..7d7f9a0 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/SumMVAggregationFunction.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/SumMVAggregationFunction.java
@@ -20,6 +20,7 @@ package org.apache.pinot.core.query.aggregation.function;
 
 import java.util.Map;
 import org.apache.pinot.common.function.AggregationFunctionType;
+import org.apache.pinot.common.request.transform.TransformExpressionTree;
 import org.apache.pinot.core.common.BlockValSet;
 import org.apache.pinot.core.query.aggregation.AggregationResultHolder;
 import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
@@ -27,10 +28,6 @@ import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
 
 public class SumMVAggregationFunction extends SumAggregationFunction {
 
-  /**
-   * Constructor for the class.
-   * @param column Column name to aggregate on.
-   */
   public SumMVAggregationFunction(String column) {
     super(column);
   }
@@ -46,8 +43,9 @@ public class SumMVAggregationFunction extends SumAggregationFunction {
   }
 
   @Override
-  public void aggregate(int length, AggregationResultHolder aggregationResultHolder, Map<String, BlockValSet> blockValSetMap) {
-    double[][] valuesArray = blockValSetMap.get(_column).getDoubleValuesMV();
+  public void aggregate(int length, AggregationResultHolder aggregationResultHolder,
+      Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+    double[][] valuesArray = blockValSetMap.get(_expression).getDoubleValuesMV();
     double sum = aggregationResultHolder.getDoubleResult();
     for (int i = 0; i < length; i++) {
       for (double value : valuesArray[i]) {
@@ -59,8 +57,8 @@ public class SumMVAggregationFunction extends SumAggregationFunction {
 
   @Override
   public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder,
-      Map<String, BlockValSet> blockValSetMap) {
-    double[][] valuesArray = blockValSetMap.get(_column).getDoubleValuesMV();
+      Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+    double[][] valuesArray = blockValSetMap.get(_expression).getDoubleValuesMV();
     for (int i = 0; i < length; i++) {
       int groupKey = groupKeyArray[i];
       double sum = groupByResultHolder.getDoubleResult(groupKey);
@@ -73,8 +71,8 @@ public class SumMVAggregationFunction extends SumAggregationFunction {
 
   @Override
   public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder groupByResultHolder,
-      Map<String, BlockValSet> blockValSetMap) {
-    double[][] valuesArray = blockValSetMap.get(_column).getDoubleValuesMV();
+      Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+    double[][] valuesArray = blockValSetMap.get(_expression).getDoubleValuesMV();
     for (int i = 0; i < length; i++) {
       double[] values = valuesArray[i];
       for (int groupKey : groupKeysArray[i]) {
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/DefaultGroupByExecutor.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/DefaultGroupByExecutor.java
index dea2bd8..a2d561a 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/DefaultGroupByExecutor.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/DefaultGroupByExecutor.java
@@ -18,21 +18,15 @@
  */
 package org.apache.pinot.core.query.aggregation.groupby;
 
-import java.util.Collections;
-import java.util.HashMap;
-import java.util.List;
 import java.util.Map;
-import javax.annotation.Nonnull;
-import org.apache.pinot.common.function.AggregationFunctionType;
-import org.apache.pinot.common.request.GroupBy;
 import org.apache.pinot.common.request.transform.TransformExpressionTree;
 import org.apache.pinot.core.common.BlockValSet;
 import org.apache.pinot.core.operator.blocks.TransformBlock;
 import org.apache.pinot.core.operator.transform.TransformOperator;
 import org.apache.pinot.core.operator.transform.TransformResultMetadata;
 import org.apache.pinot.core.plan.DocIdSetPlanNode;
-import org.apache.pinot.core.query.aggregation.AggregationFunctionContext;
 import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
+import org.apache.pinot.core.query.aggregation.function.AggregationFunctionUtils;
 
 
 /**
@@ -51,62 +45,39 @@ public class DefaultGroupByExecutor implements GroupByExecutor {
   private static final ThreadLocal<int[][]> THREAD_LOCAL_MV_GROUP_KEYS =
       ThreadLocal.withInitial(() -> new int[DocIdSetPlanNode.MAX_DOC_PER_CALL][]);
 
-  protected final int _numFunctions;
-  protected final AggregationFunction[] _functions;
-  protected final TransformExpressionTree[][] _aggregationExpressions;
+  protected final AggregationFunction[] _aggregationFunctions;
   protected final GroupKeyGenerator _groupKeyGenerator;
-  protected final GroupByResultHolder[] _resultHolders;
+  protected final GroupByResultHolder[] _groupByResultHolders;
   protected final boolean _hasMVGroupByExpression;
-  protected final boolean _hasNoDictionaryGroupByExpression;
   protected final int[] _svGroupKeys;
   protected final int[][] _mvGroupKeys;
 
   /**
    * Constructor for the class.
    *
-   * @param functionContexts Array of aggregation functions
-   * @param groupBy Group by from broker request
+   * @param aggregationFunctions Array of aggregation functions
+   * @param groupByExpressions Array of group-by expressions
    * @param maxInitialResultHolderCapacity Maximum initial capacity for the result holder
    * @param numGroupsLimit Limit on number of aggregation groups returned in the result
    * @param transformOperator Transform operator
    */
-  public DefaultGroupByExecutor(@Nonnull AggregationFunctionContext[] functionContexts, @Nonnull GroupBy groupBy,
-      int maxInitialResultHolderCapacity, int numGroupsLimit, @Nonnull TransformOperator transformOperator) {
-    // Initialize aggregation functions and expressions
-    _numFunctions = functionContexts.length;
-    _functions = new AggregationFunction[_numFunctions];
-    _aggregationExpressions = new TransformExpressionTree[_numFunctions][];
+  public DefaultGroupByExecutor(AggregationFunction[] aggregationFunctions,
+      TransformExpressionTree[] groupByExpressions, int maxInitialResultHolderCapacity, int numGroupsLimit,
+      TransformOperator transformOperator) {
+    _aggregationFunctions = aggregationFunctions;
 
-    for (int i = 0; i < _numFunctions; i++) {
-      AggregationFunction function = functionContexts[i].getAggregationFunction();
-      _functions[i] = function;
-
-      if (function.getType() != AggregationFunctionType.COUNT) {
-        List<String> expressions = functionContexts[i].getExpressions();
-
-        List<TransformExpressionTree> inputExpressions = function.getInputExpressions();
-        _aggregationExpressions[i] = inputExpressions.toArray(new TransformExpressionTree[0]);
-      }
-    }
-
-    // Initialize group-by expressions
-    List<String> groupByExpressionStrings = groupBy.getExpressions();
-    int numGroupByExpressions = groupByExpressionStrings.size();
     boolean hasMVGroupByExpression = false;
     boolean hasNoDictionaryGroupByExpression = false;
-    TransformExpressionTree[] groupByExpressions = new TransformExpressionTree[numGroupByExpressions];
-    for (int i = 0; i < numGroupByExpressions; i++) {
-      groupByExpressions[i] = TransformExpressionTree.compileToExpressionTree(groupByExpressionStrings.get(i));
-      TransformResultMetadata transformResultMetadata = transformOperator.getResultMetadata(groupByExpressions[i]);
+    for (TransformExpressionTree groupByExpression : groupByExpressions) {
+      TransformResultMetadata transformResultMetadata = transformOperator.getResultMetadata(groupByExpression);
       hasMVGroupByExpression |= !transformResultMetadata.isSingleValue();
       hasNoDictionaryGroupByExpression |= !transformResultMetadata.hasDictionary();
     }
     _hasMVGroupByExpression = hasMVGroupByExpression;
-    _hasNoDictionaryGroupByExpression = hasNoDictionaryGroupByExpression;
 
     // Initialize group key generator
-    if (_hasNoDictionaryGroupByExpression) {
-      if (numGroupByExpressions == 1) {
+    if (hasNoDictionaryGroupByExpression) {
+      if (groupByExpressions.length == 1) {
         _groupKeyGenerator =
             new NoDictionarySingleColumnGroupKeyGenerator(transformOperator, groupByExpressions[0], numGroupsLimit);
       } else {
@@ -121,9 +92,10 @@ public class DefaultGroupByExecutor implements GroupByExecutor {
     // Initialize result holders
     int maxNumResults = _groupKeyGenerator.getGlobalGroupKeyUpperBound();
     int initialCapacity = Math.min(maxNumResults, maxInitialResultHolderCapacity);
-    _resultHolders = new GroupByResultHolder[_numFunctions];
-    for (int i = 0; i < _numFunctions; i++) {
-      _resultHolders[i] = _functions[i].createGroupByResultHolder(initialCapacity, maxNumResults);
+    int numAggregationFunctions = aggregationFunctions.length;
+    _groupByResultHolders = new GroupByResultHolder[numAggregationFunctions];
+    for (int i = 0; i < numAggregationFunctions; i++) {
+      _groupByResultHolders[i] = _aggregationFunctions[i].createGroupByResultHolder(initialCapacity, maxNumResults);
     }
 
     // Initialize map from document Id to group key
@@ -137,7 +109,7 @@ public class DefaultGroupByExecutor implements GroupByExecutor {
   }
 
   @Override
-  public void process(@Nonnull TransformBlock transformBlock) {
+  public void process(TransformBlock transformBlock) {
     // Generate group keys
     // NOTE: groupKeyGenerator will limit the number of groups. Once reaching limit, no new group will be generated
     if (_hasMVGroupByExpression) {
@@ -146,41 +118,31 @@ public class DefaultGroupByExecutor implements GroupByExecutor {
       _groupKeyGenerator.generateKeysForBlock(transformBlock, _svGroupKeys);
     }
 
-    int length = transformBlock.getNumDocs();
     int capacityNeeded = _groupKeyGenerator.getCurrentGroupKeyUpperBound();
-    for (int i = 0; i < _numFunctions; i++) {
-      GroupByResultHolder resultHolder = _resultHolders[i];
-      resultHolder.ensureCapacity(capacityNeeded);
+    int length = transformBlock.getNumDocs();
+    int numAggregationFunctions = _aggregationFunctions.length;
+    for (int i = 0; i < numAggregationFunctions; i++) {
+      GroupByResultHolder groupByResultHolder = _groupByResultHolders[i];
+      groupByResultHolder.ensureCapacity(capacityNeeded);
       aggregate(transformBlock, length, i);
     }
   }
 
-  protected void aggregate(@Nonnull TransformBlock transformBlock, int length, int functionIndex) {
-    AggregationFunction function = _functions[functionIndex];
-    GroupByResultHolder resultHolder = _resultHolders[functionIndex];
+  protected void aggregate(TransformBlock transformBlock, int length, int functionIndex) {
+    AggregationFunction aggregationFunction = _aggregationFunctions[functionIndex];
+    Map<TransformExpressionTree, BlockValSet> blockValSetMap =
+        AggregationFunctionUtils.getBlockValSetMap(aggregationFunction, transformBlock);
 
-    if (function.getType() == AggregationFunctionType.COUNT) {
-      if (_hasMVGroupByExpression) {
-        function.aggregateGroupByMV(length, _mvGroupKeys, resultHolder, Collections.emptyMap());
-      } else {
-        function.aggregateGroupBySV(length, _svGroupKeys, resultHolder, Collections.emptyMap());
-      }
+    GroupByResultHolder groupByResultHolder = _groupByResultHolders[functionIndex];
+    if (_hasMVGroupByExpression) {
+      aggregationFunction.aggregateGroupByMV(length, _mvGroupKeys, groupByResultHolder, blockValSetMap);
     } else {
-      Map<String, BlockValSet> blockValSetMap = new HashMap<>();
-      for (int i = 0; i < _aggregationExpressions[functionIndex].length; i++) {
-        TransformExpressionTree aggregationExpression = _aggregationExpressions[functionIndex][i];
-        blockValSetMap.put(aggregationExpression.toString(), transformBlock.getBlockValueSet(aggregationExpression));
-      }
-      if (_hasMVGroupByExpression) {
-        function.aggregateGroupByMV(length, _mvGroupKeys, resultHolder, blockValSetMap);
-      } else {
-        function.aggregateGroupBySV(length, _svGroupKeys, resultHolder, blockValSetMap);
-      }
+      aggregationFunction.aggregateGroupBySV(length, _svGroupKeys, groupByResultHolder, blockValSetMap);
     }
   }
 
   @Override
   public AggregationGroupByResult getResult() {
-    return new AggregationGroupByResult(_groupKeyGenerator, _functions, _resultHolders);
+    return new AggregationGroupByResult(_groupKeyGenerator, _aggregationFunctions, _groupByResultHolders);
   }
 }
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/GroupByExecutor.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/GroupByExecutor.java
index 75ef1a1..c1266c5 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/GroupByExecutor.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/GroupByExecutor.java
@@ -18,7 +18,6 @@
  */
 package org.apache.pinot.core.query.aggregation.groupby;
 
-import javax.annotation.Nonnull;
 import org.apache.pinot.core.operator.blocks.TransformBlock;
 
 
@@ -32,7 +31,7 @@ public interface GroupByExecutor {
    *
    * @param transformBlock Transform block
    */
-  void process(@Nonnull TransformBlock transformBlock);
+  void process(TransformBlock transformBlock);
 
   /**
    * Returns the result of group-by aggregation.
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/AggregationDataTableReducer.java b/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/AggregationDataTableReducer.java
index 3bde93a..825ebf2 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/AggregationDataTableReducer.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/AggregationDataTableReducer.java
@@ -20,19 +20,16 @@ package org.apache.pinot.core.query.reduce;
 
 import java.io.Serializable;
 import java.util.ArrayList;
-import java.util.Collection;
 import java.util.Collections;
 import java.util.List;
 import java.util.Map;
 import org.apache.pinot.common.metrics.BrokerMetrics;
-import org.apache.pinot.common.request.AggregationInfo;
 import org.apache.pinot.common.request.BrokerRequest;
 import org.apache.pinot.common.response.broker.AggregationResult;
 import org.apache.pinot.common.response.broker.BrokerResponseNative;
 import org.apache.pinot.common.response.broker.ResultTable;
 import org.apache.pinot.common.utils.DataSchema;
 import org.apache.pinot.common.utils.DataTable;
-import org.apache.pinot.core.query.aggregation.AggregationFunctionContext;
 import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
 import org.apache.pinot.core.query.aggregation.function.AggregationFunctionUtils;
 import org.apache.pinot.core.transport.ServerRoutingInstance;
@@ -44,16 +41,12 @@ import org.apache.pinot.core.util.QueryOptions;
  */
 public class AggregationDataTableReducer implements DataTableReducer {
   private final AggregationFunction[] _aggregationFunctions;
-  private final List<AggregationInfo> _aggregationInfos;
-  private final int _numAggregationFunctions;
   private final boolean _preserveType;
   private final boolean _responseFormatSql;
 
   AggregationDataTableReducer(BrokerRequest brokerRequest, AggregationFunction[] aggregationFunctions,
       QueryOptions queryOptions) {
     _aggregationFunctions = aggregationFunctions;
-    _aggregationInfos = brokerRequest.getAggregationsInfo();
-    _numAggregationFunctions = aggregationFunctions.length;
     _preserveType = queryOptions.isPreserveType();
     _responseFormatSql = queryOptions.isResponseFormatSQL();
   }
@@ -67,7 +60,6 @@ public class AggregationDataTableReducer implements DataTableReducer {
   public void reduceAndSetResults(String tableName, DataSchema dataSchema,
       Map<ServerRoutingInstance, DataTable> dataTableMap, BrokerResponseNative brokerResponseNative,
       BrokerMetrics brokerMetrics) {
-
     if (dataTableMap.isEmpty()) {
       if (_responseFormatSql) {
         DataSchema finalDataSchema = getResultTableDataSchema();
@@ -76,14 +68,11 @@ public class AggregationDataTableReducer implements DataTableReducer {
       return;
     }
 
-    assert dataSchema != null;
-
-    Collection<DataTable> dataTables = dataTableMap.values();
-
-    // Merge results from all data tables.
-    Object[] intermediateResults = new Object[_numAggregationFunctions];
-    for (DataTable dataTable : dataTables) {
-      for (int i = 0; i < _numAggregationFunctions; i++) {
+    // Merge results from all data tables
+    int numAggregationFunctions = _aggregationFunctions.length;
+    Object[] intermediateResults = new Object[numAggregationFunctions];
+    for (DataTable dataTable : dataTableMap.values()) {
+      for (int i = 0; i < numAggregationFunctions; i++) {
         Object intermediateResultToMerge;
         DataSchema.ColumnDataType columnDataType = dataSchema.getColumnDataType(i);
         switch (columnDataType) {
@@ -120,8 +109,9 @@ public class AggregationDataTableReducer implements DataTableReducer {
    */
   private ResultTable reduceToResultTable(Object[] intermediateResults) {
     List<Object[]> rows = new ArrayList<>(1);
-    Object[] row = new Object[_numAggregationFunctions];
-    for (int i = 0; i < _numAggregationFunctions; i++) {
+    int numAggregationFunctions = _aggregationFunctions.length;
+    Object[] row = new Object[numAggregationFunctions];
+    for (int i = 0; i < numAggregationFunctions; i++) {
       row[i] = _aggregationFunctions[i].extractFinalResult(intermediateResults[i]);
     }
     rows.add(row);
@@ -134,9 +124,10 @@ public class AggregationDataTableReducer implements DataTableReducer {
    * Sets aggregation results into AggregationResults
    */
   private List<AggregationResult> reduceToAggregationResult(Object[] intermediateResults, DataSchema dataSchema) {
-    // Extract final results and set them into the broker response.
-    List<AggregationResult> reducedAggregationResults = new ArrayList<>(_numAggregationFunctions);
-    for (int i = 0; i < _numAggregationFunctions; i++) {
+    // Extract final results and set them into the broker response
+    int numAggregationFunctions = _aggregationFunctions.length;
+    List<AggregationResult> reducedAggregationResults = new ArrayList<>(numAggregationFunctions);
+    for (int i = 0; i < numAggregationFunctions; i++) {
       Serializable resultValue = AggregationFunctionUtils
           .getSerializableValue(_aggregationFunctions[i].extractFinalResult(intermediateResults[i]));
 
@@ -153,13 +144,13 @@ public class AggregationDataTableReducer implements DataTableReducer {
    * Constructs the data schema for the final results table
    */
   private DataSchema getResultTableDataSchema() {
-    String[] finalColumnNames = new String[_numAggregationFunctions];
-    DataSchema.ColumnDataType[] finalColumnDataTypes = new DataSchema.ColumnDataType[_numAggregationFunctions];
-    for (int i = 0; i < _numAggregationFunctions; i++) {
-      AggregationFunctionContext aggregationFunctionContext =
-          AggregationFunctionUtils.getAggregationFunctionContext(_aggregationInfos.get(i));
-      finalColumnNames[i] = aggregationFunctionContext.getResultColumnName();
-      finalColumnDataTypes[i] = _aggregationFunctions[i].getFinalResultColumnType();
+    int numAggregationFunctions = _aggregationFunctions.length;
+    String[] finalColumnNames = new String[numAggregationFunctions];
+    DataSchema.ColumnDataType[] finalColumnDataTypes = new DataSchema.ColumnDataType[numAggregationFunctions];
+    for (int i = 0; i < numAggregationFunctions; i++) {
+      AggregationFunction aggregationFunction = _aggregationFunctions[i];
+      finalColumnNames[i] = aggregationFunction.getResultColumnName();
+      finalColumnDataTypes[i] = aggregationFunction.getFinalResultColumnType();
     }
     return new DataSchema(finalColumnNames, finalColumnDataTypes);
   }
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/CombineService.java b/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/CombineService.java
index 79682e9..0e265e8 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/CombineService.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/CombineService.java
@@ -27,7 +27,7 @@ import org.apache.pinot.common.request.Selection;
 import org.apache.pinot.common.response.ProcessingException;
 import org.apache.pinot.common.utils.DataSchema;
 import org.apache.pinot.core.operator.blocks.IntermediateResultsBlock;
-import org.apache.pinot.core.query.aggregation.AggregationFunctionContext;
+import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
 import org.apache.pinot.core.query.selection.SelectionOperatorUtils;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -68,19 +68,19 @@ public class CombineService {
           return;
         }
 
-        AggregationFunctionContext[] mergedAggregationFunctionContexts = mergedBlock.getAggregationFunctionContexts();
-        if (mergedAggregationFunctionContexts == null) {
+        AggregationFunction[] mergedAggregationFunctions = mergedBlock.getAggregationFunctions();
+        if (mergedAggregationFunctions == null) {
           // No data in merged block.
-          mergedBlock.setAggregationFunctionContexts(blockToMerge.getAggregationFunctionContexts());
+          mergedBlock.setAggregationFunctions(blockToMerge.getAggregationFunctions());
           mergedBlock.setAggregationResults(aggregationResultToMerge);
-        }
-
-        // Merge two block.
-        List<Object> mergedAggregationResult = mergedBlock.getAggregationResult();
-        int numAggregationFunctions = mergedAggregationFunctionContexts.length;
-        for (int i = 0; i < numAggregationFunctions; i++) {
-          mergedAggregationResult.set(i, mergedAggregationFunctionContexts[i].getAggregationFunction()
-              .merge(mergedAggregationResult.get(i), aggregationResultToMerge.get(i)));
+        } else {
+          // Merge two blocks.
+          List<Object> mergedAggregationResult = mergedBlock.getAggregationResult();
+          int numAggregationFunctions = mergedAggregationFunctions.length;
+          for (int i = 0; i < numAggregationFunctions; i++) {
+            mergedAggregationResult.set(i,
+                mergedAggregationFunctions[i].merge(mergedAggregationResult.get(i), aggregationResultToMerge.get(i)));
+          }
         }
       } else {
         // Combine aggregation group-by result, which should not come into CombineService.
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/ComparisonFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/ComparisonFunction.java
index ec96f0e..9841c10 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/ComparisonFunction.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/ComparisonFunction.java
@@ -19,7 +19,7 @@
 package org.apache.pinot.core.query.reduce;
 
 import org.apache.pinot.common.request.AggregationInfo;
-import org.apache.pinot.core.query.aggregation.function.AggregationFunctionUtils;
+import org.apache.pinot.core.query.aggregation.function.AggregationFunctionFactory;
 
 
 //This class will be inherited by different classes that compare (e.g., for equality) the input value by the base value
@@ -27,8 +27,7 @@ public abstract class ComparisonFunction {
   private final String _functionExpression;
 
   protected ComparisonFunction(AggregationInfo aggregationInfo) {
-    _functionExpression =
-        AggregationFunctionUtils.getAggregationFunctionContext(aggregationInfo).getAggregationColumnName();
+    _functionExpression = AggregationFunctionFactory.getAggregationFunction(aggregationInfo, null).getColumnName();
   }
 
   public abstract boolean isComparisonValid(String aggResult);
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/DistinctDataTableReducer.java b/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/DistinctDataTableReducer.java
index de5db1c..9c5f222 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/DistinctDataTableReducer.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/DistinctDataTableReducer.java
@@ -76,8 +76,7 @@ public class DistinctDataTableReducer implements DataTableReducer {
         DataSchema finalDataSchema = getEmptyResultTableDataSchema();
         brokerResponseNative.setResultTable(new ResultTable(finalDataSchema, Collections.emptyList()));
       } else {
-        brokerResponseNative
-            .setSelectionResults(new SelectionResults(getDistinctColumns(), Collections.emptyList()));
+        brokerResponseNative.setSelectionResults(new SelectionResults(getDistinctColumns(), Collections.emptyList()));
       }
       return;
     }
@@ -163,13 +162,15 @@ public class DistinctDataTableReducer implements DataTableReducer {
   }
 
   private List<String> getDistinctColumns() {
-    return AggregationFunctionUtils.getAggregationExpressions(_brokerRequest.getAggregationsInfo().get(0));
+    return AggregationFunctionUtils.getArguments(_brokerRequest.getAggregationsInfo().get(0));
   }
 
   private DataSchema getEmptyResultTableDataSchema() {
-    String[] columns = getDistinctColumns().toArray(new String[0]);
-    DataSchema.ColumnDataType[] columnDataTypes = new DataSchema.ColumnDataType[columns.length];
+    List<String> distinctColumns = getDistinctColumns();
+    int numColumns = distinctColumns.size();
+    String[] columnNames = distinctColumns.toArray(new String[numColumns]);
+    DataSchema.ColumnDataType[] columnDataTypes = new DataSchema.ColumnDataType[numColumns];
     Arrays.fill(columnDataTypes, DataSchema.ColumnDataType.STRING);
-    return new DataSchema(columns, columnDataTypes);
+    return new DataSchema(columnNames, columnDataTypes);
   }
 }
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/GroupByDataTableReducer.java b/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/GroupByDataTableReducer.java
index 97c2019..98e6a01 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/GroupByDataTableReducer.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/GroupByDataTableReducer.java
@@ -31,7 +31,6 @@ import java.util.function.BiFunction;
 import org.apache.calcite.sql.SqlKind;
 import org.apache.pinot.common.metrics.BrokerMeter;
 import org.apache.pinot.common.metrics.BrokerMetrics;
-import org.apache.pinot.common.request.AggregationInfo;
 import org.apache.pinot.common.request.BrokerRequest;
 import org.apache.pinot.common.request.Expression;
 import org.apache.pinot.common.request.GroupBy;
@@ -48,7 +47,6 @@ import org.apache.pinot.common.utils.request.RequestUtils;
 import org.apache.pinot.core.data.table.ConcurrentIndexedTable;
 import org.apache.pinot.core.data.table.IndexedTable;
 import org.apache.pinot.core.data.table.Record;
-import org.apache.pinot.core.query.aggregation.AggregationFunctionContext;
 import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
 import org.apache.pinot.core.query.aggregation.function.AggregationFunctionUtils;
 import org.apache.pinot.core.query.aggregation.groupby.AggregationGroupByTrimmingService;
@@ -63,8 +61,6 @@ import org.apache.pinot.core.util.QueryOptions;
 public class GroupByDataTableReducer implements DataTableReducer {
   private final BrokerRequest _brokerRequest;
   private final AggregationFunction[] _aggregationFunctions;
-  private final List<AggregationInfo> _aggregationInfos;
-  private final AggregationFunctionContext[] _aggregationFunctionContexts;
   private final List<SelectionSort> _orderBy;
   private final GroupBy _groupBy;
   private final int _numAggregationFunctions;
@@ -80,8 +76,6 @@ public class GroupByDataTableReducer implements DataTableReducer {
       QueryOptions queryOptions) {
     _brokerRequest = brokerRequest;
     _aggregationFunctions = aggregationFunctions;
-    _aggregationInfos = brokerRequest.getAggregationsInfo();
-    _aggregationFunctionContexts = AggregationFunctionUtils.getAggregationFunctionContexts(_brokerRequest);
     _numAggregationFunctions = aggregationFunctions.length;
     _groupBy = brokerRequest.getGroupBy();
     _numGroupBy = _groupBy.getExpressionsSize();
@@ -154,7 +148,7 @@ public class GroupByDataTableReducer implements DataTableReducer {
       // This is the primary PQL compliant group by
 
       boolean[] aggregationFunctionSelectStatus =
-          AggregationFunctionUtils.getAggregationFunctionsSelectStatus(_aggregationInfos);
+          AggregationFunctionUtils.getAggregationFunctionsSelectStatus(_brokerRequest.getAggregationsInfo());
       setGroupByHavingResults(brokerResponseNative, aggregationFunctionSelectStatus, dataTables,
           _brokerRequest.getHavingFilterQuery(), _brokerRequest.getHavingFilterSubQueryMap());
 
@@ -303,10 +297,9 @@ public class GroupByDataTableReducer implements DataTableReducer {
   }
 
   private IndexedTable getIndexedTable(DataSchema dataSchema, Collection<DataTable> dataTables) {
-
     int indexedTableCapacity = GroupByUtils.getTableCapacity(_groupBy, _orderBy);
     IndexedTable indexedTable =
-        new ConcurrentIndexedTable(dataSchema, _aggregationInfos, _orderBy, indexedTableCapacity);
+        new ConcurrentIndexedTable(dataSchema, _aggregationFunctions, _orderBy, indexedTableCapacity);
 
     for (DataTable dataTable : dataTables) {
       BiFunction[] functions = new BiFunction[_numColumns];
@@ -552,8 +545,9 @@ public class GroupByDataTableReducer implements DataTableReducer {
         if (aggregationFunctionsSelectStatus[i]) {
           finalColumnNames[count] = columnNames[i];
           finalOutResultMaps[count] = finalResultMaps[i];
-          finalResultTableAggNames[count] = _aggregationFunctionContexts[i].getResultColumnName();
-          finalAggregationFunctions[count] = _aggregationFunctions[i];
+          AggregationFunction aggregationFunction = _aggregationFunctions[i];
+          finalResultTableAggNames[count] = aggregationFunction.getResultColumnName();
+          finalAggregationFunctions[count] = aggregationFunction;
           count++;
         }
       }
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/request/ServerQueryRequest.java b/pinot-core/src/main/java/org/apache/pinot/core/query/request/ServerQueryRequest.java
index 3865f2e..bf2fa07 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/request/ServerQueryRequest.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/request/ServerQueryRequest.java
@@ -95,7 +95,7 @@ public class ServerQueryRequest {
       _aggregationExpressions = new HashSet<>();
       for (AggregationInfo aggregationInfo : aggregationsInfo) {
         if (!aggregationInfo.getAggregationType().equalsIgnoreCase(AggregationFunctionType.COUNT.getName())) {
-          for (String expressions : AggregationFunctionUtils.getAggregationExpressions(aggregationInfo)) {
+          for (String expressions : AggregationFunctionUtils.getArguments(aggregationInfo)) {
             _aggregationExpressions.add(TransformExpressionTree.compileToExpressionTree(expressions));
           }
         }
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/startree/StarTreeUtils.java b/pinot-core/src/main/java/org/apache/pinot/core/startree/StarTreeUtils.java
index f8e9bd1..3810826 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/startree/StarTreeUtils.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/startree/StarTreeUtils.java
@@ -18,7 +18,6 @@
  */
 package org.apache.pinot.core.startree;
 
-import java.util.Arrays;
 import java.util.HashSet;
 import java.util.Map;
 import java.util.Set;
@@ -27,8 +26,6 @@ import org.apache.pinot.common.request.BrokerRequest;
 import org.apache.pinot.common.request.FilterOperator;
 import org.apache.pinot.common.request.transform.TransformExpressionTree;
 import org.apache.pinot.common.utils.request.FilterQueryTree;
-import org.apache.pinot.core.query.aggregation.AggregationFunctionContext;
-import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
 import org.apache.pinot.core.startree.v2.AggregationFunctionColumnPair;
 import org.apache.pinot.core.startree.v2.StarTreeV2Metadata;
 
@@ -57,8 +54,8 @@ public class StarTreeUtils {
    * </ul>
    */
   public static boolean isFitForStarTree(StarTreeV2Metadata starTreeV2Metadata,
-      Set<AggregationFunctionColumnPair> aggregationFunctionColumnPairs,
-      @Nullable Set<TransformExpressionTree> groupByExpressions, @Nullable FilterQueryTree rootFilterNode) {
+      AggregationFunctionColumnPair[] aggregationFunctionColumnPairs,
+      @Nullable TransformExpressionTree[] groupByExpressions, @Nullable FilterQueryTree rootFilterNode) {
     // Check aggregations
     for (AggregationFunctionColumnPair aggregationFunctionColumnPair : aggregationFunctionColumnPairs) {
       if (!starTreeV2Metadata.containsFunctionColumnPair(aggregationFunctionColumnPair)) {
@@ -102,29 +99,4 @@ public class StarTreeUtils {
     String column = filterNode.getColumn();
     return starTreeDimensions.contains(column);
   }
-
-  /**
-   * Creates a {@link AggregationFunctionContext} from the given context but replace the column with the function-column
-   * pair.
-   */
-  public static AggregationFunctionContext createStarTreeFunctionContext(AggregationFunctionContext functionContext) {
-    AggregationFunction function = functionContext.getAggregationFunction();
-    AggregationFunctionColumnPair functionColumnPair =
-        new AggregationFunctionColumnPair(function.getType(), functionContext.getColumnName());
-    return new AggregationFunctionContext(function, Arrays.asList(functionColumnPair.toColumnName()));
-  }
-
-  /**
-   * Creates an array of {@link AggregationFunctionContext}s from the given contexts but replace the column with the
-   * function-column pair.
-   */
-  public static AggregationFunctionContext[] createStarTreeFunctionContexts(
-      AggregationFunctionContext[] functionContexts) {
-    int numContexts = functionContexts.length;
-    AggregationFunctionContext[] starTreeFunctionContexts = new AggregationFunctionContext[numContexts];
-    for (int i = 0; i < numContexts; i++) {
-      starTreeFunctionContexts[i] = createStarTreeFunctionContext(functionContexts[i]);
-    }
-    return starTreeFunctionContexts;
-  }
 }
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/startree/executor/StarTreeAggregationExecutor.java b/pinot-core/src/main/java/org/apache/pinot/core/startree/executor/StarTreeAggregationExecutor.java
index daf9b64..110efb2 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/startree/executor/StarTreeAggregationExecutor.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/startree/executor/StarTreeAggregationExecutor.java
@@ -18,18 +18,10 @@
  */
 package org.apache.pinot.core.startree.executor;
 
-import java.util.Collections;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
-import org.apache.pinot.common.function.AggregationFunctionType;
-import org.apache.pinot.core.common.BlockValSet;
 import org.apache.pinot.core.operator.blocks.TransformBlock;
-import org.apache.pinot.core.query.aggregation.AggregationFunctionContext;
-import org.apache.pinot.core.query.aggregation.AggregationResultHolder;
 import org.apache.pinot.core.query.aggregation.DefaultAggregationExecutor;
 import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
-import org.apache.pinot.core.startree.StarTreeUtils;
+import org.apache.pinot.core.query.aggregation.function.AggregationFunctionUtils;
 import org.apache.pinot.core.startree.v2.AggregationFunctionColumnPair;
 
 
@@ -37,50 +29,31 @@ import org.apache.pinot.core.startree.v2.AggregationFunctionColumnPair;
  * The <code>StarTreeAggregationExecutor</code> class is the aggregation executor for star-tree index.
  * <ul>
  *   <li>The column in function context is function-column pair</li>
- *   <li>No UDF in aggregation</li>
+ *   <li>No transform function in aggregation</li>
  *   <li>For <code>COUNT</code> aggregation function, we need to aggregate on the pre-aggregated column</li>
  * </ul>
  */
 public class StarTreeAggregationExecutor extends DefaultAggregationExecutor {
-  // StarTree converts column names from min(col) to min__col, this is to store the original mapping.
-  private final String[][] _functionArgs;
+  private final AggregationFunctionColumnPair[] _aggregationFunctionColumnPairs;
 
-  public StarTreeAggregationExecutor(AggregationFunctionContext[] functionContexts) {
-    super(StarTreeUtils.createStarTreeFunctionContexts(functionContexts));
+  public StarTreeAggregationExecutor(AggregationFunction[] aggregationFunctions) {
+    super(aggregationFunctions);
 
-    _functionArgs = new String[functionContexts.length][];
-    for (int i = 0; i < functionContexts.length; i++) {
-      List<String> expressions = functionContexts[i].getExpressions();
-      _functionArgs[i] = new String[expressions.size()];
-
-      for (int j = 0; j < expressions.size(); j++) {
-        _functionArgs[i][j] = expressions.get(j);
-      }
+    int numAggregationFunctions = aggregationFunctions.length;
+    _aggregationFunctionColumnPairs = new AggregationFunctionColumnPair[numAggregationFunctions];
+    for (int i = 0; i < numAggregationFunctions; i++) {
+      _aggregationFunctionColumnPairs[i] =
+          AggregationFunctionUtils.getAggregationFunctionColumnPair(aggregationFunctions[i]);
     }
   }
 
   @Override
   public void aggregate(TransformBlock transformBlock) {
+    int numAggregationFunctions = _aggregationFunctions.length;
     int length = transformBlock.getNumDocs();
-    for (int i = 0; i < _numFunctions; i++) {
-      AggregationFunction function = _functions[i];
-      AggregationResultHolder resultHolder = _resultHolders[i];
-
-      AggregationFunctionType functionType = function.getType();
-      if ((functionType == AggregationFunctionType.COUNT)) {
-        BlockValSet blockValueSet =
-            transformBlock.getBlockValueSet(AggregationFunctionColumnPair.COUNT_STAR_COLUMN_NAME);
-        function.aggregate(length, resultHolder, Collections.singletonMap(_functionArgs[i][0], blockValueSet));
-      } else {
-
-        Map<String, BlockValSet> blockValSetMap = new HashMap<>();
-        for (int j = 0; j < _functionArgs[i].length; j++) {
-          blockValSetMap.put(_functionArgs[i][j], transformBlock.getBlockValueSet(
-              AggregationFunctionColumnPair.toColumnName(functionType, _expressions[i][j].getValue())));
-        }
-
-        function.aggregate(length, resultHolder, blockValSetMap);
-      }
+    for (int i = 0; i < numAggregationFunctions; i++) {
+      _aggregationFunctions[i].aggregate(length, _aggregationResultHolders[i],
+          AggregationFunctionUtils.getBlockValSetMap(_aggregationFunctionColumnPairs[i], transformBlock));
     }
   }
 }
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/startree/executor/StarTreeGroupByExecutor.java b/pinot-core/src/main/java/org/apache/pinot/core/startree/executor/StarTreeGroupByExecutor.java
index cd0518a..fef887c 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/startree/executor/StarTreeGroupByExecutor.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/startree/executor/StarTreeGroupByExecutor.java
@@ -18,22 +18,15 @@
  */
 package org.apache.pinot.core.startree.executor;
 
-import java.util.Collections;
-import java.util.HashMap;
-import java.util.List;
 import java.util.Map;
-import javax.annotation.Nonnull;
-import org.apache.pinot.common.function.AggregationFunctionType;
-import org.apache.pinot.common.request.GroupBy;
 import org.apache.pinot.common.request.transform.TransformExpressionTree;
 import org.apache.pinot.core.common.BlockValSet;
 import org.apache.pinot.core.operator.blocks.TransformBlock;
 import org.apache.pinot.core.operator.transform.TransformOperator;
-import org.apache.pinot.core.query.aggregation.AggregationFunctionContext;
 import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
+import org.apache.pinot.core.query.aggregation.function.AggregationFunctionUtils;
 import org.apache.pinot.core.query.aggregation.groupby.DefaultGroupByExecutor;
 import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
-import org.apache.pinot.core.startree.StarTreeUtils;
 import org.apache.pinot.core.startree.v2.AggregationFunctionColumnPair;
 
 
@@ -41,59 +34,36 @@ import org.apache.pinot.core.startree.v2.AggregationFunctionColumnPair;
  * The <code>StarTreeGroupByExecutor</code> class is the group-by executor for star-tree index.
  * <ul>
  *   <li>The column in function context is function-column pair</li>
- *   <li>No UDF in aggregation</li>
+ *   <li>No transform function in aggregation</li>
  *   <li>For <code>COUNT</code> aggregation function, we need to aggregate on the pre-aggregated column</li>
  * </ul>
  */
 public class StarTreeGroupByExecutor extends DefaultGroupByExecutor {
+  private final AggregationFunctionColumnPair[] _aggregationFunctionColumnPairs;
 
-  // StarTree converts column names from min(col) to min__col, this is to store the original mapping.
-  private final String[][] _functionArgs;
+  public StarTreeGroupByExecutor(AggregationFunction[] aggregationFunctions,
+      TransformExpressionTree[] groupByExpressions, int maxInitialResultHolderCapacity, int numGroupsLimit,
+      TransformOperator transformOperator) {
+    super(aggregationFunctions, groupByExpressions, maxInitialResultHolderCapacity, numGroupsLimit, transformOperator);
 
-  public StarTreeGroupByExecutor(@Nonnull AggregationFunctionContext[] functionContexts, @Nonnull GroupBy groupBy,
-      int maxInitialResultHolderCapacity, int numGroupsLimit, @Nonnull TransformOperator transformOperator) {
-    super(StarTreeUtils.createStarTreeFunctionContexts(functionContexts), groupBy, maxInitialResultHolderCapacity,
-        numGroupsLimit, transformOperator);
-
-    _functionArgs = new String[functionContexts.length][];
-    for (int i = 0; i < functionContexts.length; i++) {
-      List<String> expressions = functionContexts[i].getExpressions();
-      _functionArgs[i] = new String[expressions.size()];
-
-      for (int j = 0; j < expressions.size(); j++) {
-        _functionArgs[i][j] = expressions.get(j);
-      }
+    int numAggregationFunctions = aggregationFunctions.length;
+    _aggregationFunctionColumnPairs = new AggregationFunctionColumnPair[numAggregationFunctions];
+    for (int i = 0; i < numAggregationFunctions; i++) {
+      _aggregationFunctionColumnPairs[i] =
+          AggregationFunctionUtils.getAggregationFunctionColumnPair(aggregationFunctions[i]);
     }
   }
 
   @Override
-  protected void aggregate(@Nonnull TransformBlock transformBlock, int length, int functionIndex) {
-    AggregationFunction function = _functions[functionIndex];
-    GroupByResultHolder resultHolder = _resultHolders[functionIndex];
-
-    BlockValSet blockValueSet;
-    Map<String, BlockValSet> blockValSetMap;
-
-    AggregationFunctionType functionType = function.getType();
-    if (functionType == AggregationFunctionType.COUNT) {
-      blockValueSet = transformBlock.getBlockValueSet(AggregationFunctionColumnPair.COUNT_STAR_COLUMN_NAME);
-      function.getInputExpressions();
-      blockValSetMap = Collections.singletonMap(_functionArgs[functionIndex][0], blockValueSet);
-    } else {
-
-      blockValSetMap = new HashMap<>();
-      for (int i = 0; i < _functionArgs[functionIndex].length; i++) {
-        TransformExpressionTree aggregationExpression = _aggregationExpressions[functionIndex][i];
-        blockValueSet = transformBlock.getBlockValueSet(
-            AggregationFunctionColumnPair.toColumnName(functionType, aggregationExpression.getValue()));
-        blockValSetMap.put(_functionArgs[functionIndex][i], blockValueSet);
-      }
-    }
-
+  protected void aggregate(TransformBlock transformBlock, int length, int functionIndex) {
+    AggregationFunction aggregationFunction = _aggregationFunctions[functionIndex];
+    GroupByResultHolder groupByResultHolder = _groupByResultHolders[functionIndex];
+    Map<TransformExpressionTree, BlockValSet> blockValSetMap =
+        AggregationFunctionUtils.getBlockValSetMap(_aggregationFunctionColumnPairs[functionIndex], transformBlock);
     if (_hasMVGroupByExpression) {
-      function.aggregateGroupByMV(length, _mvGroupKeys, resultHolder, blockValSetMap);
+      aggregationFunction.aggregateGroupByMV(length, _mvGroupKeys, groupByResultHolder, blockValSetMap);
     } else {
-      function.aggregateGroupBySV(length, _svGroupKeys, resultHolder, blockValSetMap);
+      aggregationFunction.aggregateGroupBySV(length, _svGroupKeys, groupByResultHolder, blockValSetMap);
     }
   }
-}
\ No newline at end of file
+}
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/startree/plan/StarTreeTransformPlanNode.java b/pinot-core/src/main/java/org/apache/pinot/core/startree/plan/StarTreeTransformPlanNode.java
index 10c2276..03938d1 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/startree/plan/StarTreeTransformPlanNode.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/startree/plan/StarTreeTransformPlanNode.java
@@ -18,6 +18,7 @@
  */
 package org.apache.pinot.core.startree.plan;
 
+import java.util.Arrays;
 import java.util.Collections;
 import java.util.HashSet;
 import java.util.Map;
@@ -40,8 +41,8 @@ public class StarTreeTransformPlanNode implements PlanNode {
   private final StarTreeProjectionPlanNode _starTreeProjectionPlanNode;
 
   public StarTreeTransformPlanNode(StarTreeV2 starTreeV2,
-      Set<AggregationFunctionColumnPair> aggregationFunctionColumnPairs,
-      @Nullable Set<TransformExpressionTree> groupByExpressions, @Nullable FilterQueryTree rootFilterNode,
+      AggregationFunctionColumnPair[] aggregationFunctionColumnPairs,
+      @Nullable TransformExpressionTree[] groupByExpressions, @Nullable FilterQueryTree rootFilterNode,
       @Nullable Map<String, String> debugOptions) {
     Set<String> projectionColumns = new HashSet<>();
     for (AggregationFunctionColumnPair aggregationFunctionColumnPair : aggregationFunctionColumnPairs) {
@@ -49,7 +50,7 @@ public class StarTreeTransformPlanNode implements PlanNode {
     }
     Set<String> groupByColumns;
     if (groupByExpressions != null) {
-      _groupByExpressions = groupByExpressions;
+      _groupByExpressions = new HashSet<>(Arrays.asList(groupByExpressions));
       groupByColumns = new HashSet<>();
       for (TransformExpressionTree groupByExpression : groupByExpressions) {
         groupByExpression.getColumns(groupByColumns);
@@ -65,6 +66,9 @@ public class StarTreeTransformPlanNode implements PlanNode {
 
   @Override
   public TransformOperator run() {
+    // NOTE: Here we do not put aggregation expressions into TransformOperator based on the following assumptions:
+    //       - They are all columns (not functions or constants), where no transform is required
+    //       - We never call TransformOperator.getResultMetadata() or TransformOperator.getDictionary() on them
     return new TransformOperator(_starTreeProjectionPlanNode.run(), _groupByExpressions);
   }
 
diff --git a/pinot-core/src/test/java/org/apache/pinot/core/data/table/IndexedTableTest.java b/pinot-core/src/test/java/org/apache/pinot/core/data/table/IndexedTableTest.java
index 6c90739..960791a 100644
--- a/pinot-core/src/test/java/org/apache/pinot/core/data/table/IndexedTableTest.java
+++ b/pinot-core/src/test/java/org/apache/pinot/core/data/table/IndexedTableTest.java
@@ -20,6 +20,7 @@ package org.apache.pinot.core.data.table;
 
 import com.google.common.collect.Lists;
 import java.util.ArrayList;
+import java.util.Collections;
 import java.util.Iterator;
 import java.util.List;
 import java.util.concurrent.Callable;
@@ -29,10 +30,12 @@ import java.util.concurrent.Executors;
 import java.util.concurrent.Future;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.TimeoutException;
-import org.apache.pinot.common.request.AggregationInfo;
 import org.apache.pinot.common.request.SelectionSort;
 import org.apache.pinot.common.utils.DataSchema;
 import org.apache.pinot.common.utils.DataSchema.ColumnDataType;
+import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
+import org.apache.pinot.core.query.aggregation.function.MaxAggregationFunction;
+import org.apache.pinot.core.query.aggregation.function.SumAggregationFunction;
 import org.testng.Assert;
 import org.testng.annotations.DataProvider;
 import org.testng.annotations.Test;
@@ -44,31 +47,18 @@ import org.testng.annotations.Test;
 public class IndexedTableTest {
 
   @Test
-  public void testConcurrentIndexedTable() throws InterruptedException, TimeoutException, ExecutionException {
-
+  public void testConcurrentIndexedTable()
+      throws InterruptedException, TimeoutException, ExecutionException {
     DataSchema dataSchema = new DataSchema(new String[]{"d1", "d2", "d3", "sum(m1)", "max(m2)"},
-        new ColumnDataType[]{ColumnDataType.STRING, ColumnDataType.INT, ColumnDataType.DOUBLE, ColumnDataType.DOUBLE,
-            ColumnDataType.DOUBLE});
-
-    AggregationInfo agg1 = new AggregationInfo();
-    List<String> args1 = new ArrayList<>(1);
-    args1.add("m1");
-    agg1.setExpressions(args1);
-    agg1.setAggregationType("sum");
-
-    AggregationInfo agg2 = new AggregationInfo();
-    List<String> args2 = new ArrayList<>(1);
-    args2.add("m2");
-    agg2.setExpressions(args2);
-    agg2.setAggregationType("max");
-    List<AggregationInfo> aggregationInfos = Lists.newArrayList(agg1, agg2);
+        new ColumnDataType[]{ColumnDataType.STRING, ColumnDataType.INT, ColumnDataType.DOUBLE, ColumnDataType.DOUBLE, ColumnDataType.DOUBLE});
+    AggregationFunction[] aggregationFunctions =
+        new AggregationFunction[]{new SumAggregationFunction("m1"), new MaxAggregationFunction("m2")};
+    SelectionSort selectionSort = new SelectionSort();
+    selectionSort.setColumn("sum(m1)");
+    selectionSort.setIsAsc(true);
+    List<SelectionSort> orderBy = Collections.singletonList(selectionSort);
 
-    SelectionSort sel = new SelectionSort();
-    sel.setColumn("sum(m1)");
-    sel.setIsAsc(true);
-    List<SelectionSort> orderBy = Lists.newArrayList(sel);
-
-    IndexedTable indexedTable = new ConcurrentIndexedTable(dataSchema, aggregationInfos, orderBy, 5);
+    IndexedTable indexedTable = new ConcurrentIndexedTable(dataSchema, aggregationFunctions, orderBy, 5);
 
     // 3 threads upsert together
     // a inserted 6 times (60), b inserted 5 times (50), d inserted 2 times (20)
@@ -81,7 +71,8 @@ public class IndexedTableTest {
       Callable<Void> c1 = () -> {
         indexedTable.upsert(getKey(new Object[]{"a", 1, 10d}), getRecord(new Object[]{"a", 1, 10d, 10d, 100d}));
         indexedTable.upsert(getKey(new Object[]{"b", 2, 20d}), getRecord(new Object[]{"b", 2, 20d, 10d, 200d}));
-        indexedTable.upsert(getKey(new Object[]{"c", 3, 30d}), getRecord(new Object[]{"c", 3, 30d, 10000d, 300d})); // eviction candidate
+        indexedTable.upsert(getKey(new Object[]{"c", 3, 30d}),
+            getRecord(new Object[]{"c", 3, 30d, 10000d, 300d})); // eviction candidate
         indexedTable.upsert(getKey(new Object[]{"d", 4, 40d}), getRecord(new Object[]{"d", 4, 40d, 10d, 400d}));
         indexedTable.upsert(getKey(new Object[]{"d", 4, 40d}), getRecord(new Object[]{"d", 4, 40d, 10d, 400d}));
         indexedTable.upsert(getKey(new Object[]{"e", 5, 50d}), getRecord(new Object[]{"e", 5, 50d, 10d, 500d}));
@@ -90,27 +81,29 @@ public class IndexedTableTest {
 
       Callable<Void> c2 = () -> {
         indexedTable.upsert(getKey(new Object[]{"a", 1, 10d}), getRecord(new Object[]{"a", 1, 10d, 10d, 100d}));
-        indexedTable.upsert(getKey(new Object[]{"f", 6, 60d}), getRecord(new Object[]{"f", 6, 60d,20000d, 600d})); // eviction candidate
-        indexedTable.upsert(getKey(new Object[]{"g", 7, 70d}), getRecord(new Object[]{"g", 7, 70d,10d, 700d}));
-        indexedTable.upsert(getKey(new Object[]{"b", 2, 20d}), getRecord(new Object[]{"b", 2, 20d,10d, 200d}));
-        indexedTable.upsert(getKey(new Object[]{"b", 2, 20d}), getRecord(new Object[]{"b", 2, 20d,10d, 200d}));
-        indexedTable.upsert(getKey(new Object[]{"h", 8, 80d}), getRecord(new Object[]{"h", 8, 80d,10d, 800d}));
-        indexedTable.upsert(getKey(new Object[]{"a", 1, 10d}), getRecord(new Object[]{"a", 1, 10d,10d, 100d}));
-        indexedTable.upsert(getKey(new Object[]{"i", 9, 90d}), getRecord(new Object[]{"i", 9, 90d,500d, 900d}));
+        indexedTable.upsert(getKey(new Object[]{"f", 6, 60d}),
+            getRecord(new Object[]{"f", 6, 60d, 20000d, 600d})); // eviction candidate
+        indexedTable.upsert(getKey(new Object[]{"g", 7, 70d}), getRecord(new Object[]{"g", 7, 70d, 10d, 700d}));
+        indexedTable.upsert(getKey(new Object[]{"b", 2, 20d}), getRecord(new Object[]{"b", 2, 20d, 10d, 200d}));
+        indexedTable.upsert(getKey(new Object[]{"b", 2, 20d}), getRecord(new Object[]{"b", 2, 20d, 10d, 200d}));
+        indexedTable.upsert(getKey(new Object[]{"h", 8, 80d}), getRecord(new Object[]{"h", 8, 80d, 10d, 800d}));
+        indexedTable.upsert(getKey(new Object[]{"a", 1, 10d}), getRecord(new Object[]{"a", 1, 10d, 10d, 100d}));
+        indexedTable.upsert(getKey(new Object[]{"i", 9, 90d}), getRecord(new Object[]{"i", 9, 90d, 500d, 900d}));
         return null;
       };
 
       Callable<Void> c3 = () -> {
-        indexedTable.upsert(getKey(new Object[]{"a", 1, 10d}), getRecord(new Object[]{"a", 1, 10d,10d, 100d}));
-        indexedTable.upsert(getKey(new Object[]{"j", 10, 100d}), getRecord(new Object[]{"j", 10, 100d,10d, 1000d}));
-        indexedTable.upsert(getKey(new Object[]{"b", 2, 20d}), getRecord(new Object[]{"b", 2, 20d,10d, 200d}));
-        indexedTable.upsert(getKey(new Object[]{"k", 11, 110d}), getRecord(new Object[]{"k", 11, 110d,10d, 1100d}));
-        indexedTable.upsert(getKey(new Object[]{"a", 1, 10d}), getRecord(new Object[]{"a", 1, 10d,10d, 100d}));
-        indexedTable.upsert(getKey(new Object[]{"l", 12, 120d}), getRecord(new Object[]{"l", 12, 120d,10d, 1200d}));
-        indexedTable.upsert(getKey(new Object[]{"a", 1, 10d}), getRecord(new Object[]{"a", 1, 10d,10d, 100d})); // trimming candidate
-        indexedTable.upsert(getKey(new Object[]{"b", 2, 20d}), getRecord(new Object[]{"b", 2, 20d,10d, 200d}));
-        indexedTable.upsert(getKey(new Object[]{"m", 13, 130d}), getRecord(new Object[]{"m", 13, 130d,10d, 1300d}));
-        indexedTable.upsert(getKey(new Object[]{"n", 14, 140d}), getRecord(new Object[]{"n", 14, 140d,10d, 1400d}));
+        indexedTable.upsert(getKey(new Object[]{"a", 1, 10d}), getRecord(new Object[]{"a", 1, 10d, 10d, 100d}));
+        indexedTable.upsert(getKey(new Object[]{"j", 10, 100d}), getRecord(new Object[]{"j", 10, 100d, 10d, 1000d}));
+        indexedTable.upsert(getKey(new Object[]{"b", 2, 20d}), getRecord(new Object[]{"b", 2, 20d, 10d, 200d}));
+        indexedTable.upsert(getKey(new Object[]{"k", 11, 110d}), getRecord(new Object[]{"k", 11, 110d, 10d, 1100d}));
+        indexedTable.upsert(getKey(new Object[]{"a", 1, 10d}), getRecord(new Object[]{"a", 1, 10d, 10d, 100d}));
+        indexedTable.upsert(getKey(new Object[]{"l", 12, 120d}), getRecord(new Object[]{"l", 12, 120d, 10d, 1200d}));
+        indexedTable.upsert(getKey(new Object[]{"a", 1, 10d}),
+            getRecord(new Object[]{"a", 1, 10d, 10d, 100d})); // trimming candidate
+        indexedTable.upsert(getKey(new Object[]{"b", 2, 20d}), getRecord(new Object[]{"b", 2, 20d, 10d, 200d}));
+        indexedTable.upsert(getKey(new Object[]{"m", 13, 130d}), getRecord(new Object[]{"m", 13, 130d, 10d, 1300d}));
+        indexedTable.upsert(getKey(new Object[]{"n", 14, 140d}), getRecord(new Object[]{"n", 14, 140d, 10d, 1400d}));
         return null;
       };
 
@@ -122,7 +115,6 @@ public class IndexedTableTest {
       indexedTable.finish(false);
       Assert.assertEquals(indexedTable.size(), 5);
       checkEvicted(indexedTable, "c", "f");
-
     } finally {
       executorService.shutdown();
     }
@@ -130,27 +122,15 @@ public class IndexedTableTest {
 
   @Test(dataProvider = "initDataProvider")
   public void testNonConcurrentIndexedTable(List<SelectionSort> orderBy, List<String> survivors) {
-
     DataSchema dataSchema = new DataSchema(new String[]{"d1", "d2", "d3", "d4", "sum(m1)", "max(m2)"},
         new ColumnDataType[]{ColumnDataType.STRING, ColumnDataType.INT, ColumnDataType.DOUBLE, ColumnDataType.INT, ColumnDataType.DOUBLE, ColumnDataType.DOUBLE});
-
-    AggregationInfo agg1 = new AggregationInfo();
-    List<String> args1 = new ArrayList<>(1);
-    args1.add("m1");
-    agg1.setExpressions(args1);
-    agg1.setAggregationType("sum");
-
-    AggregationInfo agg2 = new AggregationInfo();
-    List<String> args2 = new ArrayList<>(1);
-    args2.add("m2");
-    agg2.setExpressions(args2);
-    agg2.setAggregationType("max");
-    List<AggregationInfo> aggregationInfos = Lists.newArrayList(agg1, agg2);
+    AggregationFunction[] aggregationFunctions =
+        new AggregationFunction[]{new SumAggregationFunction("m1"), new MaxAggregationFunction("m2")};
 
     // Test SimpleIndexedTable
-    IndexedTable simpleIndexedTable = new SimpleIndexedTable(dataSchema, aggregationInfos, orderBy, 5);
+    IndexedTable simpleIndexedTable = new SimpleIndexedTable(dataSchema, aggregationFunctions, orderBy, 5);
     // merge table
-    IndexedTable mergeTable = new SimpleIndexedTable(dataSchema, aggregationInfos, orderBy, 10);
+    IndexedTable mergeTable = new SimpleIndexedTable(dataSchema, aggregationFunctions, orderBy, 10);
     testNonConcurrent(simpleIndexedTable, mergeTable);
 
     // finish
@@ -158,8 +138,8 @@ public class IndexedTableTest {
     checkSurvivors(simpleIndexedTable, survivors);
 
     // Test ConcurrentIndexedTable
-    IndexedTable concurrentIndexedTable = new ConcurrentIndexedTable(dataSchema, aggregationInfos, orderBy, 5);
-    mergeTable = new SimpleIndexedTable(dataSchema, aggregationInfos, orderBy, 10);
+    IndexedTable concurrentIndexedTable = new ConcurrentIndexedTable(dataSchema, aggregationFunctions, orderBy, 5);
+    mergeTable = new SimpleIndexedTable(dataSchema, aggregationFunctions, orderBy, 10);
     testNonConcurrent(concurrentIndexedTable, mergeTable);
 
     // finish
@@ -300,6 +280,7 @@ public class IndexedTableTest {
   private Key getKey(Object[] keys) {
     return new Key(keys);
   }
+
   private Record getRecord(Object[] columns) {
     return new Record(columns);
   }
@@ -307,26 +288,14 @@ public class IndexedTableTest {
   @Test
   public void testNoMoreNewRecords() {
     DataSchema dataSchema = new DataSchema(new String[]{"d1", "d2", "d3", "sum(m1)", "max(m2)"},
-        new ColumnDataType[]{ColumnDataType.STRING, ColumnDataType.INT, ColumnDataType.DOUBLE, ColumnDataType.DOUBLE,
-            ColumnDataType.DOUBLE});
-
-    AggregationInfo agg1 = new AggregationInfo();
-    List<String> args1 = new ArrayList<>(1);
-    args1.add("m1");
-    agg1.setExpressions(args1);
-    agg1.setAggregationType("sum");
-
-    AggregationInfo agg2 = new AggregationInfo();
-    List<String> args2 = new ArrayList<>(1);
-    args2.add("m2");
-    agg2.setExpressions(args2);
-    agg2.setAggregationType("max");
-    List<AggregationInfo> aggregationInfos = Lists.newArrayList(agg1, agg2);
-
-    IndexedTable indexedTable = new SimpleIndexedTable(dataSchema, aggregationInfos, null, 5);
+        new ColumnDataType[]{ColumnDataType.STRING, ColumnDataType.INT, ColumnDataType.DOUBLE, ColumnDataType.DOUBLE, ColumnDataType.DOUBLE});
+    AggregationFunction[] aggregationFunctions =
+        new AggregationFunction[]{new SumAggregationFunction("m1"), new MaxAggregationFunction("m2")};
+
+    IndexedTable indexedTable = new SimpleIndexedTable(dataSchema, aggregationFunctions, null, 5);
     testNoMoreNewRecordsInTable(indexedTable);
 
-    indexedTable = new ConcurrentIndexedTable(dataSchema, aggregationInfos, null, 5);
+    indexedTable = new ConcurrentIndexedTable(dataSchema, aggregationFunctions, null, 5);
     testNoMoreNewRecordsInTable(indexedTable);
   }
 
@@ -355,6 +324,5 @@ public class IndexedTableTest {
     indexedTable.finish(false);
 
     checkEvicted(indexedTable, "f", "g");
-
   }
 }
diff --git a/pinot-core/src/test/java/org/apache/pinot/core/data/table/TableResizerTest.java b/pinot-core/src/test/java/org/apache/pinot/core/data/table/TableResizerTest.java
index 15c1e3f..61607c7 100644
--- a/pinot-core/src/test/java/org/apache/pinot/core/data/table/TableResizerTest.java
+++ b/pinot-core/src/test/java/org/apache/pinot/core/data/table/TableResizerTest.java
@@ -20,18 +20,20 @@ package org.apache.pinot.core.data.table;
 
 import com.google.common.collect.Lists;
 import it.unimi.dsi.fastutil.ints.IntOpenHashSet;
-import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
-import org.apache.pinot.common.request.AggregationInfo;
 import org.apache.pinot.common.request.SelectionSort;
 import org.apache.pinot.common.utils.DataSchema;
 import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
-import org.apache.pinot.core.query.aggregation.function.AggregationFunctionUtils;
+import org.apache.pinot.core.query.aggregation.function.AvgAggregationFunction;
+import org.apache.pinot.core.query.aggregation.function.DistinctCountAggregationFunction;
+import org.apache.pinot.core.query.aggregation.function.MaxAggregationFunction;
+import org.apache.pinot.core.query.aggregation.function.SumAggregationFunction;
 import org.apache.pinot.core.query.aggregation.function.customobject.AvgPair;
 import org.testng.Assert;
 import org.testng.annotations.BeforeClass;
@@ -42,14 +44,8 @@ import org.testng.annotations.Test;
  * Tests the functionality of {@link @TableResizer}
  */
 public class TableResizerTest {
-
-  private DataSchema dataSchema;
-  private List<AggregationInfo> aggregationInfos;
-  private List<SelectionSort> selectionSort;
-  private SelectionSort sel1;
-  private SelectionSort sel2;
-  private SelectionSort sel3;
-  private TableResizer _tableResizer;
+  private DataSchema _dataSchema;
+  private AggregationFunction[] _aggregationFunctions;
 
   private int trimToSize = 3;
   private Map<Key, Record> _recordsMap;
@@ -57,37 +53,11 @@ public class TableResizerTest {
   private List<Key> _keys;
 
   @BeforeClass
-  public void beforeClass() {
-    dataSchema = new DataSchema(new String[]{"d1", "d2", "d3", "sum(m1)", "max(m2)", "distinctcount(m3)", "avg(m4)"},
+  public void setUp() {
+    _dataSchema = new DataSchema(new String[]{"d1", "d2", "d3", "sum(m1)", "max(m2)", "distinctcount(m3)", "avg(m4)"},
         new DataSchema.ColumnDataType[]{DataSchema.ColumnDataType.STRING, DataSchema.ColumnDataType.INT, DataSchema.ColumnDataType.DOUBLE, DataSchema.ColumnDataType.DOUBLE, DataSchema.ColumnDataType.DOUBLE, DataSchema.ColumnDataType.OBJECT, DataSchema.ColumnDataType.OBJECT});
-    AggregationInfo agg1 = new AggregationInfo();
-    List<String> args1 = new ArrayList<>(1);
-    args1.add("m1");
-    agg1.setExpressions(args1);
-    agg1.setAggregationType("sum");
-
-    AggregationInfo agg2 = new AggregationInfo();
-    List<String> args2 = new ArrayList<>(1);
-    args2.add("m2");
-    agg2.setExpressions(args2);
-    agg2.setAggregationType("max");
-
-    AggregationInfo agg3 = new AggregationInfo();
-    List<String> args3 = new ArrayList<>(1);
-    args3.add("m3");
-    agg3.setExpressions(args3);
-    agg3.setAggregationType("distinctcount");
-
-    AggregationInfo agg4 = new AggregationInfo();
-    List<String> args4 = new ArrayList<>(1);
-    args4.add("m4");
-    agg4.setExpressions(args4);
-    agg4.setAggregationType("avg");
-    aggregationInfos = Lists.newArrayList(agg1, agg2, agg3, agg4);
-
-    sel1 = new SelectionSort();
-    sel2 = new SelectionSort();
-    sel3 = new SelectionSort();
+    _aggregationFunctions = new AggregationFunction[]{new SumAggregationFunction("m1"), new MaxAggregationFunction(
+        "m2"), new DistinctCountAggregationFunction("m3"), new AvgAggregationFunction("m4")};
 
     IntOpenHashSet i1 = new IntOpenHashSet();
     i1.add(1);
@@ -134,95 +104,93 @@ public class TableResizerTest {
   @Test
   public void testResizeRecordsMap() {
     Map<Key, Record> recordsMap;
+    SelectionSort selectionSort1 = new SelectionSort();
+    SelectionSort selectionSort2 = new SelectionSort();
+    SelectionSort selectionSort3 = new SelectionSort();
     // Test resize algorithm with numRecordsToEvict < trimToSize.
     // TotalRecords=5; trimToSize=3; numRecordsToEvict=2
 
     // d1 asc
-    sel1.setColumn("d1");
-    sel1.setIsAsc(true);
-    selectionSort = Lists.newArrayList(sel1);
+    selectionSort1.setColumn("d1");
+    selectionSort1.setIsAsc(true);
     recordsMap = new HashMap<>(_recordsMap);
-    _tableResizer = new TableResizer(dataSchema, aggregationInfos, selectionSort);
-    _tableResizer.resizeRecordsMap(recordsMap, trimToSize);
+    TableResizer tableResizer =
+        new TableResizer(_dataSchema, _aggregationFunctions, Collections.singletonList(selectionSort1));
+    tableResizer.resizeRecordsMap(recordsMap, trimToSize);
     Assert.assertEquals(recordsMap.size(), trimToSize);
     Assert.assertTrue(recordsMap.containsKey(_keys.get(0))); // a, b, c
     Assert.assertTrue(recordsMap.containsKey(_keys.get(1)));
 
     // d1 desc
-    sel1.setColumn("d1");
-    sel1.setIsAsc(false);
-    selectionSort = Lists.newArrayList(sel1);
-    _tableResizer = new TableResizer(dataSchema, aggregationInfos, selectionSort);
+    selectionSort1.setColumn("d1");
+    selectionSort1.setIsAsc(false);
+    tableResizer = new TableResizer(_dataSchema, _aggregationFunctions, Collections.singletonList(selectionSort1));
     recordsMap = new HashMap<>(_recordsMap);
-    _tableResizer.resizeRecordsMap(recordsMap, trimToSize);
+    tableResizer.resizeRecordsMap(recordsMap, trimToSize);
     Assert.assertEquals(recordsMap.size(), trimToSize);
     Assert.assertTrue(recordsMap.containsKey(_keys.get(2))); // c, c, c
     Assert.assertTrue(recordsMap.containsKey(_keys.get(3)));
     Assert.assertTrue(recordsMap.containsKey(_keys.get(4)));
 
     // d1 asc, d3 desc (tie breaking with 2nd comparator
-    sel1.setColumn("d1");
-    sel1.setIsAsc(true);
-    sel2.setColumn("d3");
-    sel2.setIsAsc(false);
-    selectionSort = Lists.newArrayList(sel1, sel2);
-    _tableResizer = new TableResizer(dataSchema, aggregationInfos, selectionSort);
+    selectionSort1.setColumn("d1");
+    selectionSort1.setIsAsc(true);
+    selectionSort2.setColumn("d3");
+    selectionSort2.setIsAsc(false);
+    tableResizer = new TableResizer(_dataSchema, _aggregationFunctions, Arrays.asList(selectionSort1, selectionSort2));
     recordsMap = new HashMap<>(_recordsMap);
-    _tableResizer.resizeRecordsMap(recordsMap, trimToSize);
+    tableResizer.resizeRecordsMap(recordsMap, trimToSize);
     Assert.assertEquals(recordsMap.size(), trimToSize);
     Assert.assertTrue(recordsMap.containsKey(_keys.get(0))); // 10, 10, 300
     Assert.assertTrue(recordsMap.containsKey(_keys.get(1)));
     Assert.assertTrue(recordsMap.containsKey(_keys.get(4)));
 
     // d2 asc
-    sel1.setColumn("d2");
-    sel1.setIsAsc(true);
-    selectionSort = Lists.newArrayList(sel1);
-    _tableResizer = new TableResizer(dataSchema, aggregationInfos, selectionSort);
+    selectionSort1.setColumn("d2");
+    selectionSort1.setIsAsc(true);
+    tableResizer = new TableResizer(_dataSchema, _aggregationFunctions, Collections.singletonList(selectionSort1));
     recordsMap = new HashMap<>(_recordsMap);
-    _tableResizer.resizeRecordsMap(recordsMap, trimToSize);
+    tableResizer.resizeRecordsMap(recordsMap, trimToSize);
     Assert.assertEquals(recordsMap.size(), trimToSize);
     Assert.assertTrue(recordsMap.containsKey(_keys.get(0))); // 10, 10, 50
     Assert.assertTrue(recordsMap.containsKey(_keys.get(1)));
     Assert.assertTrue(recordsMap.containsKey(_keys.get(3)));
 
     // d1 asc, sum(m1) desc, max(m2) desc
-    sel1.setColumn("d1");
-    sel1.setIsAsc(true);
-    sel2.setColumn("sum(m1)");
-    sel2.setIsAsc(false);
-    sel3.setColumn("max(m2)");
-    sel3.setIsAsc(false);
-    selectionSort = Lists.newArrayList(sel1, sel2, sel3);
-    _tableResizer = new TableResizer(dataSchema, aggregationInfos, selectionSort);
+    selectionSort1.setColumn("d1");
+    selectionSort1.setIsAsc(true);
+    selectionSort2.setColumn("sum(m1)");
+    selectionSort2.setIsAsc(false);
+    selectionSort3.setColumn("max(m2)");
+    selectionSort3.setIsAsc(false);
+    tableResizer = new TableResizer(_dataSchema, _aggregationFunctions,
+        Arrays.asList(selectionSort1, selectionSort2, selectionSort3));
     recordsMap = new HashMap<>(_recordsMap);
-    _tableResizer.resizeRecordsMap(recordsMap, trimToSize);
+    tableResizer.resizeRecordsMap(recordsMap, trimToSize);
     Assert.assertEquals(recordsMap.size(), trimToSize);
     Assert.assertTrue(recordsMap.containsKey(_keys.get(0))); // a, b, (c (30, 300))
     Assert.assertTrue(recordsMap.containsKey(_keys.get(1)));
     Assert.assertTrue(recordsMap.containsKey(_keys.get(2)));
 
     // object type avg(m4) asc
-    sel1.setColumn("avg(m4)");
-    sel1.setIsAsc(true);
-    selectionSort = Lists.newArrayList(sel1);
-    _tableResizer = new TableResizer(dataSchema, aggregationInfos, selectionSort);
+    selectionSort1.setColumn("avg(m4)");
+    selectionSort1.setIsAsc(true);
+    tableResizer = new TableResizer(_dataSchema, _aggregationFunctions, Collections.singletonList(selectionSort1));
     recordsMap = new HashMap<>(_recordsMap);
-    _tableResizer.resizeRecordsMap(recordsMap, trimToSize);
+    tableResizer.resizeRecordsMap(recordsMap, trimToSize);
     Assert.assertEquals(recordsMap.size(), trimToSize);
     Assert.assertTrue(recordsMap.containsKey(_keys.get(4))); // 2, 3, 3.33,
     Assert.assertTrue(recordsMap.containsKey(_keys.get(3)));
     Assert.assertTrue(recordsMap.containsKey(_keys.get(1)));
 
     // non-comparable intermediate result
-    sel1.setColumn("distinctcount(m3)");
-    sel1.setIsAsc(false);
-    sel2.setColumn("d1");
-    sel2.setIsAsc(true);
-    selectionSort = Lists.newArrayList(sel1, sel2);
-    _tableResizer = new TableResizer(dataSchema, aggregationInfos, selectionSort);
+    selectionSort1.setColumn("distinctcount(m3)");
+    selectionSort1.setIsAsc(false);
+    selectionSort2.setColumn("d1");
+    selectionSort2.setIsAsc(true);
+    tableResizer = new TableResizer(_dataSchema, _aggregationFunctions, Arrays.asList(selectionSort1, selectionSort2));
     recordsMap = new HashMap<>(_recordsMap);
-    _tableResizer.resizeRecordsMap(recordsMap, trimToSize);
+    tableResizer.resizeRecordsMap(recordsMap, trimToSize);
     Assert.assertEquals(recordsMap.size(), trimToSize);
     Assert.assertTrue(recordsMap.containsKey(_keys.get(4))); // 6, 5, 4 (b)
     Assert.assertTrue(recordsMap.containsKey(_keys.get(3)));
@@ -233,36 +201,33 @@ public class TableResizerTest {
     trimToSize = 2;
 
     // d1 asc
-    sel1.setColumn("d1");
-    sel1.setIsAsc(true);
-    selectionSort = Lists.newArrayList(sel1);
-    _tableResizer = new TableResizer(dataSchema, aggregationInfos, selectionSort);
+    selectionSort1.setColumn("d1");
+    selectionSort1.setIsAsc(true);
+    tableResizer = new TableResizer(_dataSchema, _aggregationFunctions, Collections.singletonList(selectionSort1));
     recordsMap = new HashMap<>(_recordsMap);
-    _tableResizer.resizeRecordsMap(recordsMap, trimToSize);
+    tableResizer.resizeRecordsMap(recordsMap, trimToSize);
     Assert.assertEquals(recordsMap.size(), trimToSize);
     Assert.assertTrue(recordsMap.containsKey(_keys.get(0))); // a, b
     Assert.assertTrue(recordsMap.containsKey(_keys.get(1)));
 
     // object type avg(m4) asc
-    sel1.setColumn("avg(m4)");
-    sel1.setIsAsc(true);
-    selectionSort = Lists.newArrayList(sel1);
-    _tableResizer = new TableResizer(dataSchema, aggregationInfos, selectionSort);
+    selectionSort1.setColumn("avg(m4)");
+    selectionSort1.setIsAsc(true);
+    tableResizer = new TableResizer(_dataSchema, _aggregationFunctions, Collections.singletonList(selectionSort1));
     recordsMap = new HashMap<>(_recordsMap);
-    _tableResizer.resizeRecordsMap(recordsMap, trimToSize);
+    tableResizer.resizeRecordsMap(recordsMap, trimToSize);
     Assert.assertEquals(recordsMap.size(), trimToSize);
     Assert.assertTrue(recordsMap.containsKey(_keys.get(4))); // 2, 3, 3.33,
     Assert.assertTrue(recordsMap.containsKey(_keys.get(3)));
 
     // non-comparable intermediate result
-    sel1.setColumn("distinctcount(m3)");
-    sel1.setIsAsc(false);
-    sel2.setColumn("d1");
-    sel2.setIsAsc(true);
-    selectionSort = Lists.newArrayList(sel1, sel2);
-    _tableResizer = new TableResizer(dataSchema, aggregationInfos, selectionSort);
+    selectionSort1.setColumn("distinctcount(m3)");
+    selectionSort1.setIsAsc(false);
+    selectionSort2.setColumn("d1");
+    selectionSort2.setIsAsc(true);
+    tableResizer = new TableResizer(_dataSchema, _aggregationFunctions, Arrays.asList(selectionSort1, selectionSort2));
     recordsMap = new HashMap<>(_recordsMap);
-    _tableResizer.resizeRecordsMap(recordsMap, trimToSize);
+    tableResizer.resizeRecordsMap(recordsMap, trimToSize);
     Assert.assertEquals(recordsMap.size(), trimToSize);
     Assert.assertTrue(recordsMap.containsKey(_keys.get(4))); // 6, 5, 4 (b)
     Assert.assertTrue(recordsMap.containsKey(_keys.get(3)));
@@ -279,14 +244,17 @@ public class TableResizerTest {
     List<Record> sortedRecords;
     int[] order;
     Map<Key, Record> recordsMap;
+    SelectionSort selectionSort1 = new SelectionSort();
+    SelectionSort selectionSort2 = new SelectionSort();
+    SelectionSort selectionSort3 = new SelectionSort();
 
     // d1 asc
-    sel1.setColumn("d1");
-    sel1.setIsAsc(true);
-    selectionSort = Lists.newArrayList(sel1);
-    _tableResizer = new TableResizer(dataSchema, aggregationInfos, selectionSort);
+    selectionSort1.setColumn("d1");
+    selectionSort1.setIsAsc(true);
+    TableResizer tableResizer =
+        new TableResizer(_dataSchema, _aggregationFunctions, Collections.singletonList(selectionSort1));
     recordsMap = new HashMap<>(_recordsMap);
-    sortedRecords = _tableResizer.resizeAndSortRecordsMap(recordsMap, trimToSize);
+    sortedRecords = tableResizer.resizeAndSortRecordsMap(recordsMap, trimToSize);
     Assert.assertEquals(sortedRecords.size(), trimToSize);
     order = new int[]{0, 1};
     for (int i = 0; i < order.length; i++) {
@@ -295,7 +263,7 @@ public class TableResizerTest {
 
     // d1 asc - trim to 1
     recordsMap = new HashMap<>(_recordsMap);
-    sortedRecords = _tableResizer.resizeAndSortRecordsMap(recordsMap, 1);
+    sortedRecords = tableResizer.resizeAndSortRecordsMap(recordsMap, 1);
     Assert.assertEquals(sortedRecords.size(), 1);
     order = new int[]{0};
     for (int i = 0; i < order.length; i++) {
@@ -303,14 +271,13 @@ public class TableResizerTest {
     }
 
     // d1 asc, d3 desc (tie breaking with 2nd comparator)
-    sel1.setColumn("d1");
-    sel1.setIsAsc(true);
-    sel2.setColumn("d3");
-    sel2.setIsAsc(false);
-    selectionSort = Lists.newArrayList(sel1, sel2);
-    _tableResizer = new TableResizer(dataSchema, aggregationInfos, selectionSort);
+    selectionSort1.setColumn("d1");
+    selectionSort1.setIsAsc(true);
+    selectionSort2.setColumn("d3");
+    selectionSort2.setIsAsc(false);
+    tableResizer = new TableResizer(_dataSchema, _aggregationFunctions, Arrays.asList(selectionSort1, selectionSort2));
     recordsMap = new HashMap<>(_recordsMap);
-    sortedRecords = _tableResizer.resizeAndSortRecordsMap(recordsMap, trimToSize);
+    sortedRecords = tableResizer.resizeAndSortRecordsMap(recordsMap, trimToSize);
     Assert.assertEquals(sortedRecords.size(), trimToSize);
     order = new int[]{0, 1, 4};
     for (int i = 0; i < order.length; i++) {
@@ -319,7 +286,7 @@ public class TableResizerTest {
 
     // d1 asc, d3 desc (tie breaking with 2nd comparator) - trim 1
     recordsMap = new HashMap<>(_recordsMap);
-    sortedRecords = _tableResizer.resizeAndSortRecordsMap(recordsMap, 1);
+    sortedRecords = tableResizer.resizeAndSortRecordsMap(recordsMap, 1);
     Assert.assertEquals(sortedRecords.size(), 1);
     order = new int[]{0};
     for (int i = 0; i < order.length; i++) {
@@ -327,16 +294,16 @@ public class TableResizerTest {
     }
 
     // d1 asc, sum(m1) desc, max(m2) desc
-    sel1.setColumn("d1");
-    sel1.setIsAsc(true);
-    sel2.setColumn("sum(m1)");
-    sel2.setIsAsc(false);
-    sel3.setColumn("max(m2)");
-    sel3.setIsAsc(false);
-    selectionSort = Lists.newArrayList(sel1, sel2, sel3);
-    _tableResizer = new TableResizer(dataSchema, aggregationInfos, selectionSort);
+    selectionSort1.setColumn("d1");
+    selectionSort1.setIsAsc(true);
+    selectionSort2.setColumn("sum(m1)");
+    selectionSort2.setIsAsc(false);
+    selectionSort3.setColumn("max(m2)");
+    selectionSort3.setIsAsc(false);
+    tableResizer = new TableResizer(_dataSchema, _aggregationFunctions,
+        Arrays.asList(selectionSort1, selectionSort2, selectionSort3));
     recordsMap = new HashMap<>(_recordsMap);
-    sortedRecords = _tableResizer.resizeAndSortRecordsMap(recordsMap, trimToSize);
+    sortedRecords = tableResizer.resizeAndSortRecordsMap(recordsMap, trimToSize);
     Assert.assertEquals(sortedRecords.size(), trimToSize);
     order = new int[]{0, 1, 2};
     for (int i = 0; i < order.length; i++) {
@@ -345,7 +312,7 @@ public class TableResizerTest {
 
     // trim 1
     recordsMap = new HashMap<>(_recordsMap);
-    sortedRecords = _tableResizer.resizeAndSortRecordsMap(recordsMap, 1);
+    sortedRecords = tableResizer.resizeAndSortRecordsMap(recordsMap, 1);
     Assert.assertEquals(sortedRecords.size(), 1);
     order = new int[]{0};
     for (int i = 0; i < order.length; i++) {
@@ -353,14 +320,13 @@ public class TableResizerTest {
     }
 
     // object type avg(m4) asc
-    sel1.setColumn("avg(m4)");
-    sel1.setIsAsc(true);
-    sel2.setColumn("d1");
-    sel2.setIsAsc(true);
-    selectionSort = Lists.newArrayList(sel1, sel2);
-    _tableResizer = new TableResizer(dataSchema, aggregationInfos, selectionSort);
+    selectionSort1.setColumn("avg(m4)");
+    selectionSort1.setIsAsc(true);
+    selectionSort2.setColumn("d1");
+    selectionSort2.setIsAsc(true);
+    tableResizer = new TableResizer(_dataSchema, _aggregationFunctions, Arrays.asList(selectionSort1, selectionSort2));
     recordsMap = new HashMap<>(_recordsMap);
-    sortedRecords = _tableResizer.resizeAndSortRecordsMap(recordsMap, 10); // high trim to size
+    sortedRecords = tableResizer.resizeAndSortRecordsMap(recordsMap, 10); // high trim to size
     Assert.assertEquals(sortedRecords.size(), recordsMap.size());
     order = new int[]{4, 3, 1, 0, 2};
     for (int i = 0; i < order.length; i++) {
@@ -368,14 +334,13 @@ public class TableResizerTest {
     }
 
     // non-comparable intermediate result
-    sel1.setColumn("distinctcount(m3)");
-    sel1.setIsAsc(false);
-    sel2.setColumn("avg(m4)");
-    sel2.setIsAsc(false);
-    selectionSort = Lists.newArrayList(sel1, sel2);
-    _tableResizer = new TableResizer(dataSchema, aggregationInfos, selectionSort);
+    selectionSort1.setColumn("distinctcount(m3)");
+    selectionSort1.setIsAsc(false);
+    selectionSort2.setColumn("avg(m4)");
+    selectionSort2.setIsAsc(false);
+    tableResizer = new TableResizer(_dataSchema, _aggregationFunctions, Arrays.asList(selectionSort1, selectionSort2));
     recordsMap = new HashMap<>(_recordsMap);
-    sortedRecords = _tableResizer.resizeAndSortRecordsMap(recordsMap, recordsMap.size()); // equal trim to size
+    sortedRecords = tableResizer.resizeAndSortRecordsMap(recordsMap, recordsMap.size()); // equal trim to size
     Assert.assertEquals(sortedRecords.size(), recordsMap.size());
     order = new int[]{4, 3, 2, 1, 0};
     for (int i = 0; i < order.length; i++) {
@@ -388,45 +353,43 @@ public class TableResizerTest {
    */
   @Test
   public void testIntermediateRecord() {
+    SelectionSort selectionSort1 = new SelectionSort();
+    SelectionSort selectionSort2 = new SelectionSort();
+    SelectionSort selectionSort3 = new SelectionSort();
 
     // d2
-    sel1.setColumn("d2");
-    selectionSort = Lists.newArrayList(sel1);
-    _tableResizer = new TableResizer(dataSchema, aggregationInfos, selectionSort);
+    selectionSort1.setColumn("d2");
+    TableResizer tableResizer =
+        new TableResizer(_dataSchema, _aggregationFunctions, Collections.singletonList(selectionSort1));
     for (Map.Entry<Key, Record> entry : _recordsMap.entrySet()) {
       Key key = entry.getKey();
       Record record = entry.getValue();
-      TableResizer.IntermediateRecord intermediateRecord =
-          _tableResizer.getIntermediateRecord(key, record);
+      TableResizer.IntermediateRecord intermediateRecord = tableResizer.getIntermediateRecord(key, record);
       Assert.assertEquals(intermediateRecord._key, key);
       Assert.assertEquals(intermediateRecord._values.length, 1);
       Assert.assertEquals(intermediateRecord._values[0], record.getValues()[1]);
     }
 
     // sum(m1)
-    sel1.setColumn("sum(m1)");
-    selectionSort = Lists.newArrayList(sel1);
-    _tableResizer = new TableResizer(dataSchema, aggregationInfos, selectionSort);
+    selectionSort1.setColumn("sum(m1)");
+    tableResizer = new TableResizer(_dataSchema, _aggregationFunctions, Collections.singletonList(selectionSort1));
     for (Map.Entry<Key, Record> entry : _recordsMap.entrySet()) {
       Key key = entry.getKey();
       Record record = entry.getValue();
-      TableResizer.IntermediateRecord intermediateRecord =
-          _tableResizer.getIntermediateRecord(key, record);
+      TableResizer.IntermediateRecord intermediateRecord = tableResizer.getIntermediateRecord(key, record);
       Assert.assertEquals(intermediateRecord._key, key);
       Assert.assertEquals(intermediateRecord._values.length, 1);
       Assert.assertEquals(intermediateRecord._values[0], record.getValues()[3]);
     }
 
     // d1, max(m2)
-    sel1.setColumn("d1");
-    sel2.setColumn("max(m2)");
-    selectionSort = Lists.newArrayList(sel1, sel2);
-    _tableResizer = new TableResizer(dataSchema, aggregationInfos, selectionSort);
+    selectionSort1.setColumn("d1");
+    selectionSort2.setColumn("max(m2)");
+    tableResizer = new TableResizer(_dataSchema, _aggregationFunctions, Arrays.asList(selectionSort1, selectionSort2));
     for (Map.Entry<Key, Record> entry : _recordsMap.entrySet()) {
       Key key = entry.getKey();
       Record record = entry.getValue();
-      TableResizer.IntermediateRecord intermediateRecord =
-          _tableResizer.getIntermediateRecord(key, record);
+      TableResizer.IntermediateRecord intermediateRecord = tableResizer.getIntermediateRecord(key, record);
       Assert.assertEquals(intermediateRecord._key, key);
       Assert.assertEquals(intermediateRecord._values.length, 2);
       Assert.assertEquals(intermediateRecord._values[0], record.getValues()[0]);
@@ -434,16 +397,15 @@ public class TableResizerTest {
     }
 
     // d2, sum(m1), d3
-    sel1.setColumn("d2");
-    sel2.setColumn("sum(m1)");
-    sel3.setColumn("d3");
-    selectionSort = Lists.newArrayList(sel1, sel2, sel3);
-    _tableResizer = new TableResizer(dataSchema, aggregationInfos, selectionSort);
+    selectionSort1.setColumn("d2");
+    selectionSort2.setColumn("sum(m1)");
+    selectionSort3.setColumn("d3");
+    tableResizer = new TableResizer(_dataSchema, _aggregationFunctions,
+        Arrays.asList(selectionSort1, selectionSort2, selectionSort3));
     for (Map.Entry<Key, Record> entry : _recordsMap.entrySet()) {
       Key key = entry.getKey();
       Record record = entry.getValue();
-      TableResizer.IntermediateRecord intermediateRecord =
-          _tableResizer.getIntermediateRecord(key, record);
+      TableResizer.IntermediateRecord intermediateRecord = tableResizer.getIntermediateRecord(key, record);
       Assert.assertEquals(intermediateRecord._key, key);
       Assert.assertEquals(intermediateRecord._values.length, 3);
       Assert.assertEquals(intermediateRecord._values[0], record.getValues()[1]);
@@ -452,20 +414,17 @@ public class TableResizerTest {
     }
 
     // non-comparable intermediate result
-    sel1.setColumn("distinctcount(m3)");
-    selectionSort = Lists.newArrayList(sel1);
-    _tableResizer = new TableResizer(dataSchema, aggregationInfos, selectionSort);
-    AggregationFunction distinctCountFunction =
-        AggregationFunctionUtils.getAggregationFunctionContext(aggregationInfos.get(2)).getAggregationFunction();
+    selectionSort1.setColumn("distinctcount(m3)");
+    tableResizer = new TableResizer(_dataSchema, _aggregationFunctions, Collections.singletonList(selectionSort1));
+    AggregationFunction distinctCountFunction = _aggregationFunctions[2];
     for (Map.Entry<Key, Record> entry : _recordsMap.entrySet()) {
       Key key = entry.getKey();
       Record record = entry.getValue();
-      TableResizer.IntermediateRecord intermediateRecord =
-          _tableResizer.getIntermediateRecord(key, record);
+      TableResizer.IntermediateRecord intermediateRecord = tableResizer.getIntermediateRecord(key, record);
       Assert.assertEquals(intermediateRecord._key, key);
       Assert.assertEquals(intermediateRecord._values.length, 1);
-      Assert.assertEquals(intermediateRecord._values[0],
-          distinctCountFunction.extractFinalResult(record.getValues()[5]));
+      Assert
+          .assertEquals(intermediateRecord._values[0], distinctCountFunction.extractFinalResult(record.getValues()[5]));
     }
   }
 
@@ -479,7 +438,7 @@ public class TableResizerTest {
     selectionSort.setColumn("STRING_COL");
     selectionSort.setIsAsc(true);
 
-    TableResizer tableResizer = new TableResizer(schema, Collections.emptyList(), Lists.newArrayList(selectionSort));
+    TableResizer tableResizer = new TableResizer(schema, new AggregationFunction[0], Lists.newArrayList(selectionSort));
     Set<Record> uniqueRecordsSet = new HashSet<>();
 
     Record r1 = new Record(new Object[]{"B"});
@@ -559,7 +518,7 @@ public class TableResizerTest {
 
     // change the order to DESC
     selectionSort.setIsAsc(false);
-    tableResizer = new TableResizer(schema, Collections.emptyList(), Lists.newArrayList(selectionSort));
+    tableResizer = new TableResizer(schema, new AggregationFunction[0], Lists.newArrayList(selectionSort));
 
     trimSize = 5;
     // no records should have been evicted
diff --git a/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactoryTest.java b/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactoryTest.java
index 442d86b..9c7dfed 100644
--- a/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactoryTest.java
+++ b/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactoryTest.java
@@ -20,7 +20,6 @@ package org.apache.pinot.core.query.aggregation.function;
 
 import java.util.Arrays;
 import java.util.Collections;
-import java.util.List;
 import org.apache.pinot.common.function.AggregationFunctionType;
 import org.apache.pinot.common.request.AggregationInfo;
 import org.apache.pinot.common.request.BrokerRequest;
@@ -36,7 +35,7 @@ public class AggregationFunctionFactoryTest {
     AggregationFunction aggregationFunction;
 
     BrokerRequest brokerRequest = new BrokerRequest();
-    String column = "testColumn";
+    String column;
 
     AggregationInfo aggregationInfo = new AggregationInfo();
     aggregationInfo.setAggregationType("CoUnT");
@@ -270,9 +269,9 @@ public class AggregationFunctionFactoryTest {
     AggregationInfo aggregationInfo = new AggregationInfo();
 
     aggregationInfo.setAggregationType("distinct");
-    List<String> args = Arrays.asList("column1", "column2", "column3");
-    String expected = "distinct_" + AggregationFunctionUtils.concatArgs(args);
-    aggregationInfo.setExpressions(args);
+    String[] arguments = new String[]{"column1", "column2", "column3"};
+    String expected = "distinct_" + AggregationFunctionUtils.concatArgs(arguments);
+    aggregationInfo.setExpressions(Arrays.asList(arguments));
 
     AggregationFunction aggregationFunction =
         AggregationFunctionFactory.getAggregationFunction(aggregationInfo, brokerRequest);
diff --git a/pinot-core/src/test/java/org/apache/pinot/core/startree/v2/BaseStarTreeV2Test.java b/pinot-core/src/test/java/org/apache/pinot/core/startree/v2/BaseStarTreeV2Test.java
index 03177ae..daa5ddb 100644
--- a/pinot-core/src/test/java/org/apache/pinot/core/startree/v2/BaseStarTreeV2Test.java
+++ b/pinot-core/src/test/java/org/apache/pinot/core/startree/v2/BaseStarTreeV2Test.java
@@ -30,12 +30,7 @@ import java.util.Map;
 import java.util.Random;
 import java.util.Set;
 import org.apache.commons.io.FileUtils;
-import org.apache.pinot.spi.config.table.TableConfig;
-import org.apache.pinot.spi.config.table.TableType;
-import org.apache.pinot.spi.data.FieldSpec.DataType;
-import org.apache.pinot.spi.data.Schema;
 import org.apache.pinot.common.function.AggregationFunctionType;
-import org.apache.pinot.common.request.AggregationInfo;
 import org.apache.pinot.common.request.BrokerRequest;
 import org.apache.pinot.common.request.GroupBy;
 import org.apache.pinot.common.request.transform.TransformExpressionTree;
@@ -46,7 +41,6 @@ import org.apache.pinot.core.common.BlockDocIdIterator;
 import org.apache.pinot.core.common.BlockSingleValIterator;
 import org.apache.pinot.core.common.Constants;
 import org.apache.pinot.core.common.DataSource;
-import org.apache.pinot.spi.data.readers.GenericRow;
 import org.apache.pinot.core.data.aggregator.ValueAggregator;
 import org.apache.pinot.core.data.readers.GenericRowRecordReader;
 import org.apache.pinot.core.indexsegment.IndexSegment;
@@ -54,6 +48,7 @@ import org.apache.pinot.core.indexsegment.generator.SegmentGeneratorConfig;
 import org.apache.pinot.core.indexsegment.immutable.ImmutableSegmentLoader;
 import org.apache.pinot.core.plan.FilterPlanNode;
 import org.apache.pinot.core.plan.PlanNode;
+import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
 import org.apache.pinot.core.query.aggregation.function.AggregationFunctionUtils;
 import org.apache.pinot.core.segment.creator.impl.SegmentIndexCreationDriverImpl;
 import org.apache.pinot.core.segment.index.readers.Dictionary;
@@ -62,6 +57,11 @@ import org.apache.pinot.core.startree.v2.builder.MultipleTreesBuilder;
 import org.apache.pinot.core.startree.v2.builder.MultipleTreesBuilder.BuildMode;
 import org.apache.pinot.core.startree.v2.builder.StarTreeV2BuilderConfig;
 import org.apache.pinot.pql.parsers.Pql2Compiler;
+import org.apache.pinot.spi.config.table.TableConfig;
+import org.apache.pinot.spi.config.table.TableType;
+import org.apache.pinot.spi.data.FieldSpec.DataType;
+import org.apache.pinot.spi.data.Schema;
+import org.apache.pinot.spi.data.readers.GenericRow;
 import org.apache.pinot.spi.utils.builder.TableConfigBuilder;
 import org.testng.Assert;
 import org.testng.annotations.AfterClass;
@@ -69,6 +69,7 @@ import org.testng.annotations.BeforeClass;
 import org.testng.annotations.Test;
 
 import static org.testng.Assert.assertEquals;
+import static org.testng.Assert.assertNotNull;
 
 
 /**
@@ -118,14 +119,12 @@ abstract class BaseStarTreeV2Test<R, A> {
 
     List<GenericRow> segmentRecords = new ArrayList<>(NUM_SEGMENT_RECORDS);
     for (int i = 0; i < NUM_SEGMENT_RECORDS; i++) {
-      Map<String, Object> fieldMap = new HashMap<>();
-      fieldMap.put(DIMENSION_D1, RANDOM.nextInt(DIMENSION_CARDINALITY));
-      fieldMap.put(DIMENSION_D2, RANDOM.nextInt(DIMENSION_CARDINALITY));
+      GenericRow segmentRecord = new GenericRow();
+      segmentRecord.putValue(DIMENSION_D1, RANDOM.nextInt(DIMENSION_CARDINALITY));
+      segmentRecord.putValue(DIMENSION_D2, RANDOM.nextInt(DIMENSION_CARDINALITY));
       if (rawValueType != null) {
-        fieldMap.put(METRIC, getRandomRawValue(RANDOM));
+        segmentRecord.putValue(METRIC, getRandomRawValue(RANDOM));
       }
-      GenericRow segmentRecord = new GenericRow();
-      segmentRecord.init(fieldMap);
       segmentRecords.add(segmentRecord);
     }
 
@@ -184,11 +183,14 @@ abstract class BaseStarTreeV2Test<R, A> {
     BrokerRequest brokerRequest = COMPILER.compileToBrokerRequest(query);
 
     // Aggregations
-    List<AggregationInfo> aggregationInfos = brokerRequest.getAggregationsInfo();
-    int numAggregations = aggregationInfos.size();
+    AggregationFunction[] aggregationFunctions = AggregationFunctionUtils.getAggregationFunctions(brokerRequest);
+    int numAggregations = aggregationFunctions.length;
     List<AggregationFunctionColumnPair> functionColumnPairs = new ArrayList<>(numAggregations);
-    for (AggregationInfo aggregationInfo : aggregationInfos) {
-      functionColumnPairs.add(AggregationFunctionUtils.getFunctionColumnPair(aggregationInfo));
+    for (AggregationFunction aggregationFunction : aggregationFunctions) {
+      AggregationFunctionColumnPair aggregationFunctionColumnPair =
+          AggregationFunctionUtils.getAggregationFunctionColumnPair(aggregationFunction);
+      assertNotNull(aggregationFunctionColumnPair);
+      functionColumnPairs.add(aggregationFunctionColumnPair);
     }
 
     // Group-by columns
diff --git a/pinot-core/src/test/java/org/apache/pinot/queries/InterSegmentAggregationMultiValueQueriesTest.java b/pinot-core/src/test/java/org/apache/pinot/queries/InterSegmentAggregationMultiValueQueriesTest.java
index efc3d40..3ba51c2 100644
--- a/pinot-core/src/test/java/org/apache/pinot/queries/InterSegmentAggregationMultiValueQueriesTest.java
+++ b/pinot-core/src/test/java/org/apache/pinot/queries/InterSegmentAggregationMultiValueQueriesTest.java
@@ -71,33 +71,41 @@ public class InterSegmentAggregationMultiValueQueriesTest extends BaseMultiValue
     aggregations.add(new String[]{"35436"});
     aggregations.add(new String[]{"33576"});
     aggregations.add(new String[]{"24300"});
-    QueriesTestUtils
-        .testInterSegmentAggregationGroupByResult(brokerResponse, 400000L, 0L, 800000L, 400000L,
-            groupKeys, aggregations);
+    QueriesTestUtils.testInterSegmentAggregationGroupByResult(brokerResponse, 400000L, 0L, 800000L, 400000L, groupKeys,
+        aggregations);
 
-    query = "SELECT VALUEIN(column7, 363, 469, 246, 100000), COUNTMV(column6) FROM testTable GROUP BY VALUEIN(column7, 363, 469, 246, 100000)";
+    query =
+        "SELECT VALUEIN(column7, 363, 469, 246, 100000), COUNTMV(column6) FROM testTable GROUP BY VALUEIN(column7, 363, 469, 246, 100000)";
     brokerResponse = getBrokerResponseForSqlQuery(query);
     DataSchema expectedDataSchema =
-        new DataSchema(new String[]{"valuein(column7,'363','469','246','100000')", "countmv(column6)"}, new DataSchema.ColumnDataType[]{DataSchema.ColumnDataType.INT, DataSchema.ColumnDataType.LONG});
-    QueriesTestUtils.testInterSegmentResultTable(brokerResponse, 400000L, 0L, 800000L, 400000L,
-        Lists.newArrayList(new Object[]{469, (long)33576}, new Object[]{246, (long)24300}, new Object[]{363, (long)35436}), 3, expectedDataSchema);
-
-    query = "SELECT VALUEIN(column7, 363, 469, 246, 100000), COUNTMV(column6) FROM testTable GROUP BY VALUEIN(column7, 363, 469, 246, 100000) ORDER BY COUNTMV(column6)";
+        new DataSchema(new String[]{"valuein(column7,'363','469','246','100000')", "countmv(column6)"},
+            new DataSchema.ColumnDataType[]{DataSchema.ColumnDataType.INT, DataSchema.ColumnDataType.LONG});
+    QueriesTestUtils.testInterSegmentResultTable(brokerResponse, 400000L, 0L, 800000L, 400000L, Lists
+        .newArrayList(new Object[]{469, (long) 33576}, new Object[]{246, (long) 24300},
+            new Object[]{363, (long) 35436}), 3, expectedDataSchema);
+
+    query =
+        "SELECT VALUEIN(column7, 363, 469, 246, 100000), COUNTMV(column6) FROM testTable GROUP BY VALUEIN(column7, 363, 469, 246, 100000) ORDER BY COUNTMV(column6)";
     brokerResponse = getBrokerResponseForSqlQuery(query);
-    QueriesTestUtils.testInterSegmentResultTable(brokerResponse, 400000L, 0L, 800000L, 400000L,
-        Lists.newArrayList(new Object[]{246, (long)24300}, new Object[]{469, (long)33576}, new Object[]{363, (long)35436}), 3, expectedDataSchema);
+    QueriesTestUtils.testInterSegmentResultTable(brokerResponse, 400000L, 0L, 800000L, 400000L, Lists
+        .newArrayList(new Object[]{246, (long) 24300}, new Object[]{469, (long) 33576},
+            new Object[]{363, (long) 35436}), 3, expectedDataSchema);
 
-    query = "SELECT VALUEIN(column7, 363, 469, 246, 100000), COUNTMV(column6) FROM testTable GROUP BY VALUEIN(column7, 363, 469, 246, 100000) ORDER BY COUNTMV(column6) DESC";
+    query =
+        "SELECT VALUEIN(column7, 363, 469, 246, 100000), COUNTMV(column6) FROM testTable GROUP BY VALUEIN(column7, 363, 469, 246, 100000) ORDER BY COUNTMV(column6) DESC";
     brokerResponse = getBrokerResponseForSqlQuery(query);
-    QueriesTestUtils.testInterSegmentResultTable(brokerResponse, 400000L, 0L, 800000L, 400000L,
-        Lists.newArrayList(new Object[]{363, (long)35436}, new Object[]{469, (long)33576}, new Object[]{246, (long)24300}), 3, expectedDataSchema);
-
-    query = "SELECT VALUEIN(column7, 363, 469, 246, 100000) AS value_in_col, COUNTMV(column6) FROM testTable GROUP BY value_in_col ORDER BY COUNTMV(column6) DESC";
-    expectedDataSchema =
-        new DataSchema(new String[]{"value_in_col", "countmv(column6)"}, new DataSchema.ColumnDataType[]{DataSchema.ColumnDataType.INT, DataSchema.ColumnDataType.LONG});
+    QueriesTestUtils.testInterSegmentResultTable(brokerResponse, 400000L, 0L, 800000L, 400000L, Lists
+        .newArrayList(new Object[]{363, (long) 35436}, new Object[]{469, (long) 33576},
+            new Object[]{246, (long) 24300}), 3, expectedDataSchema);
+
+    query =
+        "SELECT VALUEIN(column7, 363, 469, 246, 100000) AS value_in_col, COUNTMV(column6) FROM testTable GROUP BY value_in_col ORDER BY COUNTMV(column6) DESC";
+    expectedDataSchema = new DataSchema(new String[]{"value_in_col", "countmv(column6)"},
+        new DataSchema.ColumnDataType[]{DataSchema.ColumnDataType.INT, DataSchema.ColumnDataType.LONG});
     brokerResponse = getBrokerResponseForSqlQuery(query);
-    QueriesTestUtils.testInterSegmentResultTable(brokerResponse, 400000L, 0L, 800000L, 400000L,
-        Lists.newArrayList(new Object[]{363, (long)35436}, new Object[]{469, (long)33576}, new Object[]{246, (long)24300}), 3, expectedDataSchema);
+    QueriesTestUtils.testInterSegmentResultTable(brokerResponse, 400000L, 0L, 800000L, 400000L, Lists
+        .newArrayList(new Object[]{363, (long) 35436}, new Object[]{469, (long) 33576},
+            new Object[]{246, (long) 24300}), 3, expectedDataSchema);
 
     query = "SELECT COUNTMV(column6) FROM testTable GROUP BY daysSinceEpoch";
     brokerResponse = getBrokerResponseForPqlQuery(query);
@@ -105,18 +113,18 @@ public class InterSegmentAggregationMultiValueQueriesTest extends BaseMultiValue
     groupKeys.add(new String[]{"1756015683"});
     aggregations = new ArrayList<>();
     aggregations.add(new String[]{"426752"});
-    QueriesTestUtils
-        .testInterSegmentAggregationGroupByResult(brokerResponse, 400000L, 0L, 800000L, 400000L,
-            groupKeys, aggregations);
+    QueriesTestUtils.testInterSegmentAggregationGroupByResult(brokerResponse, 400000L, 0L, 800000L, 400000L, groupKeys,
+        aggregations);
 
-    query = "SELECT daysSinceEpoch, COUNTMV(column6) FROM testTable GROUP BY daysSinceEpoch ORDER BY COUNTMV(column6) DESC";
-    expectedDataSchema =
-        new DataSchema(new String[]{"daysSinceEpoch", "countmv(column6)"}, new DataSchema.ColumnDataType[]{DataSchema.ColumnDataType.INT, DataSchema.ColumnDataType.LONG});
+    query =
+        "SELECT daysSinceEpoch, COUNTMV(column6) FROM testTable GROUP BY daysSinceEpoch ORDER BY COUNTMV(column6) DESC";
+    expectedDataSchema = new DataSchema(new String[]{"daysSinceEpoch", "countmv(column6)"},
+        new DataSchema.ColumnDataType[]{DataSchema.ColumnDataType.INT, DataSchema.ColumnDataType.LONG});
     brokerResponse = getBrokerResponseForSqlQuery(query);
     List<Object[]> result = new ArrayList<>();
-    result.add(new Object[]{1756015683, (long)426752});
-    QueriesTestUtils.testInterSegmentResultTable(brokerResponse, 400000L, 0L, 800000L, 400000L,
-        result, 1, expectedDataSchema);
+    result.add(new Object[]{1756015683, (long) 426752});
+    QueriesTestUtils
+        .testInterSegmentResultTable(brokerResponse, 400000L, 0L, 800000L, 400000L, result, 1, expectedDataSchema);
 
     query = "SELECT COUNTMV(column6) FROM testTable GROUP BY timeconvert(daysSinceEpoch, 'DAYS', 'HOURS')";
     brokerResponse = getBrokerResponseForPqlQuery(query);
@@ -124,18 +132,18 @@ public class InterSegmentAggregationMultiValueQueriesTest extends BaseMultiValue
     groupKeys.add(new String[]{"42144376392"});
     aggregations = new ArrayList<>();
     aggregations.add(new String[]{"426752"});
-    QueriesTestUtils
-        .testInterSegmentAggregationGroupByResult(brokerResponse, 400000L, 0L, 800000L, 400000L,
-            groupKeys, aggregations);
+    QueriesTestUtils.testInterSegmentAggregationGroupByResult(brokerResponse, 400000L, 0L, 800000L, 400000L, groupKeys,
+        aggregations);
 
-    query = "SELECT timeconvert(daysSinceEpoch, 'DAYS', 'HOURS'), COUNTMV(column6) FROM testTable GROUP BY timeconvert(daysSinceEpoch, 'DAYS', 'HOURS') ORDER BY COUNTMV(column6) DESC";
-    expectedDataSchema =
-        new DataSchema(new String[]{"timeconvert(daysSinceEpoch,'DAYS','HOURS')", "countmv(column6)"}, new DataSchema.ColumnDataType[]{DataSchema.ColumnDataType.LONG, DataSchema.ColumnDataType.LONG});
+    query =
+        "SELECT timeconvert(daysSinceEpoch, 'DAYS', 'HOURS'), COUNTMV(column6) FROM testTable GROUP BY timeconvert(daysSinceEpoch, 'DAYS', 'HOURS') ORDER BY COUNTMV(column6) DESC";
+    expectedDataSchema = new DataSchema(new String[]{"timeconvert(daysSinceEpoch,'DAYS','HOURS')", "countmv(column6)"},
+        new DataSchema.ColumnDataType[]{DataSchema.ColumnDataType.LONG, DataSchema.ColumnDataType.LONG});
     brokerResponse = getBrokerResponseForSqlQuery(query);
     result = new ArrayList<>();
-    result.add(new Object[]{42144376392L, (long)426752});
-    QueriesTestUtils.testInterSegmentResultTable(brokerResponse, 400000L, 0L, 800000L, 400000L,
-        result, 1, expectedDataSchema);
+    result.add(new Object[]{42144376392L, (long) 426752});
+    QueriesTestUtils
+        .testInterSegmentResultTable(brokerResponse, 400000L, 0L, 800000L, 400000L, result, 1, expectedDataSchema);
   }
 
   @Test
@@ -340,23 +348,27 @@ public class InterSegmentAggregationMultiValueQueriesTest extends BaseMultiValue
 
   @Test
   public void testPercentile50MV() {
-    String query = "SELECT PERCENTILE50MV(column6) FROM testTable";
-
-    BrokerResponseNative brokerResponse = getBrokerResponseForPqlQuery(query);
-    QueriesTestUtils.testInterSegmentAggregationResult(brokerResponse, 400000L, 0L, 400000L, 400000L,
-        new String[]{"2147483647.00000"});
-
-    brokerResponse = getBrokerResponseForPqlQueryWithFilter(query);
-    QueriesTestUtils.testInterSegmentAggregationResult(brokerResponse, 62480L, 1101664L, 62480L, 400000L,
-        new String[]{"2147483647.00000"});
-
-    brokerResponse = getBrokerResponseForPqlQuery(query + SV_GROUP_BY);
-    QueriesTestUtils.testInterSegmentAggregationResult(brokerResponse, 400000L, 0L, 800000L, 400000L,
-        new String[]{"2147483647.00000"});
-
-    brokerResponse = getBrokerResponseForPqlQuery(query + MV_GROUP_BY);
-    QueriesTestUtils.testInterSegmentAggregationResult(brokerResponse, 400000L, 0L, 800000L, 400000L,
-        new String[]{"2147483647.00000"});
+    List<String> queries = Arrays
+        .asList("SELECT PERCENTILE50MV(column6) FROM testTable", "SELECT PERCENTILEMV(column6, 50) FROM testTable",
+            "SELECT PERCENTILEMV(column6, '50') FROM testTable", "SELECT PERCENTILEMV(column6, \"50\") FROM testTable");
+
+    for (String query : queries) {
+      BrokerResponseNative brokerResponse = getBrokerResponseForPqlQuery(query);
+      QueriesTestUtils.testInterSegmentAggregationResult(brokerResponse, 400000L, 0L, 400000L, 400000L,
+          new String[]{"2147483647.00000"});
+
+      brokerResponse = getBrokerResponseForPqlQueryWithFilter(query);
+      QueriesTestUtils.testInterSegmentAggregationResult(brokerResponse, 62480L, 1101664L, 62480L, 400000L,
+          new String[]{"2147483647.00000"});
+
+      brokerResponse = getBrokerResponseForPqlQuery(query + SV_GROUP_BY);
+      QueriesTestUtils.testInterSegmentAggregationResult(brokerResponse, 400000L, 0L, 800000L, 400000L,
+          new String[]{"2147483647.00000"});
+
+      brokerResponse = getBrokerResponseForPqlQuery(query + MV_GROUP_BY);
+      QueriesTestUtils.testInterSegmentAggregationResult(brokerResponse, 400000L, 0L, 800000L, 400000L,
+          new String[]{"2147483647.00000"});
+    }
   }
 
   @Test
diff --git a/pinot-core/src/test/java/org/apache/pinot/queries/InterSegmentAggregationSingleValueQueriesTest.java b/pinot-core/src/test/java/org/apache/pinot/queries/InterSegmentAggregationSingleValueQueriesTest.java
index 8b18adb..70654d4 100644
--- a/pinot-core/src/test/java/org/apache/pinot/queries/InterSegmentAggregationSingleValueQueriesTest.java
+++ b/pinot-core/src/test/java/org/apache/pinot/queries/InterSegmentAggregationSingleValueQueriesTest.java
@@ -19,6 +19,8 @@
 package org.apache.pinot.queries;
 
 import java.io.Serializable;
+import java.util.Arrays;
+import java.util.List;
 import java.util.function.Function;
 import org.apache.pinot.common.response.broker.BrokerResponseNative;
 import org.apache.pinot.common.response.broker.SelectionResults;
@@ -238,23 +240,28 @@ public class InterSegmentAggregationSingleValueQueriesTest extends BaseSingleVal
 
   @Test
   public void testPercentile50() {
-    String query = "SELECT PERCENTILE50(column1), PERCENTILE50(column3) FROM testTable";
-
-    BrokerResponseNative brokerResponse = getBrokerResponseForPqlQuery(query);
-    QueriesTestUtils.testInterSegmentAggregationResult(brokerResponse, 120000L, 0L, 240000L, 120000L,
-        new String[]{"1107310944.00000", "1080136306.00000"});
-
-    brokerResponse = getBrokerResponseForPqlQueryWithFilter(query);
-    QueriesTestUtils.testInterSegmentAggregationResult(brokerResponse, 24516L, 336536L, 49032L, 120000L,
-        new String[]{"1139674505.00000", "505053732.00000"});
-
-    brokerResponse = getBrokerResponseForPqlQuery(query + GROUP_BY);
-    QueriesTestUtils.testInterSegmentAggregationResult(brokerResponse, 120000L, 0L, 360000L, 120000L,
-        new String[]{"2146791843.00000", "2141451242.00000"});
-
-    brokerResponse = getBrokerResponseForPqlQueryWithFilter(query + GROUP_BY);
-    QueriesTestUtils.testInterSegmentAggregationResult(brokerResponse, 24516L, 336536L, 73548L, 120000L,
-        new String[]{"2142595699.00000", "999309554.00000"});
+    List<String> queries = Arrays.asList("SELECT PERCENTILE50(column1), PERCENTILE50(column3) FROM testTable",
+        "SELECT PERCENTILE(column1, 50), PERCENTILE(column3, 50) FROM testTable",
+        "SELECT PERCENTILE(column1, '50'), PERCENTILE(column3, '50') FROM testTable",
+        "SELECT PERCENTILE(column1, \"50\"), PERCENTILE(column3, \"50\") FROM testTable");
+
+    for (String query : queries) {
+      BrokerResponseNative brokerResponse = getBrokerResponseForPqlQuery(query);
+      QueriesTestUtils.testInterSegmentAggregationResult(brokerResponse, 120000L, 0L, 240000L, 120000L,
+          new String[]{"1107310944.00000", "1080136306.00000"});
+
+      brokerResponse = getBrokerResponseForPqlQueryWithFilter(query);
+      QueriesTestUtils.testInterSegmentAggregationResult(brokerResponse, 24516L, 336536L, 49032L, 120000L,
+          new String[]{"1139674505.00000", "505053732.00000"});
+
+      brokerResponse = getBrokerResponseForPqlQuery(query + GROUP_BY);
+      QueriesTestUtils.testInterSegmentAggregationResult(brokerResponse, 120000L, 0L, 360000L, 120000L,
+          new String[]{"2146791843.00000", "2141451242.00000"});
+
+      brokerResponse = getBrokerResponseForPqlQueryWithFilter(query + GROUP_BY);
+      QueriesTestUtils.testInterSegmentAggregationResult(brokerResponse, 24516L, 336536L, 73548L, 120000L,
+          new String[]{"2142595699.00000", "999309554.00000"});
+    }
   }
 
   @Test
diff --git a/pinot-core/src/test/java/org/apache/pinot/queries/InterSegmentResultTableMultiValueQueriesTest.java b/pinot-core/src/test/java/org/apache/pinot/queries/InterSegmentResultTableMultiValueQueriesTest.java
index dcf5d7b..3026b0f 100644
--- a/pinot-core/src/test/java/org/apache/pinot/queries/InterSegmentResultTableMultiValueQueriesTest.java
+++ b/pinot-core/src/test/java/org/apache/pinot/queries/InterSegmentResultTableMultiValueQueriesTest.java
@@ -19,6 +19,7 @@
 package org.apache.pinot.queries;
 
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
@@ -471,48 +472,52 @@ public class InterSegmentResultTableMultiValueQueriesTest extends BaseMultiValue
 
   @Test
   public void testPercentile50MV() {
+    List<String> queries = Arrays
+        .asList("SELECT PERCENTILE50MV(column6) FROM testTable", "SELECT PERCENTILEMV(column6, 50) FROM testTable",
+            "SELECT PERCENTILEMV(column6, '50') FROM testTable", "SELECT PERCENTILEMV(column6, \"50\") FROM testTable");
+
     DataSchema dataSchema;
     List<Object[]> rows;
     int expectedResultsSize;
-    String query = "SELECT PERCENTILE50MV(column6) FROM testTable";
     Map<String, String> queryOptions = new HashMap<>(2);
     queryOptions.put(CommonConstants.Broker.Request.QueryOptionKey.RESPONSE_FORMAT, CommonConstants.Broker.Request.SQL);
-
-    BrokerResponseNative brokerResponse = getBrokerResponseForPqlQuery(query, queryOptions);
-    dataSchema = new DataSchema(new String[]{"percentile50mv(column6)"},
-        new DataSchema.ColumnDataType[]{DataSchema.ColumnDataType.DOUBLE});
-    rows = new ArrayList<>();
-    rows.add(new Object[]{2147483647.0});
-    expectedResultsSize = 1;
-    QueriesTestUtils
-        .testInterSegmentResultTable(brokerResponse, 400000L, 0L, 400000L, 400000L, rows, expectedResultsSize,
-            dataSchema);
-
-    brokerResponse = getBrokerResponseForPqlQuery(query + getFilter(), queryOptions);
-    rows = new ArrayList<>();
-    rows.add(new Object[]{2147483647.0});
-    QueriesTestUtils
-        .testInterSegmentResultTable(brokerResponse, 62480L, 1101664L, 62480L, 400000L, rows, expectedResultsSize,
-            dataSchema);
-
-    brokerResponse = getBrokerResponseForPqlQuery(query + SV_GROUP_BY, queryOptions);
-    dataSchema = new DataSchema(new String[]{"column8", "percentile50mv(column6)"},
-        new DataSchema.ColumnDataType[]{DataSchema.ColumnDataType.STRING, DataSchema.ColumnDataType.DOUBLE});
-    rows = new ArrayList<>();
-    rows.add(new Object[]{"169878844", 2147483647.0});
-    expectedResultsSize = 10;
-    QueriesTestUtils
-        .testInterSegmentResultTable(brokerResponse, 400000L, 0L, 800000L, 400000L, rows, expectedResultsSize,
-            dataSchema);
-
-    brokerResponse = getBrokerResponseForPqlQuery(query + MV_GROUP_BY, queryOptions);
-    dataSchema = new DataSchema(new String[]{"column7", "percentile50mv(column6)"},
-        new DataSchema.ColumnDataType[]{DataSchema.ColumnDataType.STRING, DataSchema.ColumnDataType.DOUBLE});
-    rows = new ArrayList<>();
-    rows.add(new Object[]{"372", 2147483647.0});
-    QueriesTestUtils
-        .testInterSegmentResultTable(brokerResponse, 400000L, 0L, 800000L, 400000L, rows, expectedResultsSize,
-            dataSchema);
+    for (String query : queries) {
+      BrokerResponseNative brokerResponse = getBrokerResponseForPqlQuery(query, queryOptions);
+      dataSchema = new DataSchema(new String[]{"percentile50mv(column6)"},
+          new DataSchema.ColumnDataType[]{DataSchema.ColumnDataType.DOUBLE});
+      rows = new ArrayList<>();
+      rows.add(new Object[]{2147483647.0});
+      expectedResultsSize = 1;
+      QueriesTestUtils
+          .testInterSegmentResultTable(brokerResponse, 400000L, 0L, 400000L, 400000L, rows, expectedResultsSize,
+              dataSchema);
+
+      brokerResponse = getBrokerResponseForPqlQuery(query + getFilter(), queryOptions);
+      rows = new ArrayList<>();
+      rows.add(new Object[]{2147483647.0});
+      QueriesTestUtils
+          .testInterSegmentResultTable(brokerResponse, 62480L, 1101664L, 62480L, 400000L, rows, expectedResultsSize,
+              dataSchema);
+
+      brokerResponse = getBrokerResponseForPqlQuery(query + SV_GROUP_BY, queryOptions);
+      dataSchema = new DataSchema(new String[]{"column8", "percentile50mv(column6)"},
+          new DataSchema.ColumnDataType[]{DataSchema.ColumnDataType.STRING, DataSchema.ColumnDataType.DOUBLE});
+      rows = new ArrayList<>();
+      rows.add(new Object[]{"169878844", 2147483647.0});
+      expectedResultsSize = 10;
+      QueriesTestUtils
+          .testInterSegmentResultTable(brokerResponse, 400000L, 0L, 800000L, 400000L, rows, expectedResultsSize,
+              dataSchema);
+
+      brokerResponse = getBrokerResponseForPqlQuery(query + MV_GROUP_BY, queryOptions);
+      dataSchema = new DataSchema(new String[]{"column7", "percentile50mv(column6)"},
+          new DataSchema.ColumnDataType[]{DataSchema.ColumnDataType.STRING, DataSchema.ColumnDataType.DOUBLE});
+      rows = new ArrayList<>();
+      rows.add(new Object[]{"372", 2147483647.0});
+      QueriesTestUtils
+          .testInterSegmentResultTable(brokerResponse, 400000L, 0L, 800000L, 400000L, rows, expectedResultsSize,
+              dataSchema);
+    }
   }
 
   @Test
diff --git a/pinot-core/src/test/java/org/apache/pinot/queries/InterSegmentResultTableSingleValueQueriesTest.java b/pinot-core/src/test/java/org/apache/pinot/queries/InterSegmentResultTableSingleValueQueriesTest.java
index fe6c6df..efa941c 100644
--- a/pinot-core/src/test/java/org/apache/pinot/queries/InterSegmentResultTableSingleValueQueriesTest.java
+++ b/pinot-core/src/test/java/org/apache/pinot/queries/InterSegmentResultTableSingleValueQueriesTest.java
@@ -20,6 +20,7 @@ package org.apache.pinot.queries;
 
 import com.google.common.collect.Lists;
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
@@ -482,47 +483,52 @@ public class InterSegmentResultTableSingleValueQueriesTest extends BaseSingleVal
 
   @Test
   public void testPercentile50() {
+    List<String> queries = Arrays.asList("SELECT PERCENTILE50(column1), PERCENTILE50(column3) FROM testTable",
+        "SELECT PERCENTILE(column1, 50), PERCENTILE(column3, 50) FROM testTable",
+        "SELECT PERCENTILE(column1, '50'), PERCENTILE(column3, '50') FROM testTable",
+        "SELECT PERCENTILE(column1, \"50\"), PERCENTILE(column3, \"50\") FROM testTable");
+
     DataSchema dataSchema;
     List<Object[]> rows;
     int expectedResultsSize;
-    String query = "SELECT PERCENTILE50(column1), PERCENTILE50(column3) FROM testTable";
     Map<String, String> queryOptions = new HashMap<>(2);
     queryOptions.put(QueryOptionKey.RESPONSE_FORMAT, Request.SQL);
-
-    BrokerResponseNative brokerResponse = getBrokerResponseForPqlQuery(query, queryOptions);
-    dataSchema = new DataSchema(new String[]{"percentile50(column1)", "percentile50(column3)"},
-        new DataSchema.ColumnDataType[]{DataSchema.ColumnDataType.DOUBLE, DataSchema.ColumnDataType.DOUBLE});
-    rows = new ArrayList<>();
-    rows.add(new Object[]{1107310944.0, 1080136306.0});
-    expectedResultsSize = 1;
-    QueriesTestUtils
-        .testInterSegmentResultTable(brokerResponse, 120000L, 0L, 240000L, 120000L, rows, expectedResultsSize,
-            dataSchema);
-
-    brokerResponse = getBrokerResponseForPqlQuery(query + getFilter(), queryOptions);
-    rows = new ArrayList<>();
-    rows.add(new Object[]{1139674505.0, 505053732.0});
-    QueriesTestUtils
-        .testInterSegmentResultTable(brokerResponse, 24516L, 336536L, 49032L, 120000L, rows, expectedResultsSize,
-            dataSchema);
-
-    query = "SELECT PERCENTILE50(column3) FROM testTable";
-    brokerResponse = getBrokerResponseForPqlQuery(query + GROUP_BY, queryOptions);
-    dataSchema = new DataSchema(new String[]{"column9", "percentile50(column3)"},
-        new DataSchema.ColumnDataType[]{DataSchema.ColumnDataType.STRING, DataSchema.ColumnDataType.DOUBLE});
-    rows = new ArrayList<>();
-    rows.add(new Object[]{"1642909995", 2141451242.0});
-    expectedResultsSize = 10;
-    QueriesTestUtils
-        .testInterSegmentResultTable(brokerResponse, 120000L, 0L, 240000L, 120000L, rows, expectedResultsSize,
-            dataSchema);
-
-    brokerResponse = getBrokerResponseForPqlQuery(query + GROUP_BY + getFilter(), queryOptions);
-    rows = new ArrayList<>();
-    rows.add(new Object[]{"438926263", 999309554.0});
-    QueriesTestUtils
-        .testInterSegmentResultTable(brokerResponse, 24516L, 336536L, 49032L, 120000L, rows, expectedResultsSize,
-            dataSchema);
+    for (String query : queries) {
+      BrokerResponseNative brokerResponse = getBrokerResponseForPqlQuery(query, queryOptions);
+      dataSchema = new DataSchema(new String[]{"percentile50(column1)", "percentile50(column3)"},
+          new DataSchema.ColumnDataType[]{DataSchema.ColumnDataType.DOUBLE, DataSchema.ColumnDataType.DOUBLE});
+      rows = new ArrayList<>();
+      rows.add(new Object[]{1107310944.0, 1080136306.0});
+      expectedResultsSize = 1;
+      QueriesTestUtils
+          .testInterSegmentResultTable(brokerResponse, 120000L, 0L, 240000L, 120000L, rows, expectedResultsSize,
+              dataSchema);
+
+      brokerResponse = getBrokerResponseForPqlQuery(query + getFilter(), queryOptions);
+      rows = new ArrayList<>();
+      rows.add(new Object[]{1139674505.0, 505053732.0});
+      QueriesTestUtils
+          .testInterSegmentResultTable(brokerResponse, 24516L, 336536L, 49032L, 120000L, rows, expectedResultsSize,
+              dataSchema);
+
+      query = "SELECT PERCENTILE50(column3) FROM testTable";
+      brokerResponse = getBrokerResponseForPqlQuery(query + GROUP_BY, queryOptions);
+      dataSchema = new DataSchema(new String[]{"column9", "percentile50(column3)"},
+          new DataSchema.ColumnDataType[]{DataSchema.ColumnDataType.STRING, DataSchema.ColumnDataType.DOUBLE});
+      rows = new ArrayList<>();
+      rows.add(new Object[]{"1642909995", 2141451242.0});
+      expectedResultsSize = 10;
+      QueriesTestUtils
+          .testInterSegmentResultTable(brokerResponse, 120000L, 0L, 240000L, 120000L, rows, expectedResultsSize,
+              dataSchema);
+
+      brokerResponse = getBrokerResponseForPqlQuery(query + GROUP_BY + getFilter(), queryOptions);
+      rows = new ArrayList<>();
+      rows.add(new Object[]{"438926263", 999309554.0});
+      QueriesTestUtils
+          .testInterSegmentResultTable(brokerResponse, 24516L, 336536L, 49032L, 120000L, rows, expectedResultsSize,
+              dataSchema);
+    }
   }
 
   @Test
diff --git a/pinot-core/src/test/java/org/apache/pinot/query/aggregation/DefaultAggregationExecutorTest.java b/pinot-core/src/test/java/org/apache/pinot/query/aggregation/DefaultAggregationExecutorTest.java
index ee13450..1371f9a 100644
--- a/pinot-core/src/test/java/org/apache/pinot/query/aggregation/DefaultAggregationExecutorTest.java
+++ b/pinot-core/src/test/java/org/apache/pinot/query/aggregation/DefaultAggregationExecutorTest.java
@@ -20,7 +20,6 @@ package org.apache.pinot.query.aggregation;
 
 import java.io.File;
 import java.util.ArrayList;
-import java.util.Collections;
 import java.util.HashMap;
 import java.util.HashSet;
 import java.util.List;
@@ -28,7 +27,7 @@ import java.util.Map;
 import java.util.Random;
 import java.util.Set;
 import org.apache.commons.io.FileUtils;
-import org.apache.pinot.common.request.AggregationInfo;
+import org.apache.pinot.common.request.BrokerRequest;
 import org.apache.pinot.common.request.transform.TransformExpressionTree;
 import org.apache.pinot.common.segment.ReadMode;
 import org.apache.pinot.core.common.DataSource;
@@ -43,10 +42,11 @@ import org.apache.pinot.core.operator.filter.MatchAllFilterOperator;
 import org.apache.pinot.core.operator.transform.TransformOperator;
 import org.apache.pinot.core.plan.DocIdSetPlanNode;
 import org.apache.pinot.core.query.aggregation.AggregationExecutor;
-import org.apache.pinot.core.query.aggregation.AggregationFunctionContext;
 import org.apache.pinot.core.query.aggregation.DefaultAggregationExecutor;
+import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
 import org.apache.pinot.core.query.aggregation.function.AggregationFunctionUtils;
 import org.apache.pinot.core.segment.creator.impl.SegmentIndexCreationDriverImpl;
+import org.apache.pinot.pql.parsers.Pql2Compiler;
 import org.apache.pinot.spi.config.table.TableType;
 import org.apache.pinot.spi.data.FieldSpec;
 import org.apache.pinot.spi.data.MetricFieldSpec;
@@ -88,8 +88,8 @@ public class DefaultAggregationExecutorTest {
 
   public static IndexSegment _indexSegment;
   private Random _random;
-  private List<AggregationInfo> _aggregationInfoList;
   private String[] _columns;
+  private BrokerRequest _brokerRequest;
   private double[][] _inputData;
 
   /**
@@ -110,15 +110,15 @@ public class DefaultAggregationExecutorTest {
     _columns = new String[numColumns];
     setupSegment();
 
-    _aggregationInfoList = new ArrayList<>();
-
-    for (int i = 0; i < _columns.length; i++) {
-      AggregationInfo aggregationInfo = new AggregationInfo();
-      aggregationInfo.setAggregationType(AGGREGATION_FUNCTIONS[i]);
-
-      aggregationInfo.setExpressions(Collections.singletonList(_columns[i]));
-      _aggregationInfoList.add(aggregationInfo);
+    StringBuilder queryBuilder = new StringBuilder("SELECT");
+    for (int i = 0; i < numColumns; i++) {
+      queryBuilder.append(String.format(" %s(%s)", AGGREGATION_FUNCTIONS[i], _columns[i]));
+      if (i != numColumns - 1) {
+        queryBuilder.append(',');
+      }
     }
+    queryBuilder.append(" FROM testTable");
+    _brokerRequest = new Pql2Compiler().compileToBrokerRequest(queryBuilder.toString());
   }
 
   /**
@@ -139,13 +139,8 @@ public class DefaultAggregationExecutorTest {
     ProjectionOperator projectionOperator = new ProjectionOperator(dataSourceMap, docIdSetOperator);
     TransformOperator transformOperator = new TransformOperator(projectionOperator, expressionTrees);
     TransformBlock transformBlock = transformOperator.nextBlock();
-    int numAggFuncs = _aggregationInfoList.size();
-    AggregationFunctionContext[] aggrFuncContextArray = new AggregationFunctionContext[numAggFuncs];
-    for (int i = 0; i < numAggFuncs; i++) {
-      AggregationInfo aggregationInfo = _aggregationInfoList.get(i);
-      aggrFuncContextArray[i] = AggregationFunctionUtils.getAggregationFunctionContext(aggregationInfo);
-    }
-    AggregationExecutor aggregationExecutor = new DefaultAggregationExecutor(aggrFuncContextArray);
+    AggregationFunction[] aggregationFunctions = AggregationFunctionUtils.getAggregationFunctions(_brokerRequest);
+    AggregationExecutor aggregationExecutor = new DefaultAggregationExecutor(aggregationFunctions);
     aggregationExecutor.aggregate(transformBlock);
     List<Object> result = aggregationExecutor.getResult();
     for (int i = 0; i < result.size(); i++) {
diff --git a/pinot-perf/src/main/java/org/apache/pinot/perf/BenchmarkCombineGroupBy.java b/pinot-perf/src/main/java/org/apache/pinot/perf/BenchmarkCombineGroupBy.java
index 954db47..05b2e1b 100644
--- a/pinot-perf/src/main/java/org/apache/pinot/perf/BenchmarkCombineGroupBy.java
+++ b/pinot-perf/src/main/java/org/apache/pinot/perf/BenchmarkCombineGroupBy.java
@@ -19,8 +19,9 @@
 package org.apache.pinot.perf;
 
 import com.google.common.base.Joiner;
-import com.google.common.collect.Lists;
 import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
@@ -37,7 +38,6 @@ import java.util.concurrent.TimeUnit;
 import java.util.concurrent.TimeoutException;
 import java.util.concurrent.atomic.AtomicInteger;
 import org.apache.commons.lang3.RandomStringUtils;
-import org.apache.pinot.common.request.AggregationInfo;
 import org.apache.pinot.common.request.GroupBy;
 import org.apache.pinot.common.request.SelectionSort;
 import org.apache.pinot.common.utils.DataSchema;
@@ -45,7 +45,8 @@ import org.apache.pinot.core.data.table.ConcurrentIndexedTable;
 import org.apache.pinot.core.data.table.IndexedTable;
 import org.apache.pinot.core.data.table.Record;
 import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
-import org.apache.pinot.core.query.aggregation.function.AggregationFunctionFactory;
+import org.apache.pinot.core.query.aggregation.function.MaxAggregationFunction;
+import org.apache.pinot.core.query.aggregation.function.SumAggregationFunction;
 import org.apache.pinot.core.query.aggregation.groupby.AggregationGroupByTrimmingService;
 import org.apache.pinot.core.query.aggregation.groupby.GroupKeyGenerator;
 import org.apache.pinot.core.query.utils.Pair;
@@ -77,11 +78,9 @@ public class BenchmarkCombineGroupBy {
   private Random _random = new Random();
 
   private DataSchema _dataSchema;
-  private List<AggregationInfo> _aggregationInfos;
-  private GroupBy _groupBy;
   private AggregationFunction[] _aggregationFunctions;
+  private GroupBy _groupBy;
   private List<SelectionSort> _orderBy;
-  private int _numAggregationFunctions;
 
   private List<String> _d1;
   private List<Integer> _d2;
@@ -105,36 +104,19 @@ public class BenchmarkCombineGroupBy {
     }
 
     _dataSchema = new DataSchema(new String[]{"d1", "d2", "sum(m1)", "max(m2)"},
-        new DataSchema.ColumnDataType[]{DataSchema.ColumnDataType.STRING, DataSchema.ColumnDataType.INT,
-            DataSchema.ColumnDataType.DOUBLE, DataSchema.ColumnDataType.DOUBLE});
-
-    AggregationInfo agg1 = new AggregationInfo();
-    List<String> args1 = new ArrayList<>(1);
-    args1.add("m1");
-    agg1.setExpressions(args1);
-    agg1.setAggregationType("sum");
-
-    AggregationInfo agg2 = new AggregationInfo();
-    List<String> args2 = new ArrayList<>(1);
-    args2.add("m2");
-    agg2.setExpressions(args2);
-    agg2.setAggregationType("max");
-    _aggregationInfos = Lists.newArrayList(agg1, agg2);
-
-    _numAggregationFunctions = 2;
-    _aggregationFunctions = new AggregationFunction[_numAggregationFunctions];
-    for (int i = 0; i < _numAggregationFunctions; i++) {
-      _aggregationFunctions[i] = AggregationFunctionFactory.getAggregationFunction(_aggregationInfos.get(i), null);
-    }
+        new DataSchema.ColumnDataType[]{DataSchema.ColumnDataType.STRING, DataSchema.ColumnDataType.INT, DataSchema.ColumnDataType.DOUBLE, DataSchema.ColumnDataType.DOUBLE});
+
+    _aggregationFunctions =
+        new AggregationFunction[]{new SumAggregationFunction("m1"), new MaxAggregationFunction("m2")};
 
     _groupBy = new GroupBy();
     _groupBy.setTopN(TOP_N);
-    _groupBy.setExpressions(Lists.newArrayList("d1", "d2"));
+    _groupBy.setExpressions(Arrays.asList("d1", "d2"));
 
-    SelectionSort orderBy = new SelectionSort();
-    orderBy.setColumn("sum(m1)");
-    orderBy.setIsAsc(true);
-    _orderBy = Lists.newArrayList(orderBy);
+    SelectionSort selectionSort = new SelectionSort();
+    selectionSort.setColumn("sum(m1)");
+    selectionSort.setIsAsc(true);
+    _orderBy = Collections.singletonList(selectionSort);
 
     _executorService = Executors.newFixedThreadPool(10);
   }
@@ -161,13 +143,14 @@ public class BenchmarkCombineGroupBy {
   @Benchmark
   @BenchmarkMode(Mode.AverageTime)
   @OutputTimeUnit(TimeUnit.MICROSECONDS)
-  public void concurrentIndexedTableForCombineGroupBy() throws InterruptedException, ExecutionException, TimeoutException {
+  public void concurrentIndexedTableForCombineGroupBy()
+      throws InterruptedException, ExecutionException, TimeoutException {
 
     int capacity = GroupByUtils.getTableCapacity(_groupBy, _orderBy);
 
     // make 1 concurrent table
     IndexedTable concurrentIndexedTable =
-        new ConcurrentIndexedTable(_dataSchema, _aggregationInfos, _orderBy, capacity);
+        new ConcurrentIndexedTable(_dataSchema, _aggregationFunctions, _orderBy, capacity);
 
     List<Callable<Void>> innerSegmentCallables = new ArrayList<>(NUM_SEGMENTS);
 
@@ -193,11 +176,11 @@ public class BenchmarkCombineGroupBy {
     concurrentIndexedTable.finish(false);
   }
 
-
   @Benchmark
   @BenchmarkMode(Mode.AverageTime)
   @OutputTimeUnit(TimeUnit.MICROSECONDS)
-  public void originalCombineGroupBy() throws InterruptedException, TimeoutException, ExecutionException {
+  public void originalCombineGroupBy()
+      throws InterruptedException, TimeoutException, ExecutionException {
 
     AtomicInteger numGroups = new AtomicInteger();
     int _interSegmentNumGroupsLimit = 200_000;
@@ -213,15 +196,14 @@ public class BenchmarkCombineGroupBy {
           final Object[] value = newRecordOriginal.getSecond();
 
           resultsMap.compute(stringKey, (k, v) -> {
+            int numAggregationFunctions = _aggregationFunctions.length;
             if (v == null) {
               if (numGroups.getAndIncrement() < _interSegmentNumGroupsLimit) {
-                v = new Object[_numAggregationFunctions];
-                for (int j = 0; j < _numAggregationFunctions; j++) {
-                  v[j] = value[j];
-                }
+                v = new Object[numAggregationFunctions];
+                System.arraycopy(value, 0, v, 0, numAggregationFunctions);
               }
             } else {
-              for (int j = 0; j < _numAggregationFunctions; j++) {
+              for (int j = 0; j < numAggregationFunctions; j++) {
                 v[j] = _aggregationFunctions[j].merge(v[j], value[j]);
               }
             }
@@ -243,13 +225,11 @@ public class BenchmarkCombineGroupBy {
     List<Map<String, Object>> trimmedResults = aggregationGroupByTrimmingService.trimIntermediateResultsMap(resultsMap);
   }
 
-  public static void main(String[] args) throws Exception {
-    ChainedOptionsBuilder opt = new OptionsBuilder().include(BenchmarkCombineGroupBy.class.getSimpleName())
-        .warmupTime(TimeValue.seconds(10))
-        .warmupIterations(1)
-        .measurementTime(TimeValue.seconds(30))
-        .measurementIterations(3)
-        .forks(1);
+  public static void main(String[] args)
+      throws Exception {
+    ChainedOptionsBuilder opt =
+        new OptionsBuilder().include(BenchmarkCombineGroupBy.class.getSimpleName()).warmupTime(TimeValue.seconds(10))
+            .warmupIterations(1).measurementTime(TimeValue.seconds(30)).measurementIterations(3).forks(1);
 
     new Runner(opt.build()).run();
   }
diff --git a/pinot-perf/src/main/java/org/apache/pinot/perf/BenchmarkIndexedTable.java b/pinot-perf/src/main/java/org/apache/pinot/perf/BenchmarkIndexedTable.java
index 7312b6b..0b17edb 100644
--- a/pinot-perf/src/main/java/org/apache/pinot/perf/BenchmarkIndexedTable.java
+++ b/pinot-perf/src/main/java/org/apache/pinot/perf/BenchmarkIndexedTable.java
@@ -18,8 +18,8 @@
  */
 package org.apache.pinot.perf;
 
-import com.google.common.collect.Lists;
 import java.util.ArrayList;
+import java.util.Collections;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Random;
@@ -33,13 +33,15 @@ import java.util.concurrent.Future;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.TimeoutException;
 import org.apache.commons.lang3.RandomStringUtils;
-import org.apache.pinot.common.request.AggregationInfo;
 import org.apache.pinot.common.request.SelectionSort;
 import org.apache.pinot.common.utils.DataSchema;
 import org.apache.pinot.core.data.table.ConcurrentIndexedTable;
 import org.apache.pinot.core.data.table.IndexedTable;
 import org.apache.pinot.core.data.table.Record;
 import org.apache.pinot.core.data.table.SimpleIndexedTable;
+import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
+import org.apache.pinot.core.query.aggregation.function.MaxAggregationFunction;
+import org.apache.pinot.core.query.aggregation.function.SumAggregationFunction;
 import org.apache.pinot.core.util.trace.TraceRunnable;
 import org.openjdk.jmh.annotations.Benchmark;
 import org.openjdk.jmh.annotations.BenchmarkMode;
@@ -63,7 +65,7 @@ public class BenchmarkIndexedTable {
   private Random _random = new Random();
 
   private DataSchema _dataSchema;
-  private List<AggregationInfo> _aggregationInfos;
+  private AggregationFunction[] _aggregationFunctions;
   private List<SelectionSort> _orderBy;
 
   private List<String> _d1;
@@ -71,7 +73,6 @@ public class BenchmarkIndexedTable {
 
   private ExecutorService _executorService;
 
-
   @Setup
   public void setup() {
     // create data
@@ -90,26 +91,15 @@ public class BenchmarkIndexedTable {
     }
 
     _dataSchema = new DataSchema(new String[]{"d1", "d2", "sum(m1)", "max(m2)"},
-        new DataSchema.ColumnDataType[]{DataSchema.ColumnDataType.STRING, DataSchema.ColumnDataType.INT,
-            DataSchema.ColumnDataType.DOUBLE, DataSchema.ColumnDataType.DOUBLE});
-
-    AggregationInfo agg1 = new AggregationInfo();
-    List<String> args1 = new ArrayList<>();
-    args1.add("m1");
-    agg1.setExpressions(args1);
-    agg1.setAggregationType("sum");
-
-    AggregationInfo agg2 = new AggregationInfo();
-    List<String> args2 = new ArrayList<>();
-    args2.add("m2");
-    agg2.setExpressions(args2);
-    agg2.setAggregationType("max");
-    _aggregationInfos = Lists.newArrayList(agg1, agg2);
-
-    SelectionSort orderBy = new SelectionSort();
-    orderBy.setColumn("sum(m1)");
-    orderBy.setIsAsc(true);
-    _orderBy = Lists.newArrayList(orderBy);
+        new DataSchema.ColumnDataType[]{DataSchema.ColumnDataType.STRING, DataSchema.ColumnDataType.INT, DataSchema.ColumnDataType.DOUBLE, DataSchema.ColumnDataType.DOUBLE});
+
+    _aggregationFunctions =
+        new AggregationFunction[]{new SumAggregationFunction("m1"), new MaxAggregationFunction("m2")};
+
+    SelectionSort selectionSort = new SelectionSort();
+    selectionSort.setColumn("sum(m1)");
+    selectionSort.setIsAsc(true);
+    _orderBy = Collections.singletonList(selectionSort);
 
     _executorService = Executors.newFixedThreadPool(10);
   }
@@ -129,13 +119,14 @@ public class BenchmarkIndexedTable {
   @Benchmark
   @BenchmarkMode(Mode.AverageTime)
   @OutputTimeUnit(TimeUnit.MICROSECONDS)
-  public void concurrentIndexedTable() throws InterruptedException, ExecutionException, TimeoutException {
+  public void concurrentIndexedTable()
+      throws InterruptedException, ExecutionException, TimeoutException {
 
     int numSegments = 10;
 
     // make 1 concurrent table
     IndexedTable concurrentIndexedTable =
-        new ConcurrentIndexedTable(_dataSchema, _aggregationInfos, _orderBy, CAPACITY);
+        new ConcurrentIndexedTable(_dataSchema, _aggregationFunctions, _orderBy, CAPACITY);
 
     // 10 parallel threads putting 10k records into the table
 
@@ -172,11 +163,11 @@ public class BenchmarkIndexedTable {
     }
   }
 
-
   @Benchmark
   @BenchmarkMode(Mode.AverageTime)
   @OutputTimeUnit(TimeUnit.MICROSECONDS)
-  public void simpleIndexedTable() throws InterruptedException, TimeoutException, ExecutionException {
+  public void simpleIndexedTable()
+      throws InterruptedException, TimeoutException, ExecutionException {
 
     int numSegments = 10;
 
@@ -186,7 +177,7 @@ public class BenchmarkIndexedTable {
     for (int i = 0; i < numSegments; i++) {
 
       // make 10 indexed tables
-      IndexedTable simpleIndexedTable = new SimpleIndexedTable(_dataSchema, _aggregationInfos, _orderBy, CAPACITY);
+      IndexedTable simpleIndexedTable = new SimpleIndexedTable(_dataSchema, _aggregationFunctions, _orderBy, CAPACITY);
       simpleIndexedTables.add(simpleIndexedTable);
 
       // put 10k records in each indexed table, in parallel
@@ -217,13 +208,11 @@ public class BenchmarkIndexedTable {
     mergedTable.finish(false);
   }
 
-  public static void main(String[] args) throws Exception {
-    ChainedOptionsBuilder opt = new OptionsBuilder().include(BenchmarkIndexedTable.class.getSimpleName())
-        .warmupTime(TimeValue.seconds(10))
-        .warmupIterations(1)
-        .measurementTime(TimeValue.seconds(30))
-        .measurementIterations(3)
-        .forks(1);
+  public static void main(String[] args)
+      throws Exception {
+    ChainedOptionsBuilder opt =
+        new OptionsBuilder().include(BenchmarkIndexedTable.class.getSimpleName()).warmupTime(TimeValue.seconds(10))
+            .warmupIterations(1).measurementTime(TimeValue.seconds(30)).measurementIterations(3).forks(1);
 
     new Runner(opt.build()).run();
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@pinot.apache.org
For additional commands, e-mail: commits-help@pinot.apache.org