You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@pinot.apache.org by ja...@apache.org on 2023/11/18 21:18:24 UTC

(pinot) branch master updated: Adds support for leveraging StarTree index in conjunction with filtered aggregations (#11886)

This is an automated email from the ASF dual-hosted git repository.

jackie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git


The following commit(s) were added to refs/heads/master by this push:
     new f7f82608e0 Adds support for leveraging StarTree index in conjunction with filtered aggregations (#11886)
f7f82608e0 is described below

commit f7f82608e098b26f2d241627d8a0a811cd8f5fe8
Author: Evan Galpin <eg...@users.noreply.github.com>
AuthorDate: Sat Nov 18 13:18:18 2023 -0800

    Adds support for leveraging StarTree index in conjunction with filtered aggregations (#11886)
---
 .../core/operator/query/AggregationOperator.java   |  10 +-
 .../query/FilteredAggregationOperator.java         |  27 ++--
 .../operator/query/FilteredGroupByOperator.java    |  42 +++---
 .../pinot/core/operator/query/GroupByOperator.java |  19 +--
 .../pinot/core/plan/AggregationPlanNode.java       |  54 ++------
 .../org/apache/pinot/core/plan/FilterPlanNode.java |  10 +-
 .../apache/pinot/core/plan/GroupByPlanNode.java    |  61 +--------
 .../function/AggregationFunctionUtils.java         | 152 ++++++++++++++++-----
 .../apache/pinot/core/startree/StarTreeUtils.java  |  44 +++++-
 .../startree/executor/StarTreeGroupByExecutor.java |  16 ++-
 .../tests/OfflineClusterIntegrationTest.java       |  23 ++++
 .../tests/StarTreeClusterIntegrationTest.java      |  52 ++++++-
 .../tests/startree/StarTreeQueryGenerator.java     |  17 ++-
 13 files changed, 335 insertions(+), 192 deletions(-)

diff --git a/pinot-core/src/main/java/org/apache/pinot/core/operator/query/AggregationOperator.java b/pinot-core/src/main/java/org/apache/pinot/core/operator/query/AggregationOperator.java
index f810261049..0c89cec32e 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/operator/query/AggregationOperator.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/operator/query/AggregationOperator.java
@@ -28,6 +28,7 @@ import org.apache.pinot.core.operator.blocks.results.AggregationResultsBlock;
 import org.apache.pinot.core.query.aggregation.AggregationExecutor;
 import org.apache.pinot.core.query.aggregation.DefaultAggregationExecutor;
 import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
+import org.apache.pinot.core.query.aggregation.function.AggregationFunctionUtils.AggregationInfo;
 import org.apache.pinot.core.query.request.context.QueryContext;
 import org.apache.pinot.core.startree.executor.StarTreeAggregationExecutor;
 
@@ -42,18 +43,17 @@ public class AggregationOperator extends BaseOperator<AggregationResultsBlock> {
   private final QueryContext _queryContext;
   private final AggregationFunction[] _aggregationFunctions;
   private final BaseProjectOperator<?> _projectOperator;
-  private final long _numTotalDocs;
   private final boolean _useStarTree;
+  private final long _numTotalDocs;
 
   private int _numDocsScanned = 0;
 
-  public AggregationOperator(QueryContext queryContext, BaseProjectOperator<?> projectOperator, long numTotalDocs,
-      boolean useStarTree) {
+  public AggregationOperator(QueryContext queryContext, AggregationInfo aggregationInfo, long numTotalDocs) {
     _queryContext = queryContext;
     _aggregationFunctions = queryContext.getAggregationFunctions();
-    _projectOperator = projectOperator;
+    _projectOperator = aggregationInfo.getProjectOperator();
+    _useStarTree = aggregationInfo.isUseStarTree();
     _numTotalDocs = numTotalDocs;
-    _useStarTree = useStarTree;
   }
 
   @Override
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/operator/query/FilteredAggregationOperator.java b/pinot-core/src/main/java/org/apache/pinot/core/operator/query/FilteredAggregationOperator.java
index 9d13c06cd0..ed68c54ab1 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/operator/query/FilteredAggregationOperator.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/operator/query/FilteredAggregationOperator.java
@@ -22,7 +22,6 @@ import java.util.Arrays;
 import java.util.IdentityHashMap;
 import java.util.List;
 import java.util.stream.Collectors;
-import org.apache.commons.lang3.tuple.Pair;
 import org.apache.pinot.core.common.Operator;
 import org.apache.pinot.core.operator.BaseOperator;
 import org.apache.pinot.core.operator.BaseProjectOperator;
@@ -32,7 +31,9 @@ import org.apache.pinot.core.operator.blocks.results.AggregationResultsBlock;
 import org.apache.pinot.core.query.aggregation.AggregationExecutor;
 import org.apache.pinot.core.query.aggregation.DefaultAggregationExecutor;
 import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
+import org.apache.pinot.core.query.aggregation.function.AggregationFunctionUtils.AggregationInfo;
 import org.apache.pinot.core.query.request.context.QueryContext;
+import org.apache.pinot.core.startree.executor.StarTreeAggregationExecutor;
 
 
 /**
@@ -47,18 +48,18 @@ public class FilteredAggregationOperator extends BaseOperator<AggregationResults
 
   private final QueryContext _queryContext;
   private final AggregationFunction[] _aggregationFunctions;
-  private final List<Pair<AggregationFunction[], BaseProjectOperator<?>>> _projectOperators;
+  private final List<AggregationInfo> _aggregationInfos;
   private final long _numTotalDocs;
 
   private long _numDocsScanned;
   private long _numEntriesScannedInFilter;
   private long _numEntriesScannedPostFilter;
 
-  public FilteredAggregationOperator(QueryContext queryContext,
-      List<Pair<AggregationFunction[], BaseProjectOperator<?>>> projectOperators, long numTotalDocs) {
+  public FilteredAggregationOperator(QueryContext queryContext, List<AggregationInfo> aggregationInfos,
+      long numTotalDocs) {
     _queryContext = queryContext;
     _aggregationFunctions = queryContext.getAggregationFunctions();
-    _projectOperators = projectOperators;
+    _aggregationInfos = aggregationInfos;
     _numTotalDocs = numTotalDocs;
   }
 
@@ -71,10 +72,16 @@ public class FilteredAggregationOperator extends BaseOperator<AggregationResults
       resultIndexMap.put(_aggregationFunctions[i], i);
     }
 
-    for (Pair<AggregationFunction[], BaseProjectOperator<?>> pair : _projectOperators) {
-      AggregationFunction[] aggregationFunctions = pair.getLeft();
-      AggregationExecutor aggregationExecutor = new DefaultAggregationExecutor(aggregationFunctions);
-      BaseProjectOperator<?> projectOperator = pair.getRight();
+    for (AggregationInfo aggregationInfo : _aggregationInfos) {
+      AggregationFunction[] aggregationFunctions = aggregationInfo.getFunctions();
+      BaseProjectOperator<?> projectOperator = aggregationInfo.getProjectOperator();
+      AggregationExecutor aggregationExecutor;
+      if (aggregationInfo.isUseStarTree()) {
+        aggregationExecutor = new StarTreeAggregationExecutor(aggregationFunctions);
+      } else {
+        aggregationExecutor = new DefaultAggregationExecutor(aggregationFunctions);
+      }
+
       ValueBlock valueBlock;
       int numDocsScanned = 0;
       while ((valueBlock = projectOperator.nextBlock()) != null) {
@@ -95,7 +102,7 @@ public class FilteredAggregationOperator extends BaseOperator<AggregationResults
 
   @Override
   public List<Operator> getChildOperators() {
-    return _projectOperators.stream().map(Pair::getRight).collect(Collectors.toList());
+    return _aggregationInfos.stream().map(AggregationInfo::getProjectOperator).collect(Collectors.toList());
   }
 
   @Override
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/operator/query/FilteredGroupByOperator.java b/pinot-core/src/main/java/org/apache/pinot/core/operator/query/FilteredGroupByOperator.java
index 431542eeba..0a8c20bc93 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/operator/query/FilteredGroupByOperator.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/operator/query/FilteredGroupByOperator.java
@@ -36,11 +36,13 @@ import org.apache.pinot.core.operator.blocks.ValueBlock;
 import org.apache.pinot.core.operator.blocks.results.GroupByResultsBlock;
 import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
 import org.apache.pinot.core.query.aggregation.function.AggregationFunctionUtils;
+import org.apache.pinot.core.query.aggregation.function.AggregationFunctionUtils.AggregationInfo;
 import org.apache.pinot.core.query.aggregation.groupby.AggregationGroupByResult;
 import org.apache.pinot.core.query.aggregation.groupby.DefaultGroupByExecutor;
 import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
 import org.apache.pinot.core.query.aggregation.groupby.GroupKeyGenerator;
 import org.apache.pinot.core.query.request.context.QueryContext;
+import org.apache.pinot.core.startree.executor.StarTreeGroupByExecutor;
 import org.apache.pinot.core.util.GroupByUtils;
 import org.apache.pinot.spi.trace.Tracing;
 
@@ -56,7 +58,7 @@ public class FilteredGroupByOperator extends BaseOperator<GroupByResultsBlock> {
   private final QueryContext _queryContext;
   private final AggregationFunction[] _aggregationFunctions;
   private final ExpressionContext[] _groupByExpressions;
-  private final List<Pair<AggregationFunction[], BaseProjectOperator<?>>> _projectOperators;
+  private final List<AggregationInfo> _aggregationInfos;
   private final long _numTotalDocs;
   private final DataSchema _dataSchema;
 
@@ -64,14 +66,13 @@ public class FilteredGroupByOperator extends BaseOperator<GroupByResultsBlock> {
   private long _numEntriesScannedInFilter;
   private long _numEntriesScannedPostFilter;
 
-  public FilteredGroupByOperator(QueryContext queryContext,
-      List<Pair<AggregationFunction[], BaseProjectOperator<?>>> projectOperators, long numTotalDocs) {
+  public FilteredGroupByOperator(QueryContext queryContext, List<AggregationInfo> aggregationInfos, long numTotalDocs) {
     assert queryContext.getAggregationFunctions() != null && queryContext.getFilteredAggregationFunctions() != null
         && queryContext.getGroupByExpressions() != null;
     _queryContext = queryContext;
     _aggregationFunctions = queryContext.getAggregationFunctions();
     _groupByExpressions = queryContext.getGroupByExpressions().toArray(new ExpressionContext[0]);
-    _projectOperators = projectOperators;
+    _aggregationInfos = aggregationInfos;
     _numTotalDocs = numTotalDocs;
 
     // NOTE: The indexedTable expects that the data schema will have group by columns before aggregation columns
@@ -82,7 +83,7 @@ public class FilteredGroupByOperator extends BaseOperator<GroupByResultsBlock> {
     DataSchema.ColumnDataType[] columnDataTypes = new DataSchema.ColumnDataType[numColumns];
 
     // Extract column names and data types for group-by columns
-    BaseProjectOperator<?> projectOperator = projectOperators.get(0).getRight();
+    BaseProjectOperator<?> projectOperator = aggregationInfos.get(0).getProjectOperator();
     for (int i = 0; i < numGroupByExpressions; i++) {
       ExpressionContext groupByExpression = _groupByExpressions[i];
       columnNames[i] = groupByExpression.toString();
@@ -105,9 +106,7 @@ public class FilteredGroupByOperator extends BaseOperator<GroupByResultsBlock> {
 
   @Override
   protected GroupByResultsBlock getNextBlock() {
-    // TODO(egalpin): Support Startree query resolution when possible, even with FILTER expressions
     int numAggregations = _aggregationFunctions.length;
-
     GroupByResultHolder[] groupByResultHolders = new GroupByResultHolder[numAggregations];
     IdentityHashMap<AggregationFunction, Integer> resultHolderIndexMap =
         new IdentityHashMap<>(_aggregationFunctions.length);
@@ -116,9 +115,9 @@ public class FilteredGroupByOperator extends BaseOperator<GroupByResultsBlock> {
     }
 
     GroupKeyGenerator groupKeyGenerator = null;
-    for (Pair<AggregationFunction[], BaseProjectOperator<?>> pair : _projectOperators) {
-      AggregationFunction[] aggregationFunctions = pair.getLeft();
-      BaseProjectOperator<?> projectOperator = pair.getRight();
+    for (AggregationInfo aggregationInfo : _aggregationInfos) {
+      AggregationFunction[] aggregationFunctions = aggregationInfo.getFunctions();
+      BaseProjectOperator<?> projectOperator = aggregationInfo.getProjectOperator();
 
       // Perform aggregation group-by on all the blocks
       DefaultGroupByExecutor groupByExecutor;
@@ -130,13 +129,24 @@ public class FilteredGroupByOperator extends BaseOperator<GroupByResultsBlock> {
         // the GroupByExecutor to have sole ownership of the GroupKeyGenerator. Therefore, we allow constructing a
         // GroupByExecutor with a pre-existing GroupKeyGenerator so that the GroupKeyGenerator can be shared across
         // loop iterations i.e. across all aggs.
-        groupByExecutor =
-            new DefaultGroupByExecutor(_queryContext, aggregationFunctions, _groupByExpressions, projectOperator);
+        if (aggregationInfo.isUseStarTree()) {
+          groupByExecutor =
+              new StarTreeGroupByExecutor(_queryContext, aggregationFunctions, _groupByExpressions, projectOperator);
+        } else {
+          groupByExecutor =
+              new DefaultGroupByExecutor(_queryContext, aggregationFunctions, _groupByExpressions, projectOperator);
+        }
         groupKeyGenerator = groupByExecutor.getGroupKeyGenerator();
       } else {
-        groupByExecutor =
-            new DefaultGroupByExecutor(_queryContext, aggregationFunctions, _groupByExpressions, projectOperator,
-                groupKeyGenerator);
+        if (aggregationInfo.isUseStarTree()) {
+          groupByExecutor =
+              new StarTreeGroupByExecutor(_queryContext, aggregationFunctions, _groupByExpressions, projectOperator,
+                  groupKeyGenerator);
+        } else {
+          groupByExecutor =
+              new DefaultGroupByExecutor(_queryContext, aggregationFunctions, _groupByExpressions, projectOperator,
+                  groupKeyGenerator);
+        }
       }
 
       int numDocsScanned = 0;
@@ -191,7 +201,7 @@ public class FilteredGroupByOperator extends BaseOperator<GroupByResultsBlock> {
 
   @Override
   public List<Operator> getChildOperators() {
-    return _projectOperators.stream().map(Pair::getRight).collect(Collectors.toList());
+    return _aggregationInfos.stream().map(AggregationInfo::getProjectOperator).collect(Collectors.toList());
   }
 
   @Override
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/operator/query/GroupByOperator.java b/pinot-core/src/main/java/org/apache/pinot/core/operator/query/GroupByOperator.java
index 6ccc250925..eaae20e7e3 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/operator/query/GroupByOperator.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/operator/query/GroupByOperator.java
@@ -32,6 +32,7 @@ import org.apache.pinot.core.operator.ExecutionStatistics;
 import org.apache.pinot.core.operator.blocks.ValueBlock;
 import org.apache.pinot.core.operator.blocks.results.GroupByResultsBlock;
 import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
+import org.apache.pinot.core.query.aggregation.function.AggregationFunctionUtils.AggregationInfo;
 import org.apache.pinot.core.query.aggregation.groupby.DefaultGroupByExecutor;
 import org.apache.pinot.core.query.aggregation.groupby.GroupByExecutor;
 import org.apache.pinot.core.query.request.context.QueryContext;
@@ -51,23 +52,23 @@ public class GroupByOperator extends BaseOperator<GroupByResultsBlock> {
   private final AggregationFunction[] _aggregationFunctions;
   private final ExpressionContext[] _groupByExpressions;
   private final BaseProjectOperator<?> _projectOperator;
-  private final long _numTotalDocs;
   private final boolean _useStarTree;
+  private final long _numTotalDocs;
   private final DataSchema _dataSchema;
 
   private int _numDocsScanned = 0;
 
-  public GroupByOperator(QueryContext queryContext, ExpressionContext[] groupByExpressions,
-      BaseProjectOperator<?> projectOperator, long numTotalDocs, boolean useStarTree) {
+  public GroupByOperator(QueryContext queryContext, AggregationInfo aggregationInfo, long numTotalDocs) {
+    assert queryContext.getAggregationFunctions() != null && queryContext.getGroupByExpressions() != null;
     _queryContext = queryContext;
     _aggregationFunctions = queryContext.getAggregationFunctions();
-    _groupByExpressions = groupByExpressions;
-    _projectOperator = projectOperator;
+    _groupByExpressions = queryContext.getGroupByExpressions().toArray(new ExpressionContext[0]);
+    _projectOperator = aggregationInfo.getProjectOperator();
+    _useStarTree = aggregationInfo.isUseStarTree();
     _numTotalDocs = numTotalDocs;
-    _useStarTree = useStarTree;
 
-    // NOTE: The indexedTable expects that the the data schema will have group by columns before aggregation columns
-    int numGroupByExpressions = groupByExpressions.length;
+    // NOTE: The indexedTable expects that the data schema will have group by columns before aggregation columns
+    int numGroupByExpressions = _groupByExpressions.length;
     int numAggregationFunctions = _aggregationFunctions.length;
     int numColumns = numGroupByExpressions + numAggregationFunctions;
     String[] columnNames = new String[numColumns];
@@ -75,7 +76,7 @@ public class GroupByOperator extends BaseOperator<GroupByResultsBlock> {
 
     // Extract column names and data types for group-by columns
     for (int i = 0; i < numGroupByExpressions; i++) {
-      ExpressionContext groupByExpression = groupByExpressions[i];
+      ExpressionContext groupByExpression = _groupByExpressions[i];
       columnNames[i] = groupByExpression.toString();
       columnDataTypes[i] = DataSchema.ColumnDataType.fromDataTypeSV(
           _projectOperator.getResultColumnContext(groupByExpression).getDataType());
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/plan/AggregationPlanNode.java b/pinot-core/src/main/java/org/apache/pinot/core/plan/AggregationPlanNode.java
index 2cf067864e..d09e465cf0 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/plan/AggregationPlanNode.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/plan/AggregationPlanNode.java
@@ -20,12 +20,8 @@ package org.apache.pinot.core.plan;
 
 import java.util.EnumSet;
 import java.util.List;
-import java.util.Map;
-import java.util.Set;
-import org.apache.commons.lang3.tuple.Pair;
 import org.apache.pinot.common.request.context.ExpressionContext;
 import org.apache.pinot.core.common.Operator;
-import org.apache.pinot.core.operator.BaseProjectOperator;
 import org.apache.pinot.core.operator.blocks.results.AggregationResultsBlock;
 import org.apache.pinot.core.operator.filter.BaseFilterOperator;
 import org.apache.pinot.core.operator.query.AggregationOperator;
@@ -34,15 +30,11 @@ import org.apache.pinot.core.operator.query.FilteredAggregationOperator;
 import org.apache.pinot.core.operator.query.NonScanBasedAggregationOperator;
 import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
 import org.apache.pinot.core.query.aggregation.function.AggregationFunctionUtils;
+import org.apache.pinot.core.query.aggregation.function.AggregationFunctionUtils.AggregationInfo;
 import org.apache.pinot.core.query.request.context.QueryContext;
-import org.apache.pinot.core.startree.CompositePredicateEvaluator;
-import org.apache.pinot.core.startree.StarTreeUtils;
-import org.apache.pinot.core.startree.plan.StarTreeProjectPlanNode;
 import org.apache.pinot.segment.spi.AggregationFunctionType;
 import org.apache.pinot.segment.spi.IndexSegment;
 import org.apache.pinot.segment.spi.datasource.DataSource;
-import org.apache.pinot.segment.spi.index.startree.AggregationFunctionColumnPair;
-import org.apache.pinot.segment.spi.index.startree.StarTreeV2;
 
 import static org.apache.pinot.segment.spi.AggregationFunctionType.*;
 
@@ -81,9 +73,8 @@ public class AggregationPlanNode implements PlanNode {
    * Build the operator to be used for filtered aggregations
    */
   private FilteredAggregationOperator buildFilteredAggOperator() {
-    List<Pair<AggregationFunction[], BaseProjectOperator<?>>> projectOperators =
-        AggregationFunctionUtils.buildFilteredAggregateProjectOperators(_indexSegment, _queryContext);
-    return new FilteredAggregationOperator(_queryContext, projectOperators,
+    return new FilteredAggregationOperator(_queryContext,
+        AggregationFunctionUtils.buildFilteredAggregationInfos(_indexSegment, _queryContext),
         _indexSegment.getSegmentMetadata().getTotalDocs());
   }
 
@@ -93,11 +84,10 @@ public class AggregationPlanNode implements PlanNode {
    * aggregates code will be invoked
    */
   public Operator<AggregationResultsBlock> buildNonFilteredAggOperator() {
-    assert _queryContext.getAggregationFunctions() != null;
-
-    int numTotalDocs = _indexSegment.getSegmentMetadata().getTotalDocs();
     AggregationFunction[] aggregationFunctions = _queryContext.getAggregationFunctions();
+    assert aggregationFunctions != null;
 
+    int numTotalDocs = _indexSegment.getSegmentMetadata().getTotalDocs();
     FilterPlanNode filterPlanNode = new FilterPlanNode(_indexSegment, _queryContext);
     BaseFilterOperator filterOperator = filterPlanNode.run();
 
@@ -117,38 +107,12 @@ public class AggregationPlanNode implements PlanNode {
         }
         return new NonScanBasedAggregationOperator(_queryContext, dataSources, numTotalDocs);
       }
-
-      // Use star-tree to solve the query if possible
-      List<StarTreeV2> starTrees = _indexSegment.getStarTrees();
-      if (!filterOperator.isResultEmpty() && starTrees != null && !_queryContext.isSkipStarTree()) {
-        AggregationFunctionColumnPair[] aggregationFunctionColumnPairs =
-            StarTreeUtils.extractAggregationFunctionPairs(aggregationFunctions);
-        if (aggregationFunctionColumnPairs != null) {
-          Map<String, List<CompositePredicateEvaluator>> predicateEvaluatorsMap =
-              StarTreeUtils.extractPredicateEvaluatorsMap(_indexSegment, _queryContext.getFilter(),
-                  filterPlanNode.getPredicateEvaluators());
-          if (predicateEvaluatorsMap != null) {
-            for (StarTreeV2 starTreeV2 : starTrees) {
-              if (StarTreeUtils.isFitForStarTree(starTreeV2.getMetadata(), aggregationFunctionColumnPairs, null,
-                  predicateEvaluatorsMap.keySet())) {
-                BaseProjectOperator<?> projectOperator =
-                    new StarTreeProjectPlanNode(_queryContext, starTreeV2, aggregationFunctionColumnPairs, null,
-                        predicateEvaluatorsMap).run();
-                return new AggregationOperator(_queryContext, projectOperator, numTotalDocs, true);
-              }
-            }
-          }
-        }
-      }
     }
 
-    // TODO: Do not create ProjectOperator when filter result is empty
-    Set<ExpressionContext> expressionsToTransform =
-        AggregationFunctionUtils.collectExpressionsToTransform(aggregationFunctions, null);
-    BaseProjectOperator<?> projectOperator =
-        new ProjectPlanNode(_indexSegment, _queryContext, expressionsToTransform, DocIdSetPlanNode.MAX_DOC_PER_CALL,
-            filterOperator).run();
-    return new AggregationOperator(_queryContext, projectOperator, numTotalDocs, false);
+    AggregationInfo aggregationInfo =
+        AggregationFunctionUtils.buildAggregationInfo(_indexSegment, _queryContext, aggregationFunctions,
+            _queryContext.getFilter(), filterOperator, filterPlanNode.getPredicateEvaluators());
+    return new AggregationOperator(_queryContext, aggregationInfo, numTotalDocs);
   }
 
   /**
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/plan/FilterPlanNode.java b/pinot-core/src/main/java/org/apache/pinot/core/plan/FilterPlanNode.java
index 6eafcbf170..d1cef2ef26 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/plan/FilterPlanNode.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/plan/FilterPlanNode.java
@@ -61,7 +61,6 @@ import org.roaringbitmap.buffer.MutableRoaringBitmap;
 
 
 public class FilterPlanNode implements PlanNode {
-
   private final IndexSegment _indexSegment;
   private final QueryContext _queryContext;
   private final FilterContext _filter;
@@ -76,7 +75,7 @@ public class FilterPlanNode implements PlanNode {
   public FilterPlanNode(IndexSegment indexSegment, QueryContext queryContext, @Nullable FilterContext filter) {
     _indexSegment = indexSegment;
     _queryContext = queryContext;
-    _filter = filter;
+    _filter = filter != null ? filter : _queryContext.getFilter();
   }
 
   @Override
@@ -96,9 +95,8 @@ public class FilterPlanNode implements PlanNode {
     }
     int numDocs = _indexSegment.getSegmentMetadata().getTotalDocs();
 
-    FilterContext filter = _filter != null ? _filter : _queryContext.getFilter();
-    if (filter != null) {
-      BaseFilterOperator filterOperator = constructPhysicalOperator(filter, numDocs);
+    if (_filter != null) {
+      BaseFilterOperator filterOperator = constructPhysicalOperator(_filter, numDocs);
       if (queryableDocIdSnapshot != null) {
         BaseFilterOperator validDocFilter = new BitmapBasedFilterOperator(queryableDocIdSnapshot, false, numDocs);
         return FilterOperatorUtils.getAndFilterOperator(_queryContext, Arrays.asList(filterOperator, validDocFilter),
@@ -312,6 +310,8 @@ public class FilterPlanNode implements PlanNode {
               return FilterOperatorUtils.getLeafFilterOperator(_queryContext, predicateEvaluator, dataSource, numDocs);
           }
         }
+      case CONSTANT:
+        return filter.isConstantTrue() ? new MatchAllFilterOperator(numDocs) : EmptyFilterOperator.getInstance();
       default:
         throw new IllegalStateException();
     }
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/plan/GroupByPlanNode.java b/pinot-core/src/main/java/org/apache/pinot/core/plan/GroupByPlanNode.java
index d5324e073e..89b1afa552 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/plan/GroupByPlanNode.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/plan/GroupByPlanNode.java
@@ -18,32 +18,19 @@
  */
 package org.apache.pinot.core.plan;
 
-import java.util.List;
-import java.util.Map;
-import java.util.Set;
-import org.apache.commons.lang3.tuple.Pair;
-import org.apache.pinot.common.request.context.ExpressionContext;
 import org.apache.pinot.core.common.Operator;
-import org.apache.pinot.core.operator.BaseProjectOperator;
 import org.apache.pinot.core.operator.blocks.results.GroupByResultsBlock;
 import org.apache.pinot.core.operator.filter.BaseFilterOperator;
 import org.apache.pinot.core.operator.query.FilteredGroupByOperator;
 import org.apache.pinot.core.operator.query.GroupByOperator;
-import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
 import org.apache.pinot.core.query.aggregation.function.AggregationFunctionUtils;
 import org.apache.pinot.core.query.request.context.QueryContext;
-import org.apache.pinot.core.startree.CompositePredicateEvaluator;
-import org.apache.pinot.core.startree.StarTreeUtils;
-import org.apache.pinot.core.startree.plan.StarTreeProjectPlanNode;
 import org.apache.pinot.segment.spi.IndexSegment;
-import org.apache.pinot.segment.spi.index.startree.AggregationFunctionColumnPair;
-import org.apache.pinot.segment.spi.index.startree.StarTreeV2;
 
 
 /**
  * The <code>GroupByPlanNode</code> class provides the execution plan for group-by query on a single segment.
  */
-@SuppressWarnings("rawtypes")
 public class GroupByPlanNode implements PlanNode {
   private final IndexSegment _indexSegment;
   private final QueryContext _queryContext;
@@ -60,52 +47,18 @@ public class GroupByPlanNode implements PlanNode {
   }
 
   private FilteredGroupByOperator buildFilteredGroupByPlan() {
-    List<Pair<AggregationFunction[], BaseProjectOperator<?>>> projectOperators =
-        AggregationFunctionUtils.buildFilteredAggregateProjectOperators(_indexSegment, _queryContext);
-    return new FilteredGroupByOperator(_queryContext, projectOperators,
+    return new FilteredGroupByOperator(_queryContext,
+        AggregationFunctionUtils.buildFilteredAggregationInfos(_indexSegment, _queryContext),
         _indexSegment.getSegmentMetadata().getTotalDocs());
   }
 
   private GroupByOperator buildNonFilteredGroupByPlan() {
-    int numTotalDocs = _indexSegment.getSegmentMetadata().getTotalDocs();
-    AggregationFunction[] aggregationFunctions = _queryContext.getAggregationFunctions();
-    List<ExpressionContext> groupByExpressionsList = _queryContext.getGroupByExpressions();
-    assert aggregationFunctions != null && groupByExpressionsList != null;
-    ExpressionContext[] groupByExpressions = groupByExpressionsList.toArray(new ExpressionContext[0]);
-
     FilterPlanNode filterPlanNode = new FilterPlanNode(_indexSegment, _queryContext);
     BaseFilterOperator filterOperator = filterPlanNode.run();
-
-    // Use star-tree to solve the query if possible
-    List<StarTreeV2> starTrees = _indexSegment.getStarTrees();
-    if (!_queryContext.isNullHandlingEnabled() && !filterOperator.isResultEmpty() && starTrees != null
-        && !_queryContext.isSkipStarTree()) {
-      AggregationFunctionColumnPair[] aggregationFunctionColumnPairs =
-          StarTreeUtils.extractAggregationFunctionPairs(aggregationFunctions);
-      if (aggregationFunctionColumnPairs != null) {
-        Map<String, List<CompositePredicateEvaluator>> predicateEvaluatorsMap =
-            StarTreeUtils.extractPredicateEvaluatorsMap(_indexSegment, _queryContext.getFilter(),
-                filterPlanNode.getPredicateEvaluators());
-        if (predicateEvaluatorsMap != null) {
-          for (StarTreeV2 starTreeV2 : starTrees) {
-            if (StarTreeUtils.isFitForStarTree(starTreeV2.getMetadata(), aggregationFunctionColumnPairs,
-                groupByExpressions, predicateEvaluatorsMap.keySet())) {
-              BaseProjectOperator<?> projectOperator =
-                  new StarTreeProjectPlanNode(_queryContext, starTreeV2, aggregationFunctionColumnPairs,
-                      groupByExpressions, predicateEvaluatorsMap).run();
-              return new GroupByOperator(_queryContext, groupByExpressions, projectOperator, numTotalDocs, true);
-            }
-          }
-        }
-      }
-    }
-
-    // TODO: Do not create ProjectOperator when filter result is empty
-    Set<ExpressionContext> expressionsToTransform =
-        AggregationFunctionUtils.collectExpressionsToTransform(aggregationFunctions, groupByExpressionsList);
-    BaseProjectOperator<?> projectOperator =
-        new ProjectPlanNode(_indexSegment, _queryContext, expressionsToTransform, DocIdSetPlanNode.MAX_DOC_PER_CALL,
-            filterOperator).run();
-    return new GroupByOperator(_queryContext, groupByExpressions, projectOperator, numTotalDocs, false);
+    AggregationFunctionUtils.AggregationInfo aggregationInfo =
+        AggregationFunctionUtils.buildAggregationInfo(_indexSegment, _queryContext,
+            _queryContext.getAggregationFunctions(), _queryContext.getFilter(), filterOperator,
+            filterPlanNode.getPredicateEvaluators());
+    return new GroupByOperator(_queryContext, aggregationInfo, _indexSegment.getSegmentMetadata().getTotalDocs());
   }
 }
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionUtils.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionUtils.java
index 7a077ccbd5..cb0d3179d4 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionUtils.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionUtils.java
@@ -32,6 +32,7 @@ import org.apache.pinot.common.CustomObject;
 import org.apache.pinot.common.datatable.DataTable;
 import org.apache.pinot.common.request.context.ExpressionContext;
 import org.apache.pinot.common.request.context.FilterContext;
+import org.apache.pinot.common.request.context.predicate.Predicate;
 import org.apache.pinot.common.utils.DataSchema.ColumnDataType;
 import org.apache.pinot.core.common.BlockValSet;
 import org.apache.pinot.core.common.ObjectSerDeUtils;
@@ -39,10 +40,12 @@ import org.apache.pinot.core.operator.BaseProjectOperator;
 import org.apache.pinot.core.operator.blocks.ValueBlock;
 import org.apache.pinot.core.operator.filter.BaseFilterOperator;
 import org.apache.pinot.core.operator.filter.CombinedFilterOperator;
+import org.apache.pinot.core.operator.filter.predicate.PredicateEvaluator;
 import org.apache.pinot.core.plan.DocIdSetPlanNode;
 import org.apache.pinot.core.plan.FilterPlanNode;
 import org.apache.pinot.core.plan.ProjectPlanNode;
 import org.apache.pinot.core.query.request.context.QueryContext;
+import org.apache.pinot.core.startree.StarTreeUtils;
 import org.apache.pinot.segment.spi.AggregationFunctionType;
 import org.apache.pinot.segment.spi.IndexSegment;
 import org.apache.pinot.segment.spi.index.startree.AggregationFunctionColumnPair;
@@ -211,14 +214,69 @@ public class AggregationFunctionUtils {
     }
   }
 
+  public static class AggregationInfo {
+    private final AggregationFunction[] _functions;
+    private final BaseProjectOperator<?> _projectOperator;
+    private final boolean _useStarTree;
+
+    public AggregationInfo(AggregationFunction[] functions, BaseProjectOperator<?> projectOperator,
+        boolean useStarTree) {
+      _functions = functions;
+      _projectOperator = projectOperator;
+      _useStarTree = useStarTree;
+    }
+
+    public AggregationFunction[] getFunctions() {
+      return _functions;
+    }
+
+    public BaseProjectOperator<?> getProjectOperator() {
+      return _projectOperator;
+    }
+
+    public boolean isUseStarTree() {
+      return _useStarTree;
+    }
+  }
+
+  /**
+   * Builds {@link AggregationInfo} for aggregations.
+   */
+  public static AggregationInfo buildAggregationInfo(IndexSegment indexSegment, QueryContext queryContext,
+      AggregationFunction[] aggregationFunctions, @Nullable FilterContext filter, BaseFilterOperator filterOperator,
+      List<Pair<Predicate, PredicateEvaluator>> predicateEvaluators) {
+    BaseProjectOperator<?> projectOperator = null;
+
+    // TODO: Create a short-circuit ProjectOperator when filter result is empty
+    if (!filterOperator.isResultEmpty()) {
+      projectOperator =
+          StarTreeUtils.createStarTreeBasedProjectOperator(indexSegment, queryContext, aggregationFunctions, filter,
+              predicateEvaluators);
+    }
+
+    if (projectOperator != null) {
+      return new AggregationInfo(aggregationFunctions, projectOperator, true);
+    } else {
+      Set<ExpressionContext> expressionsToTransform =
+          AggregationFunctionUtils.collectExpressionsToTransform(aggregationFunctions,
+              queryContext.getGroupByExpressions());
+      projectOperator =
+          new ProjectPlanNode(indexSegment, queryContext, expressionsToTransform, DocIdSetPlanNode.MAX_DOC_PER_CALL,
+              filterOperator).run();
+      return new AggregationInfo(aggregationFunctions, projectOperator, false);
+    }
+  }
+
   /**
-   * Build pairs of filtered aggregation functions and corresponding project operator.
+   * Builds swim-lanes (list of {@link AggregationInfo}) for filtered aggregations.
    */
-  public static List<Pair<AggregationFunction[], BaseProjectOperator<?>>> buildFilteredAggregateProjectOperators(
-      IndexSegment indexSegment, QueryContext queryContext) {
+  public static List<AggregationInfo> buildFilteredAggregationInfos(IndexSegment indexSegment,
+      QueryContext queryContext) {
     assert queryContext.getAggregationFunctions() != null && queryContext.getFilteredAggregationFunctions() != null;
 
-    BaseFilterOperator mainFilterOperator = new FilterPlanNode(indexSegment, queryContext).run();
+    FilterPlanNode mainFilterPlan = new FilterPlanNode(indexSegment, queryContext);
+    BaseFilterOperator mainFilterOperator = mainFilterPlan.run();
+    List<Pair<Predicate, PredicateEvaluator>> mainPredicateEvaluators = mainFilterPlan.getPredicateEvaluators();
 
     // No need to process sub-filters when main filter has empty result
     if (mainFilterOperator.isResultEmpty()) {
@@ -228,68 +286,88 @@ public class AggregationFunctionUtils {
       BaseProjectOperator<?> projectOperator =
           new ProjectPlanNode(indexSegment, queryContext, expressions, DocIdSetPlanNode.MAX_DOC_PER_CALL,
               mainFilterOperator).run();
-      return Collections.singletonList(Pair.of(aggregationFunctions, projectOperator));
+      return Collections.singletonList(new AggregationInfo(aggregationFunctions, projectOperator, false));
     }
 
     // For each aggregation function, check if the aggregation function is a filtered aggregate. If so, populate the
     // corresponding filter operator.
-    Map<FilterContext, Pair<BaseFilterOperator, List<AggregationFunction>>> filterOperators = new HashMap<>();
+    Map<FilterContext, FilteredAggregationContext> filteredAggregationContexts = new HashMap<>();
     List<AggregationFunction> nonFilteredFunctions = new ArrayList<>();
+    FilterContext mainFilter = queryContext.getFilter();
     for (Pair<AggregationFunction, FilterContext> functionFilterPair : queryContext.getFilteredAggregationFunctions()) {
       AggregationFunction aggregationFunction = functionFilterPair.getLeft();
       FilterContext filter = functionFilterPair.getRight();
       if (filter != null) {
-        filterOperators.computeIfAbsent(filter, k -> {
+        filteredAggregationContexts.computeIfAbsent(filter, k -> {
+          FilterContext combinedFilter;
+          if (mainFilter == null) {
+            combinedFilter = filter;
+          } else {
+            combinedFilter = FilterContext.forAnd(List.of(mainFilter, filter));
+          }
+
+          FilterPlanNode subFilterPlan = new FilterPlanNode(indexSegment, queryContext, filter);
+          BaseFilterOperator subFilterOperator = subFilterPlan.run();
           BaseFilterOperator combinedFilterOperator;
-          BaseFilterOperator subFilterOperator = new FilterPlanNode(indexSegment, queryContext, filter).run();
-          if (mainFilterOperator.isResultMatchingAll()) {
+          if (mainFilterOperator.isResultMatchingAll() || subFilterOperator.isResultEmpty()) {
             combinedFilterOperator = subFilterOperator;
+          } else if (subFilterOperator.isResultMatchingAll()) {
+            combinedFilterOperator = mainFilterOperator;
           } else {
-            if (subFilterOperator.isResultEmpty()) {
-              combinedFilterOperator = subFilterOperator;
-            } else if (subFilterOperator.isResultMatchingAll()) {
-              combinedFilterOperator = mainFilterOperator;
-            } else {
-              combinedFilterOperator =
-                  new CombinedFilterOperator(mainFilterOperator, subFilterOperator, queryContext.getQueryOptions());
-            }
+            combinedFilterOperator =
+                new CombinedFilterOperator(mainFilterOperator, subFilterOperator, queryContext.getQueryOptions());
           }
-          return Pair.of(combinedFilterOperator, new ArrayList<>());
-        }).getRight().add(aggregationFunction);
+
+          List<Pair<Predicate, PredicateEvaluator>> subPredicateEvaluators = subFilterPlan.getPredicateEvaluators();
+          List<Pair<Predicate, PredicateEvaluator>> combinedPredicateEvaluators =
+              new ArrayList<>(mainPredicateEvaluators.size() + subPredicateEvaluators.size());
+          combinedPredicateEvaluators.addAll(mainPredicateEvaluators);
+          combinedPredicateEvaluators.addAll(subPredicateEvaluators);
+
+          return new FilteredAggregationContext(combinedFilter, combinedFilterOperator, combinedPredicateEvaluators);
+        })._aggregationFunctions.add(aggregationFunction);
       } else {
         nonFilteredFunctions.add(aggregationFunction);
       }
     }
 
-    // Create the project operators
-    List<Pair<AggregationFunction[], BaseProjectOperator<?>>> projectOperators = new ArrayList<>();
-    List<ExpressionContext> groupByExpressions = queryContext.getGroupByExpressions();
-    for (Pair<BaseFilterOperator, List<AggregationFunction>> filterOperatorFunctionsPair : filterOperators.values()) {
-      BaseFilterOperator filterOperator = filterOperatorFunctionsPair.getLeft();
+    List<AggregationInfo> aggregationInfos = new ArrayList<>();
+    for (FilteredAggregationContext filteredAggregationContext : filteredAggregationContexts.values()) {
+      BaseFilterOperator filterOperator = filteredAggregationContext._filterOperator;
       if (filterOperator == mainFilterOperator) {
         // This can happen when the sub filter matches all documents, and we can treat the function as non-filtered
-        nonFilteredFunctions.addAll(filterOperatorFunctionsPair.getRight());
+        nonFilteredFunctions.addAll(filteredAggregationContext._aggregationFunctions);
       } else {
         AggregationFunction[] aggregationFunctions =
-            filterOperatorFunctionsPair.getRight().toArray(new AggregationFunction[0]);
-        Set<ExpressionContext> expressions = collectExpressionsToTransform(aggregationFunctions, groupByExpressions);
-        BaseProjectOperator<?> projectOperator =
-            new ProjectPlanNode(indexSegment, queryContext, expressions, DocIdSetPlanNode.MAX_DOC_PER_CALL,
-                filterOperator).run();
-        projectOperators.add(Pair.of(aggregationFunctions, projectOperator));
+            filteredAggregationContext._aggregationFunctions.toArray(new AggregationFunction[0]);
+        aggregationInfos.add(
+            buildAggregationInfo(indexSegment, queryContext, aggregationFunctions, filteredAggregationContext._filter,
+                filteredAggregationContext._filterOperator, filteredAggregationContext._predicateEvaluators));
       }
     }
 
     if (!nonFilteredFunctions.isEmpty()) {
       AggregationFunction[] aggregationFunctions = nonFilteredFunctions.toArray(new AggregationFunction[0]);
-      Set<ExpressionContext> expressions = collectExpressionsToTransform(aggregationFunctions, groupByExpressions);
-      BaseProjectOperator<?> projectOperator =
-          new ProjectPlanNode(indexSegment, queryContext, expressions, DocIdSetPlanNode.MAX_DOC_PER_CALL,
-              mainFilterOperator).run();
-      projectOperators.add(Pair.of(aggregationFunctions, projectOperator));
+      aggregationInfos.add(
+          buildAggregationInfo(indexSegment, queryContext, aggregationFunctions, mainFilter, mainFilterOperator,
+              mainPredicateEvaluators));
     }
 
-    return projectOperators;
+    return aggregationInfos;
+  }
+
+  private static class FilteredAggregationContext {
+    final FilterContext _filter;
+    final BaseFilterOperator _filterOperator;
+    final List<Pair<Predicate, PredicateEvaluator>> _predicateEvaluators;
+    final List<AggregationFunction> _aggregationFunctions = new ArrayList<>();
+
+    public FilteredAggregationContext(FilterContext filter, BaseFilterOperator filterOperator,
+        List<Pair<Predicate, PredicateEvaluator>> predicateEvaluators) {
+      _filter = filter;
+      _filterOperator = filterOperator;
+      _predicateEvaluators = predicateEvaluators;
+    }
   }
 
   public static String getResultColumnName(AggregationFunction aggregationFunction, @Nullable FilterContext filter) {
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/startree/StarTreeUtils.java b/pinot-core/src/main/java/org/apache/pinot/core/startree/StarTreeUtils.java
index c8de00208f..ceeca782d4 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/startree/StarTreeUtils.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/startree/StarTreeUtils.java
@@ -32,13 +32,17 @@ import org.apache.commons.lang3.tuple.Pair;
 import org.apache.pinot.common.request.context.ExpressionContext;
 import org.apache.pinot.common.request.context.FilterContext;
 import org.apache.pinot.common.request.context.predicate.Predicate;
+import org.apache.pinot.core.operator.BaseProjectOperator;
 import org.apache.pinot.core.operator.filter.predicate.PredicateEvaluator;
 import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
 import org.apache.pinot.core.query.aggregation.function.AggregationFunctionUtils;
+import org.apache.pinot.core.query.request.context.QueryContext;
+import org.apache.pinot.core.startree.plan.StarTreeProjectPlanNode;
 import org.apache.pinot.segment.spi.IndexSegment;
 import org.apache.pinot.segment.spi.datasource.DataSource;
 import org.apache.pinot.segment.spi.index.reader.Dictionary;
 import org.apache.pinot.segment.spi.index.startree.AggregationFunctionColumnPair;
+import org.apache.pinot.segment.spi.index.startree.StarTreeV2;
 import org.apache.pinot.segment.spi.index.startree.StarTreeV2Metadata;
 
 
@@ -115,8 +119,8 @@ public class StarTreeUtils {
           return null;
         case PREDICATE:
           Predicate predicate = filterNode.getPredicate();
-          PredicateEvaluator predicateEvaluator = getPredicateEvaluator(indexSegment, predicate,
-              predicateEvaluatorMapping);
+          PredicateEvaluator predicateEvaluator =
+              getPredicateEvaluator(indexSegment, predicate, predicateEvaluatorMapping);
           // Do not use star-tree when the predicate cannot be solved with star-tree or is always false
           if (predicateEvaluator == null || predicateEvaluator.isAlwaysFalse()) {
             return null;
@@ -277,10 +281,44 @@ public class StarTreeUtils {
         break;
     }
     for (Pair<Predicate, PredicateEvaluator> pair : predicatesEvaluatorMapping) {
-      if (pair.getKey().equals(predicate)) {
+      if (pair.getKey() == predicate) {
         return pair.getValue();
       }
     }
     return null;
   }
+
+  /**
+   * Returns a {@link BaseProjectOperator} when the filter can be solved with star-tree, or {@code null} otherwise.
+   */
+  @Nullable
+  public static BaseProjectOperator<?> createStarTreeBasedProjectOperator(IndexSegment indexSegment,
+      QueryContext queryContext, AggregationFunction[] aggregationFunctions, @Nullable FilterContext filter,
+      List<Pair<Predicate, PredicateEvaluator>> predicateEvaluators) {
+    List<StarTreeV2> starTrees = indexSegment.getStarTrees();
+    if (starTrees == null || queryContext.isSkipStarTree() || queryContext.isNullHandlingEnabled()) {
+      return null;
+    }
+    AggregationFunctionColumnPair[] aggregationFunctionColumnPairs =
+        extractAggregationFunctionPairs(aggregationFunctions);
+    if (aggregationFunctionColumnPairs == null) {
+      return null;
+    }
+    Map<String, List<CompositePredicateEvaluator>> predicateEvaluatorsMap =
+        extractPredicateEvaluatorsMap(indexSegment, filter, predicateEvaluators);
+    if (predicateEvaluatorsMap == null) {
+      return null;
+    }
+    ExpressionContext[] groupByExpressions =
+        queryContext.getGroupByExpressions() != null ? queryContext.getGroupByExpressions()
+            .toArray(new ExpressionContext[0]) : null;
+    for (StarTreeV2 starTreeV2 : starTrees) {
+      if (isFitForStarTree(starTreeV2.getMetadata(), aggregationFunctionColumnPairs, groupByExpressions,
+          predicateEvaluatorsMap.keySet())) {
+        return new StarTreeProjectPlanNode(queryContext, starTreeV2, aggregationFunctionColumnPairs, groupByExpressions,
+            predicateEvaluatorsMap).run();
+      }
+    }
+    return null;
+  }
 }
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/startree/executor/StarTreeGroupByExecutor.java b/pinot-core/src/main/java/org/apache/pinot/core/startree/executor/StarTreeGroupByExecutor.java
index bad3d4ad08..6441fcc98d 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/startree/executor/StarTreeGroupByExecutor.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/startree/executor/StarTreeGroupByExecutor.java
@@ -19,6 +19,7 @@
 package org.apache.pinot.core.startree.executor;
 
 import java.util.Map;
+import javax.annotation.Nullable;
 import org.apache.pinot.common.request.context.ExpressionContext;
 import org.apache.pinot.core.common.BlockValSet;
 import org.apache.pinot.core.operator.BaseProjectOperator;
@@ -27,6 +28,7 @@ import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
 import org.apache.pinot.core.query.aggregation.function.AggregationFunctionUtils;
 import org.apache.pinot.core.query.aggregation.groupby.DefaultGroupByExecutor;
 import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
+import org.apache.pinot.core.query.aggregation.groupby.GroupKeyGenerator;
 import org.apache.pinot.core.query.request.context.QueryContext;
 import org.apache.pinot.segment.spi.index.startree.AggregationFunctionColumnPair;
 
@@ -45,9 +47,19 @@ public class StarTreeGroupByExecutor extends DefaultGroupByExecutor {
 
   public StarTreeGroupByExecutor(QueryContext queryContext, ExpressionContext[] groupByExpressions,
       BaseProjectOperator<?> projectOperator) {
-    super(queryContext, groupByExpressions, projectOperator);
+    this(queryContext, queryContext.getAggregationFunctions(), groupByExpressions, projectOperator, null);
+  }
+
+  public StarTreeGroupByExecutor(QueryContext queryContext, AggregationFunction[] aggregationFunctions,
+      ExpressionContext[] groupByExpressions, BaseProjectOperator<?> projectOperator) {
+    this(queryContext, aggregationFunctions, groupByExpressions, projectOperator, null);
+  }
+
+  public StarTreeGroupByExecutor(QueryContext queryContext, AggregationFunction[] aggregationFunctions,
+      ExpressionContext[] groupByExpressions, BaseProjectOperator<?> projectOperator,
+      @Nullable GroupKeyGenerator groupKeyGenerator) {
+    super(queryContext, aggregationFunctions, groupByExpressions, projectOperator, groupKeyGenerator);
 
-    AggregationFunction[] aggregationFunctions = queryContext.getAggregationFunctions();
     assert aggregationFunctions != null;
     int numAggregationFunctions = aggregationFunctions.length;
     _aggregationFunctionColumnPairs = new AggregationFunctionColumnPair[numAggregationFunctions];
diff --git a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/OfflineClusterIntegrationTest.java b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/OfflineClusterIntegrationTest.java
index 09b7411d1c..6f9b125a4c 100644
--- a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/OfflineClusterIntegrationTest.java
+++ b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/OfflineClusterIntegrationTest.java
@@ -128,10 +128,20 @@ public class OfflineClusterIntegrationTest extends BaseClusterIntegrationTestSet
       new StarTreeIndexConfig(Collections.singletonList("Carrier"), null,
           Collections.singletonList(AggregationFunctionColumnPair.COUNT_STAR.toColumnName()), null, 100);
   private static final String TEST_STAR_TREE_QUERY_1 = "SELECT COUNT(*) FROM mytable WHERE Carrier = 'UA'";
+  private static final String TEST_STAR_TREE_QUERY_1_FILTER_INVERT =
+      "SELECT COUNT(*) FILTER (WHERE Carrier = 'UA') FROM mytable";
   private static final StarTreeIndexConfig STAR_TREE_INDEX_CONFIG_2 =
       new StarTreeIndexConfig(Collections.singletonList("DestState"), null,
           Collections.singletonList(AggregationFunctionColumnPair.COUNT_STAR.toColumnName()), null, 100);
   private static final String TEST_STAR_TREE_QUERY_2 = "SELECT COUNT(*) FROM mytable WHERE DestState = 'CA'";
+  private static final String TEST_STAR_TREE_QUERY_FILTERED_AGG =
+      "SELECT COUNT(*), COUNT(*) FILTER (WHERE Carrier = 'UA') FROM mytable WHERE DestState = 'CA'";
+  // This query contains a filtered aggregation which cannot be solved with startree, but the COUNT(*) still should be
+  private static final String TEST_STAR_TREE_QUERY_FILTERED_AGG_MIXED =
+      "SELECT COUNT(*), AVG(ArrDelay) FILTER (WHERE Carrier = 'UA') FROM mytable WHERE DestState = 'CA'";
+  private static final StarTreeIndexConfig STAR_TREE_INDEX_CONFIG_3 =
+      new StarTreeIndexConfig(List.of("Carrier", "DestState"), null,
+          Collections.singletonList(AggregationFunctionColumnPair.COUNT_STAR.toColumnName()), null, 100);
 
   // For default columns test
   private static final String TEST_EXTRA_COLUMNS_QUERY = "SELECT COUNT(*) FROM mytable WHERE NewAddedIntMetric = 1";
@@ -1326,6 +1336,9 @@ public class OfflineClusterIntegrationTest extends BaseClusterIntegrationTestSet
     assertEquals(firstQueryResponse.get("totalDocs").asLong(), numTotalDocs);
     // Initially 'numDocsScanned' should be the same as 'COUNT(*)' result
     assertEquals(firstQueryResponse.get("numDocsScanned").asInt(), firstQueryResult);
+    // Verify that inverting the filter to be a filtered agg shows the identical results
+    firstQueryResponse = postQuery(TEST_STAR_TREE_QUERY_1_FILTER_INVERT);
+    assertEquals(firstQueryResponse.get("resultTable").get("rows").get(0).get(0).asInt(), firstQueryResult);
 
     // Update table config and trigger reload
     TableConfig tableConfig = getOfflineTableConfig();
@@ -1336,6 +1349,11 @@ public class OfflineClusterIntegrationTest extends BaseClusterIntegrationTestSet
     reloadAllSegments(TEST_STAR_TREE_QUERY_1, false, numTotalDocs);
     // With star-tree, 'numDocsScanned' should be the same as number of segments (1 per segment)
     assertEquals(postQuery(TEST_STAR_TREE_QUERY_1).get("numDocsScanned").asLong(), NUM_SEGMENTS);
+    // Verify that inverting the filter to be a filtered agg shows the identical results
+    firstQueryResponse = postQuery(TEST_STAR_TREE_QUERY_1_FILTER_INVERT);
+    assertEquals(firstQueryResponse.get("resultTable").get("rows").get(0).get(0).asInt(), firstQueryResult);
+    assertEquals(firstQueryResponse.get("totalDocs").asLong(), numTotalDocs);
+    assertEquals(firstQueryResponse.get("numDocsScanned").asInt(), NUM_SEGMENTS);
 
     // Reload again should have no effect
     reloadAllSegments(TEST_STAR_TREE_QUERY_1, false, numTotalDocs);
@@ -1343,6 +1361,11 @@ public class OfflineClusterIntegrationTest extends BaseClusterIntegrationTestSet
     assertEquals(firstQueryResponse.get("resultTable").get("rows").get(0).get(0).asInt(), firstQueryResult);
     assertEquals(firstQueryResponse.get("totalDocs").asLong(), numTotalDocs);
     assertEquals(firstQueryResponse.get("numDocsScanned").asInt(), NUM_SEGMENTS);
+    // Verify that inverting the filter to be a filtered agg shows the identical results
+    firstQueryResponse = postQuery(TEST_STAR_TREE_QUERY_1_FILTER_INVERT);
+    assertEquals(firstQueryResponse.get("resultTable").get("rows").get(0).get(0).asInt(), firstQueryResult);
+    assertEquals(firstQueryResponse.get("totalDocs").asLong(), numTotalDocs);
+    assertEquals(firstQueryResponse.get("numDocsScanned").asInt(), NUM_SEGMENTS);
 
     // Enforce a sleep here since segment reload is async and there is another back-to-back reload below.
     // Otherwise, there is no way to tell whether the 1st reload on server side is finished,
diff --git a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/StarTreeClusterIntegrationTest.java b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/StarTreeClusterIntegrationTest.java
index 5447dc8b2f..aba44536e7 100644
--- a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/StarTreeClusterIntegrationTest.java
+++ b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/StarTreeClusterIntegrationTest.java
@@ -39,6 +39,8 @@ import org.testng.annotations.BeforeClass;
 import org.testng.annotations.Test;
 
 import static org.testng.Assert.assertEquals;
+import static org.testng.Assert.assertFalse;
+import static org.testng.Assert.assertTrue;
 
 
 /**
@@ -160,8 +162,8 @@ public class StarTreeClusterIntegrationTest extends BaseClusterIntegrationTest {
       throws Exception {
     setUseMultiStageQueryEngine(useMultiStageQueryEngine);
     for (int i = 0; i < NUM_QUERIES_TO_GENERATE; i += 2) {
-      testStarQuery(_starTree1QueryGenerator.nextQuery());
-      testStarQuery(_starTree2QueryGenerator.nextQuery());
+      testStarQuery(_starTree1QueryGenerator.nextQuery(), false);
+      testStarQuery(_starTree2QueryGenerator.nextQuery(), false);
     }
   }
 
@@ -174,14 +176,54 @@ public class StarTreeClusterIntegrationTest extends BaseClusterIntegrationTest {
     String starQuery = "SELECT DepTimeBlk, COUNT(*) FROM mytable "
         + "WHERE CRSDepTime BETWEEN 1137 AND 1849 AND DivArrDelay > 218 AND CRSDepTime NOT IN (35, 1633, 1457, 140) "
         + "AND LongestAddGTime NOT IN (17, 105, 20, 22) GROUP BY DepTimeBlk ORDER BY DepTimeBlk";
-    testStarQuery(starQuery);
+    testStarQuery(starQuery, !useMultiStageQueryEngine);
   }
 
-  private void testStarQuery(String starQuery)
+  @Test(dataProvider = "useBothQueryEngines")
+  public void testHardCodedFilteredAggQueries(boolean useMultiStageQueryEngine)
       throws Exception {
-    String referenceQuery = "SET useStarTree = false; " + starQuery;
+    setUseMultiStageQueryEngine(useMultiStageQueryEngine);
+    String starQuery = "SELECT DepTimeBlk, COUNT(*), COUNT(*) FILTER (WHERE CRSDepTime = 35) FROM mytable "
+        + "WHERE CRSDepTime != 35"
+        + "GROUP BY DepTimeBlk ORDER BY DepTimeBlk";
+    // Don't verify that the query plan uses StarTree index, as this query results in FILTER_EMPTY in the query plan.
+    // This is still a valuable test, as it caught a bug where only the subFilterContext was being preserved through
+    // AggragationFunctionUtils#buildFilteredAggregateProjectOperators
+    testStarQuery(starQuery, false);
+
+    // Ensure the filtered agg and unfiltered agg can co-exist in one query
+    starQuery = "SELECT DepTimeBlk, COUNT(*), COUNT(*) FILTER (WHERE DivArrDelay > 20) FROM mytable "
+        + "WHERE CRSDepTime != 35"
+        + "GROUP BY DepTimeBlk ORDER BY DepTimeBlk";
+    testStarQuery(starQuery, !useMultiStageQueryEngine);
+
+    starQuery = "SELECT DepTimeBlk, COUNT(*) FILTER (WHERE CRSDepTime != 35) FROM mytable "
+        + "GROUP BY DepTimeBlk ORDER BY DepTimeBlk";
+    testStarQuery(starQuery, !useMultiStageQueryEngine);
+  }
+
+  private void testStarQuery(String starQuery, boolean verifyPlan)
+      throws Exception {
+    String filterStartreeIndex = "FILTER_STARTREE_INDEX";
+    String explain = "EXPLAIN PLAN FOR ";
+    String disableStarTree = "SET useStarTree = false; ";
+
+    if (verifyPlan) {
+      JsonNode starPlan = postQuery(explain + starQuery);
+      JsonNode referencePlan = postQuery(disableStarTree + explain + starQuery);
+      assertTrue(starPlan.toString().contains(filterStartreeIndex)
+              || starPlan.toString().contains("FILTER_EMPTY")
+              || starPlan.toString().contains("ALL_SEGMENTS_PRUNED_ON_SERVER"),
+          "StarTree query did not indicate use of StarTree index in query plan. Plan: " + starPlan);
+      assertFalse(referencePlan.toString().contains(filterStartreeIndex),
+          "Reference query indicated use of StarTree index in query plan. Plan: " + referencePlan);
+    }
+
     JsonNode starResponse = postQuery(starQuery);
+    String referenceQuery = disableStarTree + starQuery;
     JsonNode referenceResponse = postQuery(referenceQuery);
+    assertEquals(starResponse.get("exceptions").size(), 0);
+    assertEquals(referenceResponse.get("exceptions").size(), 0);
     assertEquals(starResponse.get("resultTable"), referenceResponse.get("resultTable"), String.format(
         "Query comparison failed for: \n"
             + "Star Query: %s\nStar Response: %s\nReference Query: %s\nReference Response: %s\nRandom Seed: %d",
diff --git a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/startree/StarTreeQueryGenerator.java b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/startree/StarTreeQueryGenerator.java
index 827aaa7245..ee8618f0f3 100644
--- a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/startree/StarTreeQueryGenerator.java
+++ b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/startree/StarTreeQueryGenerator.java
@@ -45,6 +45,7 @@ public class StarTreeQueryGenerator {
   private static final String IN = " IN ";
   private static final String NOT_IN = " NOT IN ";
   private static final String AND = " AND ";
+  private static final String FILTER = " FILTER (%s)";
 
   private static final int MAX_NUM_AGGREGATIONS = 5;
   private static final int MAX_NUM_PREDICATES = 10;
@@ -82,6 +83,16 @@ public class StarTreeQueryGenerator {
         metricColumn);
   }
 
+  private String generateFilteredAggregation(String metricColumn) {
+    StringBuilder filteredAgg = new StringBuilder(generateAggregation(metricColumn));
+    String predicates = generatePredicates();
+    if (predicates == null) {
+      return filteredAgg.toString();
+    }
+    filteredAgg.append(String.format(FILTER, predicates));
+    return filteredAgg.toString();
+  }
+
   /**
    * Generate the aggregation section of the query, returns at least one aggregation.
    *
@@ -92,7 +103,11 @@ public class StarTreeQueryGenerator {
     int numMetrics = _metricColumns.size();
     String[] aggregations = new String[numAggregations];
     for (int i = 0; i < numAggregations; i++) {
-      aggregations[i] = generateAggregation(_metricColumns.get(_random.nextInt(numMetrics)));
+      if (i % 3 == 0) {
+        aggregations[i] = generateFilteredAggregation(_metricColumns.get(_random.nextInt(numMetrics)));
+      } else {
+        aggregations[i] = generateAggregation(_metricColumns.get(_random.nextInt(numMetrics)));
+      }
     }
     return StringUtils.join(aggregations, ", ");
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@pinot.apache.org
For additional commands, e-mail: commits-help@pinot.apache.org