You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@pinot.apache.org by xi...@apache.org on 2023/05/11 00:08:12 UTC

[pinot] branch master updated: [multistage] Refactor query planner and dispatcher (#10748)

This is an automated email from the ASF dual-hosted git repository.

xiangfu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git


The following commit(s) were added to refs/heads/master by this push:
     new fe98bb0783 [multistage] Refactor query planner and dispatcher (#10748)
fe98bb0783 is described below

commit fe98bb0783213504c77be397b1750de3962000ef
Author: Xiang Fu <xi...@gmail.com>
AuthorDate: Wed May 10 17:08:04 2023 -0700

    [multistage] Refactor query planner and dispatcher (#10748)
---
 .../MultiStageBrokerRequestHandler.java            |   8 +-
 .../org/apache/pinot/query/QueryEnvironment.java   |  69 +++++-----
 .../query/planner/DispatchablePlanFragment.java    | 119 +++++++++++++++++
 .../pinot/query/planner/DispatchableSubPlan.java   |  74 +++++++++++
 .../query/planner/ExplainPlanPlanVisitor.java      |  48 +++----
 .../apache/pinot/query/planner/PlanFragment.java   |  61 +++++++++
 .../pinot/query/planner/PlanFragmentMetadata.java  |  28 ++++
 .../org/apache/pinot/query/planner/QueryPlan.java  | 121 ++----------------
 .../pinot/query/planner/QueryPlanMetadata.java     |  63 +++++++++
 .../org/apache/pinot/query/planner/SubPlan.java    |  60 +++++++++
 .../pinot/query/planner/SubPlanMetadata.java       |  53 ++++++++
 .../query/planner/logical/LiteralValueNode.java    |  56 ++++++++
 .../planner/logical/PinotLogicalQueryPlanner.java  | 142 +++++++++++++++++++++
 .../{StageFragmenter.java => PlanFragmenter.java}  |  70 +++++++---
 ...eConverter.java => RelToPlanNodeConverter.java} |  28 +++-
 .../pinot/query/planner/logical/RexExpression.java |   6 +-
 .../query/planner/logical/RexExpressionUtils.java  |   9 +-
 .../pinot/query/planner/logical/StagePlanner.java  | 112 ----------------
 ...StageFragmenter.java => SubPlanFragmenter.java} |  72 ++++++-----
 .../planner/physical/DispatchablePlanContext.java  |  59 +++++++++
 .../planner/physical/DispatchablePlanVisitor.java  |  49 -------
 .../planner/physical/PinotDispatchPlanner.java     | 105 +++++++++++++++
 .../colocated/GreedyShuffleRewriteVisitor.java     |   6 +-
 .../pinot/query/planner/plannode/ExchangeNode.java |  14 +-
 .../query/planner/plannode/PlanNodeVisitor.java    |   3 +-
 ...lanFragmentMetadata.java => StageMetadata.java} |  14 +-
 .../apache/pinot/query/QueryCompilationTest.java   | 130 +++++++++----------
 ...erTest.java => RelToPlanNodeConverterTest.java} |  60 ++++-----
 .../query/planner/plannode/SerDeUtilsTest.java     |  21 +--
 .../query/queries/ResourceBasedQueryPlansTest.java |  44 ++++---
 .../apache/pinot/query/runtime/QueryRunner.java    |  10 +-
 .../runtime/operator/utils/OperatorUtils.java      |  10 +-
 .../query/runtime/plan/DistributedStagePlan.java   |  18 +--
 .../runtime/plan/OpChainExecutionContext.java      |  12 +-
 .../query/runtime/plan/PlanRequestContext.java     |  12 +-
 .../runtime/plan/serde/QueryPlanSerDeUtils.java    |  12 +-
 .../plan/server/ServerPlanRequestContext.java      |   6 +-
 .../query/service/dispatch/QueryDispatcher.java    |  93 ++++++++------
 .../pinot/query/runtime/QueryRunnerTest.java       |  39 +++---
 .../pinot/query/runtime/QueryRunnerTestBase.java   |  21 +--
 .../operator/MailboxReceiveOperatorTest.java       |  34 ++---
 .../runtime/operator/MailboxSendOperatorTest.java  |   6 +-
 .../pinot/query/runtime/operator/OpChainTest.java  |  14 +-
 .../operator/SortedMailboxReceiveOperatorTest.java |  38 +++---
 .../pinot/query/service/QueryServerTest.java       |  69 +++++-----
 .../service/dispatch/QueryDispatcherTest.java      |  34 ++---
 46 files changed, 1403 insertions(+), 729 deletions(-)

diff --git a/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/MultiStageBrokerRequestHandler.java b/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/MultiStageBrokerRequestHandler.java
index ffc72c0732..abfa1d4d15 100644
--- a/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/MultiStageBrokerRequestHandler.java
+++ b/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/MultiStageBrokerRequestHandler.java
@@ -51,7 +51,7 @@ import org.apache.pinot.core.transport.ServerInstance;
 import org.apache.pinot.query.QueryEnvironment;
 import org.apache.pinot.query.catalog.PinotCatalog;
 import org.apache.pinot.query.mailbox.MailboxService;
-import org.apache.pinot.query.planner.QueryPlan;
+import org.apache.pinot.query.planner.DispatchableSubPlan;
 import org.apache.pinot.query.routing.WorkerManager;
 import org.apache.pinot.query.service.QueryConfig;
 import org.apache.pinot.query.service.dispatch.QueryDispatcher;
@@ -175,7 +175,7 @@ public class MultiStageBrokerRequestHandler extends BaseBrokerRequestHandler {
       return new BrokerResponseNative(QueryException.getException(QueryException.SQL_PARSING_ERROR, e));
     }
 
-    QueryPlan queryPlan = queryPlanResult.getQueryPlan();
+    DispatchableSubPlan dispatchableSubPlan = queryPlanResult.getQueryPlan();
     Set<String> tableNames = queryPlanResult.getTableNames();
 
     // Compilation Time. This includes the time taken for parsing, compiling, create stage plans and assigning workers.
@@ -201,13 +201,13 @@ public class MultiStageBrokerRequestHandler extends BaseBrokerRequestHandler {
 
     ResultTable queryResults;
     Map<Integer, ExecutionStatsAggregator> stageIdStatsMap = new HashMap<>();
-    for (Integer stageId : queryPlan.getDispatchablePlanMetadataMap().keySet()) {
+    for (int stageId = 0; stageId < dispatchableSubPlan.getQueryStageList().size(); stageId++) {
       stageIdStatsMap.put(stageId, new ExecutionStatsAggregator(traceEnabled));
     }
 
     long executionStartTimeNs = System.nanoTime();
     try {
-      queryResults = _queryDispatcher.submitAndReduce(requestId, queryPlan, _mailboxService, queryTimeoutMs,
+      queryResults = _queryDispatcher.submitAndReduce(requestId, dispatchableSubPlan, _mailboxService, queryTimeoutMs,
           sqlNodeAndOptions.getOptions(), stageIdStatsMap, traceEnabled);
     } catch (Exception e) {
       LOGGER.info("query execution failed", e);
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/QueryEnvironment.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/QueryEnvironment.java
index 7c89ea74dd..c5f8ff8474 100644
--- a/pinot-query-planner/src/main/java/org/apache/pinot/query/QueryEnvironment.java
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/QueryEnvironment.java
@@ -20,8 +20,6 @@ package org.apache.pinot.query;
 
 import com.google.common.annotations.VisibleForTesting;
 import java.util.Arrays;
-import java.util.HashSet;
-import java.util.List;
 import java.util.Properties;
 import java.util.Set;
 import javax.annotation.Nullable;
@@ -60,9 +58,13 @@ import org.apache.calcite.tools.Frameworks;
 import org.apache.calcite.tools.RelBuilder;
 import org.apache.pinot.common.config.provider.TableCache;
 import org.apache.pinot.query.context.PlannerContext;
+import org.apache.pinot.query.planner.DispatchableSubPlan;
 import org.apache.pinot.query.planner.PlannerUtils;
 import org.apache.pinot.query.planner.QueryPlan;
-import org.apache.pinot.query.planner.logical.StagePlanner;
+import org.apache.pinot.query.planner.SubPlan;
+import org.apache.pinot.query.planner.logical.PinotLogicalQueryPlanner;
+import org.apache.pinot.query.planner.logical.RelToPlanNodeConverter;
+import org.apache.pinot.query.planner.physical.PinotDispatchPlanner;
 import org.apache.pinot.query.routing.WorkerManager;
 import org.apache.pinot.query.type.TypeFactory;
 import org.apache.pinot.sql.parsers.CalciteSqlParser;
@@ -72,7 +74,7 @@ import org.apache.pinot.sql.parsers.SqlNodeAndOptions;
 /**
  * The {@code QueryEnvironment} contains the main entrypoint for query planning.
  *
- * <p>It provide the higher level entry interface to convert a SQL string into a {@link QueryPlan}.
+ * <p>It provide the higher level entry interface to convert a SQL string into a {@link DispatchableSubPlan}.
  */
 public class QueryEnvironment {
   // Calcite configurations
@@ -155,9 +157,12 @@ public class QueryEnvironment {
     try (PlannerContext plannerContext = new PlannerContext(_config, _catalogReader, _typeFactory, _hepProgram)) {
       plannerContext.setOptions(sqlNodeAndOptions.getOptions());
       RelRoot relRoot = compileQuery(sqlNodeAndOptions.getSqlNode(), plannerContext);
-      Set<String> tableNames = getTableNamesFromRelRoot(relRoot.rel);
-      return new QueryPlannerResult(toDispatchablePlan(relRoot, plannerContext, requestId, tableNames), null,
-          tableNames);
+      SubPlan subPlanRoot = toSubPlan(relRoot);
+      // TODO: current code only assume one SubPlan per query, but we should support multiple SubPlans per query.
+      // Each SubPlan should be able to run independently from Broker then set the results into the dependent
+      // SubPlan for further processing.
+      DispatchableSubPlan dispatchableSubPlan = toDispatchableSubPlan(subPlanRoot, plannerContext, requestId);
+      return new QueryPlannerResult(dispatchableSubPlan, null, dispatchableSubPlan.getTableNames());
     } catch (CalciteContextException e) {
       throw new RuntimeException("Error composing query plan for '" + sqlQuery
           + "': " + e.getMessage() + "'", e);
@@ -170,7 +175,8 @@ public class QueryEnvironment {
    * Explain a SQL query.
    *
    * Similar to {@link QueryEnvironment#planQuery(String, SqlNodeAndOptions, long)}, this API runs the query
-   * compilation. But it doesn't run the distributed {@link QueryPlan} generation, instead it only returns the
+   * compilation. But it doesn't run the distributed {@link DispatchableSubPlan} generation, instead it only
+   * returns the
    * explained logical plan.
    *
    * @param sqlQuery SQL query string.
@@ -185,7 +191,7 @@ public class QueryEnvironment {
       SqlExplainFormat format = explain.getFormat() == null ? SqlExplainFormat.DOT : explain.getFormat();
       SqlExplainLevel level =
           explain.getDetailLevel() == null ? SqlExplainLevel.DIGEST_ATTRIBUTES : explain.getDetailLevel();
-      Set<String> tableNames = getTableNamesFromRelRoot(relRoot.rel);
+      Set<String> tableNames = RelToPlanNodeConverter.getTableNamesFromRelRoot(relRoot.rel);
       return new QueryPlannerResult(null, PlannerUtils.explainPlan(relRoot.rel, format, level), tableNames);
     } catch (Exception e) {
       throw new RuntimeException("Error explain query plan for: " + sqlQuery, e);
@@ -193,7 +199,7 @@ public class QueryEnvironment {
   }
 
   @VisibleForTesting
-  public QueryPlan planQuery(String sqlQuery) {
+  public DispatchableSubPlan planQuery(String sqlQuery) {
     return planQuery(sqlQuery, CalciteSqlParser.compileToSqlNodeAndOptions(sqlQuery), 0).getQueryPlan();
   }
 
@@ -206,12 +212,13 @@ public class QueryEnvironment {
    * Results of planning a query
    */
   public static class QueryPlannerResult {
-    private QueryPlan _queryPlan;
+    private DispatchableSubPlan _dispatchableSubPlan;
     private String _explainPlan;
     Set<String> _tableNames;
 
-    QueryPlannerResult(@Nullable QueryPlan queryPlan, @Nullable String explainPlan, Set<String> tableNames) {
-      _queryPlan = queryPlan;
+    QueryPlannerResult(@Nullable DispatchableSubPlan dispatchableSubPlan, @Nullable String explainPlan,
+        Set<String> tableNames) {
+      _dispatchableSubPlan = dispatchableSubPlan;
       _explainPlan = explainPlan;
       _tableNames = tableNames;
     }
@@ -220,8 +227,8 @@ public class QueryEnvironment {
       return _explainPlan;
     }
 
-    public QueryPlan getQueryPlan() {
-      return _queryPlan;
+    public DispatchableSubPlan getQueryPlan() {
+      return _dispatchableSubPlan;
     }
 
     // Returns all the table names in the query.
@@ -297,11 +304,20 @@ public class QueryEnvironment {
     }
   }
 
-  private QueryPlan toDispatchablePlan(RelRoot relRoot, PlannerContext plannerContext, long requestId,
-      Set<String> tableNames) {
-    // 5. construct a dispatchable query plan.
-    StagePlanner queryStagePlanner = new StagePlanner(plannerContext, _workerManager, requestId, _tableCache);
-    return queryStagePlanner.makePlan(relRoot, tableNames);
+  private SubPlan toSubPlan(RelRoot relRoot) {
+    // 5. construct a logical query plan.
+    PinotLogicalQueryPlanner pinotLogicalQueryPlanner = new PinotLogicalQueryPlanner();
+    QueryPlan queryPlan = pinotLogicalQueryPlanner.planQuery(relRoot);
+    SubPlan subPlan = pinotLogicalQueryPlanner.makePlan(queryPlan);
+    return subPlan;
+  }
+
+  private DispatchableSubPlan toDispatchableSubPlan(SubPlan subPlan, PlannerContext plannerContext, long requestId) {
+    // 6. construct a dispatchable query plan.
+    PinotDispatchPlanner pinotDispatchPlanner =
+        new PinotDispatchPlanner(plannerContext, _workerManager, requestId, _tableCache);
+    DispatchableSubPlan dispatchableSubPlan = pinotDispatchPlanner.createDispatchableSubPlan(subPlan);
+    return dispatchableSubPlan;
   }
 
   // --------------------------------------------------------------------------
@@ -311,17 +327,4 @@ public class QueryEnvironment {
   private HintStrategyTable getHintStrategyTable() {
     return PinotHintStrategyTable.PINOT_HINT_STRATEGY_TABLE;
   }
-
-
-  private Set<String> getTableNamesFromRelRoot(RelNode relRoot) {
-    Set<String> tableNames = new HashSet<>();
-    List<String> qualifiedTableNames = RelOptUtil.findAllTableQualifiedNames(relRoot);
-    for (String qualifiedTableName : qualifiedTableNames) {
-      // Calcite encloses table and schema names in square brackets to properly quote and delimit them in SQL
-      // statements, particularly to handle cases when they contain special characters or reserved keywords.
-      String tableName = qualifiedTableName.replaceAll("^\\[(.*)\\]$", "$1");
-      tableNames.add(tableName);
-    }
-    return tableNames;
-  }
 }
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/DispatchablePlanFragment.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/DispatchablePlanFragment.java
new file mode 100644
index 0000000000..c06fa383a6
--- /dev/null
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/DispatchablePlanFragment.java
@@ -0,0 +1,119 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.query.planner;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import org.apache.pinot.core.routing.TimeBoundaryInfo;
+import org.apache.pinot.query.routing.QueryServerInstance;
+import org.apache.pinot.query.routing.StageMetadata;
+import org.apache.pinot.query.routing.WorkerMetadata;
+
+
+public class DispatchablePlanFragment {
+
+  public static final String TABLE_NAME_KEY = "tableName";
+  public static final String TIME_BOUNDARY_COLUMN_KEY = "timeBoundaryInfo.timeColumn";
+  public static final String TIME_BOUNDARY_VALUE_KEY = "timeBoundaryInfo.timeValue";
+  private final PlanFragment _planFragment;
+  private final List<WorkerMetadata> _workerMetadataList;
+
+  // This is used at broker stage - we don't need to ship it to the server.
+  private final Map<QueryServerInstance, List<Integer>> _serverInstanceToWorkerIdMap;
+
+  // used for table scan stage - we use ServerInstance instead of VirtualServer
+  // here because all virtual servers that share a server instance will have the
+  // same segments on them
+  private final Map<Integer, Map<String, List<String>>> _workerIdToSegmentsMap;
+
+  // used for passing custom properties to build StageMetadata on the server.
+  private final Map<String, String> _customProperties;
+
+  public DispatchablePlanFragment(PlanFragment planFragment) {
+    this(planFragment, new ArrayList<>(), new HashMap<>(), new HashMap<>());
+  }
+
+  public DispatchablePlanFragment(PlanFragment planFragment, List<WorkerMetadata> workerMetadataList,
+      Map<QueryServerInstance, List<Integer>> serverInstanceToWorkerIdMap, Map<String, String> customPropertyMap) {
+    _planFragment = planFragment;
+    _workerMetadataList = workerMetadataList;
+    _serverInstanceToWorkerIdMap = serverInstanceToWorkerIdMap;
+    _workerIdToSegmentsMap = new HashMap<>();
+    _customProperties = customPropertyMap;
+  }
+
+  public PlanFragment getPlanFragment() {
+    return _planFragment;
+  }
+
+  public List<WorkerMetadata> getWorkerMetadataList() {
+    return _workerMetadataList;
+  }
+
+  public Map<QueryServerInstance, List<Integer>> getServerInstanceToWorkerIdMap() {
+    return _serverInstanceToWorkerIdMap;
+  }
+
+  public Map<String, String> getCustomProperties() {
+    return _customProperties;
+  }
+
+  public String getTableName() {
+    return _customProperties.get(TABLE_NAME_KEY);
+  }
+
+  public String setTableName(String tableName) {
+    return _customProperties.put(TABLE_NAME_KEY, tableName);
+  }
+
+  public TimeBoundaryInfo getTimeBoundary() {
+    return new TimeBoundaryInfo(_customProperties.get(TIME_BOUNDARY_COLUMN_KEY),
+        _customProperties.get(TIME_BOUNDARY_VALUE_KEY));
+  }
+
+  public void setTimeBoundaryInfo(TimeBoundaryInfo timeBoundaryInfo) {
+    _customProperties.put(TIME_BOUNDARY_COLUMN_KEY, timeBoundaryInfo.getTimeColumn());
+    _customProperties.put(TIME_BOUNDARY_VALUE_KEY, timeBoundaryInfo.getTimeValue());
+  }
+
+  public Map<Integer, Map<String, List<String>>> getWorkerIdToSegmentsMap() {
+    return _workerIdToSegmentsMap;
+  }
+
+  public void setWorkerIdToSegmentsMap(Map<Integer, Map<String, List<String>>> workerIdToSegmentsMap) {
+    _workerIdToSegmentsMap.clear();
+    _workerIdToSegmentsMap.putAll(workerIdToSegmentsMap);
+  }
+
+  public void setWorkerMetadataList(List<WorkerMetadata> workerMetadataList) {
+    _workerMetadataList.clear();
+    _workerMetadataList.addAll(workerMetadataList);
+  }
+
+  public StageMetadata toStageMetadata() {
+    return new StageMetadata(_workerMetadataList, _customProperties);
+  }
+
+  public void setServerInstanceToWorkerIdMap(Map<QueryServerInstance, List<Integer>> serverInstanceToWorkerIdMap) {
+    _serverInstanceToWorkerIdMap.clear();
+    _serverInstanceToWorkerIdMap.putAll(serverInstanceToWorkerIdMap);
+  }
+}
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/DispatchableSubPlan.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/DispatchableSubPlan.java
new file mode 100644
index 0000000000..734698c122
--- /dev/null
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/DispatchableSubPlan.java
@@ -0,0 +1,74 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.query.planner;
+
+import java.util.List;
+import java.util.Set;
+import org.apache.calcite.util.Pair;
+
+
+/**
+ * The {@code DispatchableSubPlan} is the dispatchable query execution plan from the result of
+ * {@link org.apache.pinot.query.planner.logical.LogicalPlanner} and
+ * {@link org.apache.pinot.query.planner.physical.PinotDispatchPlanner}.
+ *
+ * <p>QueryPlan should contain the necessary stage boundary information and the cross exchange information
+ * for:
+ * <ul>
+ *   <li>dispatch individual stages to executor.</li>
+ *   <li>instruction for stage executor to establish connection channels to other stages.</li>
+ *   <li>instruction for encoding data blocks & transferring between stages based on partitioning scheme.</li>
+ * </ul>
+ */
+public class DispatchableSubPlan {
+  private final List<Pair<Integer, String>> _queryResultFields;
+  private final List<DispatchablePlanFragment> _queryStageList;
+  private final Set<String> _tableNames;
+
+  public DispatchableSubPlan(List<Pair<Integer, String>> fields, List<DispatchablePlanFragment> queryStageList,
+      Set<String> tableNames) {
+    _queryResultFields = fields;
+    _queryStageList = queryStageList;
+    _tableNames = tableNames;
+  }
+
+  /**
+   * Get the list of stage plan root node.
+   * @return stage plan map.
+   */
+  public List<DispatchablePlanFragment> getQueryStageList() {
+    return _queryStageList;
+  }
+
+  /**
+   * Get the query result field.
+   * @return query result field.
+   */
+  public List<Pair<Integer, String>> getQueryResultFields() {
+    return _queryResultFields;
+  }
+
+  /**
+   * Get the table names.
+   * @return table names.
+   */
+  public Set<String> getTableNames() {
+    return _tableNames;
+  }
+}
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/ExplainPlanPlanVisitor.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/ExplainPlanPlanVisitor.java
index 7b5e935010..60fa47bb53 100644
--- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/ExplainPlanPlanVisitor.java
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/ExplainPlanPlanVisitor.java
@@ -22,7 +22,6 @@ import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
 import java.util.stream.Collectors;
-import org.apache.pinot.query.planner.physical.DispatchablePlanMetadata;
 import org.apache.pinot.query.planner.plannode.AggregateNode;
 import org.apache.pinot.query.planner.plannode.ExchangeNode;
 import org.apache.pinot.query.planner.plannode.FilterNode;
@@ -48,24 +47,30 @@ import org.apache.pinot.query.routing.QueryServerInstance;
  */
 public class ExplainPlanPlanVisitor implements PlanNodeVisitor<StringBuilder, ExplainPlanPlanVisitor.Context> {
 
-  private final QueryPlan _queryPlan;
+  private final DispatchableSubPlan _dispatchableSubPlan;
+
+  public ExplainPlanPlanVisitor(DispatchableSubPlan dispatchableSubPlan) {
+    _dispatchableSubPlan = dispatchableSubPlan;
+  }
 
   /**
    * Explains the query plan.
    *
-   * @see QueryPlan#explain()
-   * @param queryPlan the queryPlan to explain
+   * @see DispatchableSubPlan#explain()
+   * @param dispatchableSubPlan the queryPlan to explain
    * @return a String representation of the query plan tree
    */
-  public static String explain(QueryPlan queryPlan) {
-    if (queryPlan.getQueryStageMap().isEmpty()) {
+  public static String explain(DispatchableSubPlan dispatchableSubPlan) {
+    if (dispatchableSubPlan.getQueryStageList().isEmpty()) {
       return "EMPTY";
     }
 
     // the root of a query plan always only has a single node
-    QueryServerInstance rootServer = queryPlan.getDispatchablePlanMetadataMap().get(0).getServerInstanceToWorkerIdMap()
-        .keySet().iterator().next();
-    return explainFrom(queryPlan, queryPlan.getQueryStageMap().get(0), rootServer);
+    QueryServerInstance rootServer =
+        dispatchableSubPlan.getQueryStageList().get(0).getServerInstanceToWorkerIdMap()
+            .keySet().iterator().next();
+    return explainFrom(dispatchableSubPlan,
+        dispatchableSubPlan.getQueryStageList().get(0).getPlanFragment().getFragmentRoot(), rootServer);
   }
 
   /**
@@ -74,23 +79,20 @@ public class ExplainPlanPlanVisitor implements PlanNodeVisitor<StringBuilder, Ex
    * at a given point in time (for example, printing the tree that will be executed on a
    * local node right before it is executed).
    *
-   * @param queryPlan the entire query plan, including non-executed portions
+   * @param dispatchableSubPlan the entire query plan, including non-executed portions
    * @param node the node to begin traversal
    * @param rootServer the server instance that is executing this plan (should execute {@code node})
    *
    * @return a query plan associated with
    */
-  public static String explainFrom(QueryPlan queryPlan, PlanNode node, QueryServerInstance rootServer) {
-    final ExplainPlanPlanVisitor visitor = new ExplainPlanPlanVisitor(queryPlan);
+  public static String explainFrom(DispatchableSubPlan dispatchableSubPlan, PlanNode node,
+      QueryServerInstance rootServer) {
+    final ExplainPlanPlanVisitor visitor = new ExplainPlanPlanVisitor(dispatchableSubPlan);
     return node
         .visit(visitor, new Context(rootServer, 0, "", "", new StringBuilder()))
         .toString();
   }
 
-  private ExplainPlanPlanVisitor(QueryPlan queryPlan) {
-    _queryPlan = queryPlan;
-  }
-
   private StringBuilder appendInfo(PlanNode node, Context context) {
     int planFragmentId = node.getPlanFragmentId();
     context._builder
@@ -154,10 +156,11 @@ public class ExplainPlanPlanVisitor implements PlanNodeVisitor<StringBuilder, Ex
 
     MailboxSendNode sender = (MailboxSendNode) node.getSender();
     int senderStageId = node.getSenderStageId();
-    DispatchablePlanMetadata metadata = _queryPlan.getDispatchablePlanMetadataMap().get(senderStageId);
-    Map<Integer, Map<String, List<String>>> segments = metadata.getWorkerIdToSegmentsMap();
+    DispatchablePlanFragment dispatchablePlanFragment = _dispatchableSubPlan.getQueryStageList().get(senderStageId);
+    Map<Integer, Map<String, List<String>>> segments = dispatchablePlanFragment.getWorkerIdToSegmentsMap();
 
-    Map<QueryServerInstance, List<Integer>> serverInstanceToWorkerIdMap = metadata.getServerInstanceToWorkerIdMap();
+    Map<QueryServerInstance, List<Integer>> serverInstanceToWorkerIdMap =
+        dispatchablePlanFragment.getServerInstanceToWorkerIdMap();
     Iterator<QueryServerInstance> iterator = serverInstanceToWorkerIdMap.keySet().iterator();
     while (iterator.hasNext()) {
       QueryServerInstance queryServerInstance = iterator.next();
@@ -191,8 +194,9 @@ public class ExplainPlanPlanVisitor implements PlanNodeVisitor<StringBuilder, Ex
     appendInfo(node, context);
 
     int receiverStageId = node.getReceiverStageId();
-    Map<QueryServerInstance, List<Integer>> servers = _queryPlan.getDispatchablePlanMetadataMap().get(receiverStageId)
-        .getServerInstanceToWorkerIdMap();
+    Map<QueryServerInstance, List<Integer>> servers =
+        _dispatchableSubPlan.getQueryStageList().get(receiverStageId)
+            .getServerInstanceToWorkerIdMap();
     context._builder.append("->");
     String receivers = servers.entrySet().stream()
         .map(ExplainPlanPlanVisitor::stringifyQueryServerInstanceToWorkerIdsEntry)
@@ -215,7 +219,7 @@ public class ExplainPlanPlanVisitor implements PlanNodeVisitor<StringBuilder, Ex
   public StringBuilder visitTableScan(TableScanNode node, Context context) {
     return appendInfo(node, context)
         .append(' ')
-        .append(_queryPlan.getDispatchablePlanMetadataMap()
+        .append(_dispatchableSubPlan.getQueryStageList()
             .get(node.getPlanFragmentId())
             .getWorkerIdToSegmentsMap()
             .get(context._host))
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/PlanFragment.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/PlanFragment.java
new file mode 100644
index 0000000000..b8dd4e1be6
--- /dev/null
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/PlanFragment.java
@@ -0,0 +1,61 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.query.planner;
+
+import java.util.List;
+import org.apache.pinot.query.planner.plannode.PlanNode;
+
+
+/**
+ * The {@code PlanFragment} is the logical sub query plan that should be scheduled together from the result of
+ * {@link org.apache.pinot.query.planner.logical.PinotQueryFragmenter}.
+ *
+ */
+public class PlanFragment {
+
+  private final int _fragmentId;
+  private final PlanNode _fragmentRoot;
+  private final PlanFragmentMetadata _fragmentMetadata;
+
+  private final List<PlanFragment> _children;
+
+  public PlanFragment(int fragmentId, PlanNode fragmentRoot, PlanFragmentMetadata fragmentMetadata,
+      List<PlanFragment> children) {
+    _fragmentId = fragmentId;
+    _fragmentRoot = fragmentRoot;
+    _fragmentMetadata = fragmentMetadata;
+    _children = children;
+  }
+
+  public int getFragmentId() {
+    return _fragmentId;
+  }
+
+  public PlanNode getFragmentRoot() {
+    return _fragmentRoot;
+  }
+
+  public PlanFragmentMetadata getFragmentMetadata() {
+    return _fragmentMetadata;
+  }
+
+  public List<PlanFragment> getChildren() {
+    return _children;
+  }
+}
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/PlanFragmentMetadata.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/PlanFragmentMetadata.java
new file mode 100644
index 0000000000..0916a8b34c
--- /dev/null
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/PlanFragmentMetadata.java
@@ -0,0 +1,28 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.query.planner;
+
+/**
+ * Metadata for a plan fragment. This class won't leave the query planner/broker side.
+ */
+public class PlanFragmentMetadata {
+
+  public PlanFragmentMetadata() {
+  }
+}
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/QueryPlan.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/QueryPlan.java
index 25f410cf84..2b762a26d2 100644
--- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/QueryPlan.java
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/QueryPlan.java
@@ -18,129 +18,34 @@
  */
 package org.apache.pinot.query.planner;
 
-import java.util.Arrays;
-import java.util.List;
-import java.util.Map;
-import org.apache.calcite.util.Pair;
-import org.apache.pinot.query.planner.physical.DispatchablePlanMetadata;
 import org.apache.pinot.query.planner.plannode.PlanNode;
-import org.apache.pinot.query.routing.MailboxMetadata;
-import org.apache.pinot.query.routing.PlanFragmentMetadata;
-import org.apache.pinot.query.routing.QueryServerInstance;
-import org.apache.pinot.query.routing.VirtualServerAddress;
-import org.apache.pinot.query.routing.WorkerMetadata;
 
 
 /**
- * The {@code QueryPlan} is the dispatchable query execution plan from the result of
- * {@link org.apache.pinot.query.planner.logical.StagePlanner}.
+ * The {@code QueryPlan} is the logical query plan from the result of
+ * {@link org.apache.pinot.query.planner.logical.PinotLogicalQueryPlanner}.
  *
- * <p>QueryPlan should contain the necessary stage boundary information and the cross exchange information
- * for:
- * <ul>
- *   <li>dispatch individual stages to executor.</li>
- *   <li>instruction for stage executor to establish connection channels to other stages.</li>
- *   <li>instruction for encoding data blocks & transferring between stages based on partitioning scheme.</li>
- * </ul>
  */
 public class QueryPlan {
-  private final List<Pair<Integer, String>> _queryResultFields;
-  private final Map<Integer, PlanNode> _queryStageMap;
-  private final List<PlanFragmentMetadata> _planFragmentMetadataList;
-  private final Map<Integer, DispatchablePlanMetadata> _dispatchablePlanMetadataMap;
+  private final PlanNode _planRoot;
+  private final QueryPlanMetadata _queryPlanMetadata;
 
-  public QueryPlan(List<Pair<Integer, String>> fields, Map<Integer, PlanNode> queryStageMap,
-      Map<Integer, DispatchablePlanMetadata> dispatchablePlanMetadataMap) {
-    _queryResultFields = fields;
-    _queryStageMap = queryStageMap;
-    _dispatchablePlanMetadataMap = dispatchablePlanMetadataMap;
-    _planFragmentMetadataList = constructStageMetadataList(_dispatchablePlanMetadataMap);
+  public QueryPlan(PlanNode queryPlanRoot, QueryPlanMetadata queryPlanMetadata) {
+    _planRoot = queryPlanRoot;
+    _queryPlanMetadata = queryPlanMetadata;
   }
 
   /**
-   * Get the map between stageID and the stage plan root node.
-   * @return stage plan map.
+   * Get the root node of the query plan.
    */
-  public Map<Integer, PlanNode> getQueryStageMap() {
-    return _queryStageMap;
+  public PlanNode getPlanRoot() {
+    return _planRoot;
   }
 
   /**
-   * Get the stage metadata information based on planFragmentId.
-   * @return stage metadata info.
+   * Get the metadata of the query plan.
    */
-  public PlanFragmentMetadata getStageMetadata(int planFragmentId) {
-    return _planFragmentMetadataList.get(planFragmentId);
-  }
-
-  /**
-   * Get the dispatch metadata information.
-   * @return dispatch metadata info.
-   */
-  public Map<Integer, DispatchablePlanMetadata> getDispatchablePlanMetadataMap() {
-    return _dispatchablePlanMetadataMap;
-  }
-
-  /**
-   * Get the query result field.
-   * @return query result field.
-   */
-  public List<Pair<Integer, String>> getQueryResultFields() {
-    return _queryResultFields;
-  }
-
-  /**
-   * Explains the {@code QueryPlan}
-   *
-   * @return a human-readable tree explaining the query plan
-   * @see ExplainPlanPlanVisitor#explain(QueryPlan)
-   * @apiNote this is <b>NOT</b> identical to the SQL {@code EXPLAIN PLAN FOR} functionality
-   *          and is instead intended to be used by developers debugging during feature
-   *          development
-   */
-  public String explain() {
-    return ExplainPlanPlanVisitor.explain(this);
-  }
-
-  /**
-   * Convert the {@link DispatchablePlanMetadata} into dispatchable info for each stage/worker.
-   */
-  private static List<PlanFragmentMetadata> constructStageMetadataList(
-      Map<Integer, DispatchablePlanMetadata> dispatchablePlanMetadataMap) {
-    PlanFragmentMetadata[] planFragmentMetadataList = new PlanFragmentMetadata[dispatchablePlanMetadataMap.size()];
-    for (Map.Entry<Integer, DispatchablePlanMetadata> dispatchableEntry : dispatchablePlanMetadataMap.entrySet()) {
-      DispatchablePlanMetadata dispatchablePlanMetadata = dispatchableEntry.getValue();
-
-      // construct each worker metadata
-      WorkerMetadata[] workerMetadataList = new WorkerMetadata[dispatchablePlanMetadata.getTotalWorkerCount()];
-      for (Map.Entry<QueryServerInstance, List<Integer>> queryServerEntry
-          : dispatchablePlanMetadata.getServerInstanceToWorkerIdMap().entrySet()) {
-        for (int workerId : queryServerEntry.getValue()) {
-          VirtualServerAddress virtualServerAddress = new VirtualServerAddress(queryServerEntry.getKey(), workerId);
-          WorkerMetadata.Builder builder = new WorkerMetadata.Builder();
-          builder.setVirtualServerAddress(virtualServerAddress);
-          Map<Integer, MailboxMetadata> planFragmentToMailboxMetadata =
-              dispatchablePlanMetadata.getWorkerIdToMailBoxIdsMap().get(workerId);
-          builder.putAllMailBoxInfosMap(planFragmentToMailboxMetadata);
-          if (dispatchablePlanMetadata.getScannedTables().size() == 1) {
-            builder.addTableSegmentsMap(dispatchablePlanMetadata.getWorkerIdToSegmentsMap().get(workerId));
-          }
-          workerMetadataList[workerId] = builder.build();
-        }
-      }
-
-      // construct the stageMetadata
-      int planFragmentId = dispatchableEntry.getKey();
-      PlanFragmentMetadata.Builder builder = new PlanFragmentMetadata.Builder();
-      builder.setWorkerMetadataList(Arrays.asList(workerMetadataList));
-      if (dispatchablePlanMetadata.getScannedTables().size() == 1) {
-        builder.addTableName(dispatchablePlanMetadata.getScannedTables().get(0));
-      }
-      if (dispatchablePlanMetadata.getTimeBoundaryInfo() != null) {
-        builder.addTimeBoundaryInfo(dispatchablePlanMetadata.getTimeBoundaryInfo());
-      }
-      planFragmentMetadataList[planFragmentId] = builder.build();
-    }
-    return Arrays.asList(planFragmentMetadataList);
+  public QueryPlanMetadata getPlanMetadata() {
+    return _queryPlanMetadata;
   }
 }
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/QueryPlanMetadata.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/QueryPlanMetadata.java
new file mode 100644
index 0000000000..9bd6267f22
--- /dev/null
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/QueryPlanMetadata.java
@@ -0,0 +1,63 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.query.planner;
+
+import com.google.common.collect.ImmutableList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import org.apache.calcite.util.Pair;
+
+
+/**
+ * QueryPlanMetadata contains the metadata of the {@code QueryPlan}.
+ * It contains the table names and the fields of the query result.
+ */
+public class QueryPlanMetadata {
+  private final Set<String> _tableNames;
+  private final List<Pair<Integer, String>> _fields;
+  private final Map<String, String> _customProperties;
+
+  public QueryPlanMetadata(Set<String> tableNames, ImmutableList<Pair<Integer, String>> fields) {
+    _tableNames = tableNames;
+    _fields = fields;
+    _customProperties = new HashMap<>();
+  }
+
+  public Map<String, String> getCustomProperties() {
+    return _customProperties;
+  }
+
+  /**
+   * Get the table names.
+   * @return table names.
+   */
+  public Set<String> getTableNames() {
+    return _tableNames;
+  }
+
+  /**
+   * Get the query result field.
+   * @return query result field.
+   */
+  public List<Pair<Integer, String>> getFields() {
+    return _fields;
+  }
+}
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/SubPlan.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/SubPlan.java
new file mode 100644
index 0000000000..f0af6a428d
--- /dev/null
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/SubPlan.java
@@ -0,0 +1,60 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.query.planner;
+
+import java.util.List;
+
+
+/**
+ * The {@code SubPlan} is the logical sub query plan that should be scheduled together from the result of
+ * {@link org.apache.pinot.query.planner.logical.SubPlanFragmenter}.
+ *
+ */
+public class SubPlan {
+  /**
+   * The root node of the sub query plan.
+   */
+  private final PlanFragment _subPlanRoot;
+  /**
+   * The metadata of the sub query plan.
+   */
+  private final SubPlanMetadata _subPlanMetadata;
+  /**
+   * The list of children sub query plans.
+   */
+  private final List<SubPlan> _children;
+
+  public SubPlan(PlanFragment subPlanRoot, SubPlanMetadata subPlanMetadata, List<SubPlan> children) {
+    _subPlanRoot = subPlanRoot;
+    _subPlanMetadata = subPlanMetadata;
+    _children = children;
+  }
+
+  public PlanFragment getSubPlanRoot() {
+    return _subPlanRoot;
+  }
+
+  public SubPlanMetadata getSubPlanMetadata() {
+    return _subPlanMetadata;
+  }
+
+  public List<SubPlan> getChildren() {
+    return _children;
+  }
+}
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/SubPlanMetadata.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/SubPlanMetadata.java
new file mode 100644
index 0000000000..ca48949275
--- /dev/null
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/SubPlanMetadata.java
@@ -0,0 +1,53 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.query.planner;
+
+import java.util.List;
+import java.util.Set;
+import org.apache.calcite.util.Pair;
+
+
+/**
+ * Metadata for a subplan. This class won't leave the query planner/broker side.
+ */
+public class SubPlanMetadata {
+
+  /**
+   * The set of tables that are scanned in this subplan.
+   */
+  private final Set<String> _tableNames;
+
+  /**
+   * The list of fields that are surfaced by this subplan. Only valid for SubPlan Id 0.
+   */
+  private List<Pair<Integer, String>> _fields;
+
+  public SubPlanMetadata(Set<String> tableNames, List<Pair<Integer, String>> fields) {
+    _tableNames = tableNames;
+    _fields = fields;
+  }
+
+  public List<Pair<Integer, String>> getFields() {
+    return _fields;
+  }
+
+  public Set<String> getTableNames() {
+    return _tableNames;
+  }
+}
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/LiteralValueNode.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/LiteralValueNode.java
new file mode 100644
index 0000000000..4239108c4c
--- /dev/null
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/LiteralValueNode.java
@@ -0,0 +1,56 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.query.planner.logical;
+
+import org.apache.pinot.common.datatable.DataTable;
+import org.apache.pinot.common.utils.DataSchema;
+import org.apache.pinot.query.planner.plannode.AbstractPlanNode;
+import org.apache.pinot.query.planner.plannode.PlanNodeVisitor;
+
+
+/**
+ * TODO: A placeholder class for literal values coming after SubPlan execution.
+ * Expected to have drastic change in the future.
+ */
+public class LiteralValueNode extends AbstractPlanNode {
+
+  private DataTable _dataTable;
+
+  public LiteralValueNode(DataSchema dataSchema) {
+    super(-1, dataSchema);
+  }
+
+  public void setDataTable(DataTable dataTable) {
+    _dataTable = dataTable;
+  }
+
+  public DataTable getDataTable() {
+    return _dataTable;
+  }
+
+  @Override
+  public String explain() {
+    return "LITERAL_VALUE";
+  }
+
+  @Override
+  public <T, C> T visit(PlanNodeVisitor<T, C> visitor, C context) {
+    throw new UnsupportedOperationException("LiteralValueNode visit is not supported yet");
+  }
+}
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/PinotLogicalQueryPlanner.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/PinotLogicalQueryPlanner.java
new file mode 100644
index 0000000000..795276b1e5
--- /dev/null
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/PinotLogicalQueryPlanner.java
@@ -0,0 +1,142 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.query.planner.logical;
+
+import com.google.common.collect.ImmutableList;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import org.apache.calcite.rel.RelDistribution;
+import org.apache.calcite.rel.RelNode;
+import org.apache.calcite.rel.RelRoot;
+import org.apache.pinot.query.planner.PlanFragment;
+import org.apache.pinot.query.planner.PlanFragmentMetadata;
+import org.apache.pinot.query.planner.QueryPlan;
+import org.apache.pinot.query.planner.QueryPlanMetadata;
+import org.apache.pinot.query.planner.SubPlan;
+import org.apache.pinot.query.planner.SubPlanMetadata;
+import org.apache.pinot.query.planner.plannode.MailboxReceiveNode;
+import org.apache.pinot.query.planner.plannode.MailboxSendNode;
+import org.apache.pinot.query.planner.plannode.PlanNode;
+
+
+/**
+ * PinotLogicalQueryPlanner walks top-down from {@link RelRoot} and construct a forest of trees with {@link PlanNode}.
+ *
+ * This class is non-threadsafe. Do not reuse the stage planner for multiple query plans.
+ */
+public class PinotLogicalQueryPlanner {
+
+  /**
+   * planQuery achieves 2 objective:
+   *   1. convert Calcite's {@link RelNode} to Pinot's {@link PlanNode} format from the {@link RelRoot} of Calcite's
+   *   LogicalPlanner result.
+   *   2. while walking Calcite's {@link RelNode} tree, populate {@link QueryPlanMetadata}.
+   *
+   * @param relRoot relational plan root.
+   * @return dispatchable plan.
+   */
+  public QueryPlan planQuery(RelRoot relRoot) {
+    RelNode relRootNode = relRoot.rel;
+    // Walk through RelNode tree and construct a StageNode tree.
+    PlanNode globalRoot = relNodeToStageNode(relRootNode);
+    QueryPlanMetadata queryPlanMetadata =
+        new QueryPlanMetadata(RelToPlanNodeConverter.getTableNamesFromRelRoot(relRootNode), relRoot.fields);
+    return new QueryPlan(globalRoot, queryPlanMetadata);
+  }
+
+  /**
+   * Convert the Pinot plan from {@link PinotLogicalQueryPlanner#planQuery(RelRoot)} into a {@link SubPlan}.
+   *
+   * @param queryPlan relational plan root.
+   * @return dispatchable plan.
+   */
+  public SubPlan makePlan(QueryPlan queryPlan) {
+    PlanNode globalRoot = queryPlan.getPlanRoot();
+
+    // Fragment the stage tree into multiple SubPlans.
+    SubPlanFragmenter.Context subPlanContext = new SubPlanFragmenter.Context();
+    subPlanContext._subPlanIdToRootNodeMap.put(0, globalRoot);
+    subPlanContext._subPlanIdToMetadataMap.put(0,
+        new SubPlanMetadata(queryPlan.getPlanMetadata().getTableNames(), queryPlan.getPlanMetadata().getFields()));
+    globalRoot.visit(SubPlanFragmenter.INSTANCE, subPlanContext);
+
+    Map<Integer, SubPlan> subPlanMap = new HashMap<>();
+    for (Map.Entry<Integer, PlanNode> subPlanEntry : subPlanContext._subPlanIdToRootNodeMap.entrySet()) {
+      int subPlanId = subPlanEntry.getKey();
+      PlanNode subPlanRoot = subPlanEntry.getValue();
+      PlanFragmenter.Context planFragmentContext = new PlanFragmenter.Context();
+      planFragmentContext._planFragmentIdToRootNodeMap.put(1,
+          new PlanFragment(1, subPlanRoot, new PlanFragmentMetadata(), new ArrayList<>()));
+      subPlanRoot = subPlanRoot.visit(PlanFragmenter.INSTANCE, planFragmentContext);
+
+      // Sub plan root needs to send results back to the Broker ROOT, a.k.a. the client response node. the last stage
+      // only has one
+      // receiver so doesn't matter what the exchange type is. setting it to SINGLETON by default.
+      PlanNode subPlanRootSenderNode =
+          new MailboxSendNode(subPlanRoot.getPlanFragmentId(), subPlanRoot.getDataSchema(),
+              0, RelDistribution.Type.RANDOM_DISTRIBUTED, null, null, false);
+      subPlanRootSenderNode.addInput(subPlanRoot);
+
+      PlanNode subPlanRootReceiverNode =
+          new MailboxReceiveNode(0, subPlanRoot.getDataSchema(), subPlanRoot.getPlanFragmentId(),
+              RelDistribution.Type.RANDOM_DISTRIBUTED, null, null, false, false, subPlanRootSenderNode);
+      subPlanRoot = subPlanRootReceiverNode;
+      PlanFragment planFragment1 = planFragmentContext._planFragmentIdToRootNodeMap.get(1);
+      planFragmentContext._planFragmentIdToRootNodeMap.put(1,
+          new PlanFragment(1, subPlanRootSenderNode, planFragment1.getFragmentMetadata(), planFragment1.getChildren()));
+      PlanFragment rootPlanFragment
+          = new PlanFragment(subPlanRoot.getPlanFragmentId(), subPlanRoot, new PlanFragmentMetadata(),
+          ImmutableList.of(planFragmentContext._planFragmentIdToRootNodeMap.get(1)));
+      planFragmentContext._planFragmentIdToRootNodeMap.put(0, rootPlanFragment);
+      for (Map.Entry<Integer, List<Integer>> planFragmentToChildrenEntry
+          : planFragmentContext._planFragmentIdToChildrenMap.entrySet()) {
+        int planFragmentId = planFragmentToChildrenEntry.getKey();
+        List<Integer> planFragmentChildren = planFragmentToChildrenEntry.getValue();
+        for (int planFragmentChild : planFragmentChildren) {
+          planFragmentContext._planFragmentIdToRootNodeMap.get(planFragmentId).getChildren()
+              .add(planFragmentContext._planFragmentIdToRootNodeMap.get(planFragmentChild));
+        }
+      }
+      SubPlan subPlan = new SubPlan(planFragmentContext._planFragmentIdToRootNodeMap.get(0),
+          subPlanContext._subPlanIdToMetadataMap.get(0), new ArrayList<>());
+      subPlanMap.put(subPlanId, subPlan);
+    }
+    for (Map.Entry<Integer, List<Integer>> subPlanToChildrenEntry : subPlanContext._subPlanIdToChildrenMap.entrySet()) {
+      int subPlanId = subPlanToChildrenEntry.getKey();
+      List<Integer> subPlanChildren = subPlanToChildrenEntry.getValue();
+      for (int subPlanChild : subPlanChildren) {
+        subPlanMap.get(subPlanId).getChildren().add(subPlanMap.get(subPlanChild));
+      }
+    }
+    return subPlanMap.get(0);
+  }
+
+  // non-threadsafe
+  // TODO: add dataSchema (extracted from RelNode schema) to the StageNode.
+  private PlanNode relNodeToStageNode(RelNode node) {
+    PlanNode planNode = RelToPlanNodeConverter.toStageNode(node, -1);
+    List<RelNode> inputs = node.getInputs();
+    for (RelNode input : inputs) {
+      planNode.addInput(relNodeToStageNode(input));
+    }
+    return planNode;
+  }
+}
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/StageFragmenter.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/PlanFragmenter.java
similarity index 60%
copy from pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/StageFragmenter.java
copy to pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/PlanFragmenter.java
index d04443f04f..1c4c90f958 100644
--- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/StageFragmenter.java
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/PlanFragmenter.java
@@ -18,8 +18,13 @@
  */
 package org.apache.pinot.query.planner.logical;
 
+import java.util.ArrayList;
+import java.util.HashMap;
 import java.util.List;
+import java.util.Map;
 import org.apache.calcite.rel.RelDistribution;
+import org.apache.pinot.query.planner.PlanFragment;
+import org.apache.pinot.query.planner.PlanFragmentMetadata;
 import org.apache.pinot.query.planner.partitioning.FieldSelectionKeySelector;
 import org.apache.pinot.query.planner.partitioning.KeySelector;
 import org.apache.pinot.query.planner.plannode.AggregateNode;
@@ -38,14 +43,25 @@ import org.apache.pinot.query.planner.plannode.ValueNode;
 import org.apache.pinot.query.planner.plannode.WindowNode;
 
 
-public class StageFragmenter implements PlanNodeVisitor<PlanNode, StageFragmenter.Context> {
-  public static final StageFragmenter INSTANCE = new StageFragmenter();
+/**
+ * PlanFragmenter is an implementation of {@link PlanNodeVisitor} to fragment a
+ * {@link org.apache.pinot.query.planner.SubPlan} into multiple {@link PlanFragment}.
+ *
+ * The fragmenting process is as follows:
+ * 1. Traverse the plan tree in a depth-first manner;
+ * 2. For each node, if it is a PlanFragment splittable ExchangeNode, split it into {@link MailboxReceiveNode} and
+ * {@link MailboxSendNode} pair;
+ * 3. Assign current PlanFragment Id to {@link MailboxReceiveNode};
+ * 4. Increment current PlanFragment Id by one and assign it to the {@link MailboxSendNode}.
+ */
+public class PlanFragmenter implements PlanNodeVisitor<PlanNode, PlanFragmenter.Context> {
+  public static final PlanFragmenter INSTANCE = new PlanFragmenter();
 
   private PlanNode process(PlanNode node, Context context) {
-    node.setPlanFragmentId(context._currentStageId);
+    node.setPlanFragmentId(context._currentPlanFragmentId);
     List<PlanNode> inputs = node.getInputs();
     for (int i = 0; i < inputs.size(); i++) {
-      context._previousStageId = node.getPlanFragmentId();
+      context._previousPlanFragmentId = node.getPlanFragmentId();
       inputs.set(i, inputs.get(i).visit(this, context));
     }
     return node;
@@ -68,12 +84,12 @@ public class StageFragmenter implements PlanNodeVisitor<PlanNode, StageFragmente
 
   @Override
   public PlanNode visitMailboxReceive(MailboxReceiveNode node, Context context) {
-    throw new UnsupportedOperationException("MailboxReceiveNode should not be visited by StageFragmenter");
+    throw new UnsupportedOperationException("MailboxReceiveNode should not be visited by PlanNodeFragmenter");
   }
 
   @Override
   public PlanNode visitMailboxSend(MailboxSendNode node, Context context) {
-    throw new UnsupportedOperationException("MailboxSendNode should not be visited by StageFragmenter");
+    throw new UnsupportedOperationException("MailboxSendNode should not be visited by PlanNodeFragmenter");
   }
 
   @Override
@@ -108,10 +124,12 @@ public class StageFragmenter implements PlanNodeVisitor<PlanNode, StageFragmente
 
   @Override
   public PlanNode visitExchange(ExchangeNode node, Context context) {
-    int nodeStageId = context._previousStageId;
-
-    context._currentStageId++;
-    PlanNode nextStageRoot = node.getInputs().get(0).visit(this, context);
+    if (!isPlanFragmentSplitter(node)) {
+      return process(node, context);
+    }
+    int currentPlanFragmentId = context._previousPlanFragmentId;
+    int nextPlanFragmentId = ++context._currentPlanFragmentId;
+    PlanNode nextPlanFragmentRoot = node.getInputs().get(0).visit(this, context);
 
     List<Integer> distributionKeys = node.getDistributionKeys();
     RelDistribution.Type exchangeType = node.getDistributionType();
@@ -122,20 +140,36 @@ public class StageFragmenter implements PlanNodeVisitor<PlanNode, StageFragmente
     KeySelector<Object[], Object[]> keySelector = exchangeType == RelDistribution.Type.HASH_DISTRIBUTED
         ? new FieldSelectionKeySelector(distributionKeys) : null;
 
-    PlanNode mailboxSender = new MailboxSendNode(nextStageRoot.getPlanFragmentId(), nextStageRoot.getDataSchema(),
-        nodeStageId, exchangeType, keySelector, node.getCollations(), node.isSortOnSender());
-    PlanNode mailboxReceiver = new MailboxReceiveNode(nodeStageId, nextStageRoot.getDataSchema(),
-        nextStageRoot.getPlanFragmentId(), exchangeType, keySelector,
+    PlanNode mailboxSender =
+        new MailboxSendNode(nextPlanFragmentId, nextPlanFragmentRoot.getDataSchema(),
+            currentPlanFragmentId, exchangeType, keySelector, node.getCollations(), node.isSortOnSender());
+    PlanNode mailboxReceiver = new MailboxReceiveNode(currentPlanFragmentId, nextPlanFragmentRoot.getDataSchema(),
+        nextPlanFragmentId, exchangeType, keySelector,
         node.getCollations(), node.isSortOnSender(), node.isSortOnReceiver(), mailboxSender);
-    mailboxSender.addInput(nextStageRoot);
+    mailboxSender.addInput(nextPlanFragmentRoot);
+
+    context._planFragmentIdToRootNodeMap.put(nextPlanFragmentId,
+        new PlanFragment(nextPlanFragmentId, mailboxSender, new PlanFragmentMetadata(), new ArrayList<>()));
+    if (!context._planFragmentIdToChildrenMap.containsKey(currentPlanFragmentId)) {
+      context._planFragmentIdToChildrenMap.put(currentPlanFragmentId, new ArrayList<>());
+    }
+    context._planFragmentIdToChildrenMap.get(currentPlanFragmentId).add(nextPlanFragmentId);
 
     return mailboxReceiver;
   }
 
+  private boolean isPlanFragmentSplitter(PlanNode node) {
+    // TODO: always return true for now, we will add more logic here later.
+    return true;
+  }
+
   public static class Context {
 
-    // Stage ID starts with 1, 0 will be reserved for ROOT PlanFragment.
-    Integer _currentStageId = 1;
-    Integer _previousStageId = 1;
+    // PlanFragment ID starts with 1, 0 will be reserved for ROOT PlanFragment.
+    Integer _currentPlanFragmentId = 1;
+    Integer _previousPlanFragmentId = 1;
+    Map<Integer, PlanFragment> _planFragmentIdToRootNodeMap = new HashMap<>();
+
+    Map<Integer, List<Integer>> _planFragmentIdToChildrenMap = new HashMap<>();
   }
 }
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RelToStageConverter.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RelToPlanNodeConverter.java
similarity index 91%
rename from pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RelToStageConverter.java
rename to pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RelToPlanNodeConverter.java
index 3d353b2514..efe35ce05b 100644
--- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RelToStageConverter.java
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RelToPlanNodeConverter.java
@@ -18,8 +18,11 @@
  */
 package org.apache.pinot.query.planner.logical;
 
+import java.util.HashSet;
 import java.util.List;
+import java.util.Set;
 import java.util.stream.Collectors;
+import org.apache.calcite.plan.RelOptUtil;
 import org.apache.calcite.rel.RelCollation;
 import org.apache.calcite.rel.RelFieldCollation;
 import org.apache.calcite.rel.RelNode;
@@ -59,11 +62,11 @@ import org.apache.pinot.spi.data.FieldSpec;
 
 
 /**
- * The {@code StageNodeConverter} converts a logical {@link RelNode} to a {@link PlanNode}.
+ * The {@link RelToPlanNodeConverter} converts a logical {@link RelNode} to a {@link PlanNode}.
  */
-public final class RelToStageConverter {
+public final class RelToPlanNodeConverter {
 
-  private RelToStageConverter() {
+  private RelToPlanNodeConverter() {
     // do not instantiate.
   }
 
@@ -114,8 +117,11 @@ public final class RelToStageConverter {
       }
     }
     List<RelFieldCollation> fieldCollations = (collation == null) ? null : collation.getFieldCollations();
-    return new ExchangeNode(currentStageId, toDataSchema(node.getRowType()), node.getDistribution(), fieldCollations,
-        isSortOnSender, isSortOnReceiver);
+
+    // Compute all the tables involved under this exchange node
+    Set<String> tableNames = getTableNamesFromRelRoot(node);
+    return new ExchangeNode(currentStageId, toDataSchema(node.getRowType()), tableNames, node.getDistribution(),
+        fieldCollations, isSortOnSender, isSortOnReceiver);
   }
 
   private static PlanNode convertLogicalSetOp(SetOp node, int currentStageId) {
@@ -265,4 +271,16 @@ public final class RelToStageConverter {
       }
     }
   }
+
+  public static Set<String> getTableNamesFromRelRoot(RelNode relRoot) {
+    Set<String> tableNames = new HashSet<>();
+    List<String> qualifiedTableNames = RelOptUtil.findAllTableQualifiedNames(relRoot);
+    for (String qualifiedTableName : qualifiedTableNames) {
+      // Calcite encloses table and schema names in square brackets to properly quote and delimit them in SQL
+      // statements, particularly to handle cases when they contain special characters or reserved keywords.
+      String tableName = qualifiedTableName.replaceAll("^\\[(.*)\\]$", "$1");
+      tableNames.add(tableName);
+    }
+    return tableNames;
+  }
 }
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RexExpression.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RexExpression.java
index abe1738fc6..9b879ab779 100644
--- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RexExpression.java
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RexExpression.java
@@ -48,7 +48,7 @@ public interface RexExpression {
       return new RexExpression.InputRef(((RexInputRef) rexNode).getIndex());
     } else if (rexNode instanceof RexLiteral) {
       RexLiteral rexLiteral = ((RexLiteral) rexNode);
-      FieldSpec.DataType dataType = RelToStageConverter.convertToFieldSpecDataType(rexLiteral.getType());
+      FieldSpec.DataType dataType = RelToPlanNodeConverter.convertToFieldSpecDataType(rexLiteral.getType());
       return new RexExpression.Literal(dataType, toRexValue(dataType, rexLiteral.getValue()));
     } else if (rexNode instanceof RexCall) {
       RexCall rexCall = (RexCall) rexNode;
@@ -70,7 +70,7 @@ public interface RexExpression {
         List<RexExpression> operands =
             rexCall.getOperands().stream().map(RexExpression::toRexExpression).collect(Collectors.toList());
         return new RexExpression.FunctionCall(rexCall.getKind(),
-            RelToStageConverter.convertToFieldSpecDataType(rexCall.getType()),
+            RelToPlanNodeConverter.convertToFieldSpecDataType(rexCall.getType()),
             rexCall.getOperator().getName(), operands);
     }
   }
@@ -78,7 +78,7 @@ public interface RexExpression {
   static RexExpression toRexExpression(AggregateCall aggCall) {
     List<RexExpression> operands = aggCall.getArgList().stream().map(InputRef::new).collect(Collectors.toList());
     return new RexExpression.FunctionCall(aggCall.getAggregation().getKind(),
-        RelToStageConverter.convertToFieldSpecDataType(aggCall.getType()), aggCall.getAggregation().getName(),
+        RelToPlanNodeConverter.convertToFieldSpecDataType(aggCall.getType()), aggCall.getAggregation().getName(),
         operands);
   }
 
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RexExpressionUtils.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RexExpressionUtils.java
index c5a47fe789..242291c280 100644
--- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RexExpressionUtils.java
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RexExpressionUtils.java
@@ -44,7 +44,7 @@ public class RexExpressionUtils {
     List<RexExpression> operands =
         rexCall.getOperands().stream().map(RexExpression::toRexExpression).collect(Collectors.toList());
     return new RexExpression.FunctionCall(rexCall.getKind(),
-        RelToStageConverter.convertToFieldSpecDataType(rexCall.getType()),
+        RelToPlanNodeConverter.convertToFieldSpecDataType(rexCall.getType()),
         "caseWhen", operands);
   }
 
@@ -56,8 +56,9 @@ public class RexExpressionUtils {
     Preconditions.checkState(operands.size() == 1, "CAST takes exactly 2 arguments");
     RelDataType castType = rexCall.getType();
     operands.add(new RexExpression.Literal(FieldSpec.DataType.STRING,
-        RelToStageConverter.convertToFieldSpecDataType(castType).name()));
-    return new RexExpression.FunctionCall(rexCall.getKind(), RelToStageConverter.convertToFieldSpecDataType(castType),
+        RelToPlanNodeConverter.convertToFieldSpecDataType(castType).name()));
+    return new RexExpression.FunctionCall(rexCall.getKind(),
+        RelToPlanNodeConverter.convertToFieldSpecDataType(castType),
         "CAST", operands);
   }
 
@@ -66,7 +67,7 @@ public class RexExpressionUtils {
     List<RexNode> operands = rexCall.getOperands();
     RexInputRef rexInputRef = (RexInputRef) operands.get(0);
     RexLiteral rexLiteral = (RexLiteral) operands.get(1);
-    FieldSpec.DataType dataType = RelToStageConverter.convertToFieldSpecDataType(rexLiteral.getType());
+    FieldSpec.DataType dataType = RelToPlanNodeConverter.convertToFieldSpecDataType(rexLiteral.getType());
     Sarg sarg = rexLiteral.getValueAs(Sarg.class);
     if (sarg.isPoints()) {
       return new RexExpression.FunctionCall(SqlKind.IN, dataType, SqlKind.IN.name(), toFunctionOperands(rexInputRef,
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/StagePlanner.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/StagePlanner.java
deleted file mode 100644
index 5ca7de8543..0000000000
--- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/StagePlanner.java
+++ /dev/null
@@ -1,112 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.pinot.query.planner.logical;
-
-import java.util.List;
-import java.util.Set;
-import org.apache.calcite.rel.RelDistribution;
-import org.apache.calcite.rel.RelNode;
-import org.apache.calcite.rel.RelRoot;
-import org.apache.pinot.common.config.provider.TableCache;
-import org.apache.pinot.query.context.PlannerContext;
-import org.apache.pinot.query.planner.QueryPlan;
-import org.apache.pinot.query.planner.physical.DispatchablePlanContext;
-import org.apache.pinot.query.planner.physical.DispatchablePlanVisitor;
-import org.apache.pinot.query.planner.physical.colocated.GreedyShuffleRewriteVisitor;
-import org.apache.pinot.query.planner.plannode.MailboxReceiveNode;
-import org.apache.pinot.query.planner.plannode.MailboxSendNode;
-import org.apache.pinot.query.planner.plannode.PlanNode;
-import org.apache.pinot.query.routing.WorkerManager;
-
-
-/**
- * QueryPlanMaker walks top-down from {@link RelRoot} and construct a forest of trees with {@link PlanNode}.
- *
- * This class is non-threadsafe. Do not reuse the stage planner for multiple query plans.
- */
-public class StagePlanner {
-  private final PlannerContext _plannerContext;   // DO NOT REMOVE.
-  private final WorkerManager _workerManager;
-  private final TableCache _tableCache;
-  private long _requestId;
-
-  public StagePlanner(PlannerContext plannerContext, WorkerManager workerManager, long requestId,
-      TableCache tableCache) {
-    _plannerContext = plannerContext;
-    _workerManager = workerManager;
-    _requestId = requestId;
-    _tableCache = tableCache;
-  }
-
-  /**
-   * Construct the dispatchable plan from relational logical plan.
-   *
-   * @param relRoot relational plan root.
-   * @return dispatchable plan.
-   */
-  public QueryPlan makePlan(RelRoot relRoot, Set<String> tableNames) {
-    RelNode relRootNode = relRoot.rel;
-
-    // Walk through RelNode tree and construct a StageNode tree.
-    PlanNode globalStageRoot = relNodeToStageNode(relRootNode);
-
-    // Fragment the stage tree into multiple stages.
-    globalStageRoot = globalStageRoot.visit(StageFragmenter.INSTANCE, new StageFragmenter.Context());
-
-    // global root needs to send results back to the ROOT, a.k.a. the client response node. the last stage only has one
-    // receiver so doesn't matter what the exchange type is. setting it to SINGLETON by default.
-    PlanNode globalSenderNode =
-        new MailboxSendNode(globalStageRoot.getPlanFragmentId(), globalStageRoot.getDataSchema(),
-            0, RelDistribution.Type.RANDOM_DISTRIBUTED, null, null, false);
-    globalSenderNode.addInput(globalStageRoot);
-
-    PlanNode globalReceiverNode =
-        new MailboxReceiveNode(0, globalStageRoot.getDataSchema(), globalStageRoot.getPlanFragmentId(),
-            RelDistribution.Type.RANDOM_DISTRIBUTED, null, null, false, false, globalSenderNode);
-
-    // perform physical plan conversion and assign workers to each stage.
-    DispatchablePlanContext dispatchablePlanContext = new DispatchablePlanContext(_workerManager, _requestId,
-        _plannerContext, relRoot.fields, tableNames);
-    QueryPlan queryPlan = DispatchablePlanVisitor.INSTANCE.constructDispatchablePlan(globalReceiverNode,
-        dispatchablePlanContext);
-
-    // Run physical optimizations
-    runPhysicalOptimizers(queryPlan);
-
-    return queryPlan;
-  }
-
-  // non-threadsafe
-  // TODO: add dataSchema (extracted from RelNode schema) to the StageNode.
-  private PlanNode relNodeToStageNode(RelNode node) {
-    PlanNode planNode = RelToStageConverter.toStageNode(node, -1);
-    List<RelNode> inputs = node.getInputs();
-    for (RelNode input : inputs) {
-      planNode.addInput(relNodeToStageNode(input));
-    }
-    return planNode;
-  }
-
-  // TODO: Switch to Worker SPI to avoid multiple-places where workers are assigned.
-  private void runPhysicalOptimizers(QueryPlan queryPlan) {
-    if (_plannerContext.getOptions().getOrDefault("useColocatedJoin", "false").equals("true")) {
-      GreedyShuffleRewriteVisitor.optimizeShuffles(queryPlan, _tableCache);
-    }
-  }
-}
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/StageFragmenter.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/SubPlanFragmenter.java
similarity index 62%
rename from pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/StageFragmenter.java
rename to pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/SubPlanFragmenter.java
index d04443f04f..adb87d1ad8 100644
--- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/StageFragmenter.java
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/SubPlanFragmenter.java
@@ -18,10 +18,12 @@
  */
 package org.apache.pinot.query.planner.logical;
 
+import com.google.common.collect.ImmutableList;
+import java.util.ArrayList;
+import java.util.HashMap;
 import java.util.List;
-import org.apache.calcite.rel.RelDistribution;
-import org.apache.pinot.query.planner.partitioning.FieldSelectionKeySelector;
-import org.apache.pinot.query.planner.partitioning.KeySelector;
+import java.util.Map;
+import org.apache.pinot.query.planner.SubPlanMetadata;
 import org.apache.pinot.query.planner.plannode.AggregateNode;
 import org.apache.pinot.query.planner.plannode.ExchangeNode;
 import org.apache.pinot.query.planner.plannode.FilterNode;
@@ -38,14 +40,23 @@ import org.apache.pinot.query.planner.plannode.ValueNode;
 import org.apache.pinot.query.planner.plannode.WindowNode;
 
 
-public class StageFragmenter implements PlanNodeVisitor<PlanNode, StageFragmenter.Context> {
-  public static final StageFragmenter INSTANCE = new StageFragmenter();
+/**
+ * SubPlanFragmenter is an implementation of {@link PlanNodeVisitor} to fragment a
+ * {@link org.apache.pinot.query.planner.QueryPlan} into multiple {@link org.apache.pinot.query.planner.SubPlan}.
+ *
+ * The fragmenting process is as follows:
+ * 1. Traverse the plan tree in a depth-first manner;
+ * 2. For each node, if it is a SubPlan splittable ExchangeNode, switch it to a {@link LiteralValueNode};
+ * 3. Increment current SubPlan Id by one and keep traverse the tree.
+ */
+public class SubPlanFragmenter implements PlanNodeVisitor<PlanNode, SubPlanFragmenter.Context> {
+  public static final SubPlanFragmenter INSTANCE = new SubPlanFragmenter();
 
   private PlanNode process(PlanNode node, Context context) {
-    node.setPlanFragmentId(context._currentStageId);
+    node.setPlanFragmentId(context._currentSubPlanId);
     List<PlanNode> inputs = node.getInputs();
     for (int i = 0; i < inputs.size(); i++) {
-      context._previousStageId = node.getPlanFragmentId();
+      context._previousSubPlanId = node.getPlanFragmentId();
       inputs.set(i, inputs.get(i).visit(this, context));
     }
     return node;
@@ -108,34 +119,37 @@ public class StageFragmenter implements PlanNodeVisitor<PlanNode, StageFragmente
 
   @Override
   public PlanNode visitExchange(ExchangeNode node, Context context) {
-    int nodeStageId = context._previousStageId;
+    if (!isSubPlanSplitter(node)) {
+      return process(node, context);
+    }
+    int currentStageId = context._previousSubPlanId;
+    int nextSubPlanId = context._currentSubPlanId + 1;
 
-    context._currentStageId++;
+    context._currentSubPlanId = nextSubPlanId;
     PlanNode nextStageRoot = node.getInputs().get(0).visit(this, context);
+    context._subPlanIdToRootNodeMap.put(nextSubPlanId, nextStageRoot);
+    if (!context._subPlanIdToChildrenMap.containsKey(currentStageId)) {
+      context._subPlanIdToChildrenMap.put(currentStageId, new ArrayList<>());
+    }
+    context._subPlanIdToChildrenMap.get(currentStageId).add(nextSubPlanId);
+    context._subPlanIdToMetadataMap.put(nextSubPlanId, new SubPlanMetadata(node.getTableNames(), ImmutableList.of()));
+    PlanNode literalValueNode = new LiteralValueNode(nextStageRoot.getDataSchema());
+    return literalValueNode;
+  }
 
-    List<Integer> distributionKeys = node.getDistributionKeys();
-    RelDistribution.Type exchangeType = node.getDistributionType();
-
-    // make an exchange sender and receiver node pair
-    // only HASH_DISTRIBUTED requires a partition key selector; so all other types (SINGLETON and BROADCAST)
-    // of exchange will not carry a partition key selector.
-    KeySelector<Object[], Object[]> keySelector = exchangeType == RelDistribution.Type.HASH_DISTRIBUTED
-        ? new FieldSelectionKeySelector(distributionKeys) : null;
-
-    PlanNode mailboxSender = new MailboxSendNode(nextStageRoot.getPlanFragmentId(), nextStageRoot.getDataSchema(),
-        nodeStageId, exchangeType, keySelector, node.getCollations(), node.isSortOnSender());
-    PlanNode mailboxReceiver = new MailboxReceiveNode(nodeStageId, nextStageRoot.getDataSchema(),
-        nextStageRoot.getPlanFragmentId(), exchangeType, keySelector,
-        node.getCollations(), node.isSortOnSender(), node.isSortOnReceiver(), mailboxSender);
-    mailboxSender.addInput(nextStageRoot);
-
-    return mailboxReceiver;
+  private boolean isSubPlanSplitter(PlanNode node) {
+    // TODO: implement this when we introduce a new type of exchange node for sub-plan splitter
+    return false;
   }
 
   public static class Context {
+    Map<Integer, PlanNode> _subPlanIdToRootNodeMap = new HashMap<>();
+
+    Map<Integer, List<Integer>> _subPlanIdToChildrenMap = new HashMap<>();
+    Map<Integer, SubPlanMetadata> _subPlanIdToMetadataMap = new HashMap<>();
 
-    // Stage ID starts with 1, 0 will be reserved for ROOT PlanFragment.
-    Integer _currentStageId = 1;
-    Integer _previousStageId = 1;
+    // SubPlan ID starts with 0.
+    Integer _currentSubPlanId = 0;
+    Integer _previousSubPlanId = 0;
   }
 }
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/DispatchablePlanContext.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/DispatchablePlanContext.java
index aa587cf974..c0fb98d9a5 100644
--- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/DispatchablePlanContext.java
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/DispatchablePlanContext.java
@@ -18,14 +18,21 @@
  */
 package org.apache.pinot.query.planner.physical;
 
+import com.google.common.base.Preconditions;
+import java.util.Arrays;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
 import org.apache.calcite.util.Pair;
 import org.apache.pinot.query.context.PlannerContext;
+import org.apache.pinot.query.planner.DispatchablePlanFragment;
+import org.apache.pinot.query.planner.PlanFragment;
 import org.apache.pinot.query.planner.plannode.PlanNode;
+import org.apache.pinot.query.routing.QueryServerInstance;
+import org.apache.pinot.query.routing.VirtualServerAddress;
 import org.apache.pinot.query.routing.WorkerManager;
+import org.apache.pinot.query.routing.WorkerMetadata;
 
 
 public class DispatchablePlanContext {
@@ -78,4 +85,56 @@ public class DispatchablePlanContext {
   public Map<Integer, PlanNode> getDispatchablePlanStageRootMap() {
     return _dispatchablePlanStageRootMap;
   }
+
+  public List<DispatchablePlanFragment> constructDispatchablePlanFragmentList(PlanFragment subPlanRoot) {
+    DispatchablePlanFragment[] dispatchablePlanFragmentArray =
+        new DispatchablePlanFragment[_dispatchablePlanStageRootMap.size()];
+    createDispatchablePlanFragmentList(dispatchablePlanFragmentArray, subPlanRoot);
+    List<DispatchablePlanFragment> dispatchablePlanFragmentList = Arrays.asList(dispatchablePlanFragmentArray);
+    for (Map.Entry<Integer, DispatchablePlanMetadata> dispatchableEntry : _dispatchablePlanMetadataMap.entrySet()) {
+      DispatchablePlanMetadata dispatchablePlanMetadata = dispatchableEntry.getValue();
+
+      // construct each worker metadata
+      WorkerMetadata[] workerMetadataList = new WorkerMetadata[dispatchablePlanMetadata.getTotalWorkerCount()];
+      for (Map.Entry<QueryServerInstance, List<Integer>> queryServerEntry
+          : dispatchablePlanMetadata.getServerInstanceToWorkerIdMap().entrySet()) {
+        for (int workerId : queryServerEntry.getValue()) {
+          VirtualServerAddress virtualServerAddress = new VirtualServerAddress(queryServerEntry.getKey(), workerId);
+          WorkerMetadata.Builder builder = new WorkerMetadata.Builder();
+          builder.setVirtualServerAddress(virtualServerAddress);
+          if (dispatchablePlanMetadata.getScannedTables().size() == 1) {
+            builder.addTableSegmentsMap(dispatchablePlanMetadata.getWorkerIdToSegmentsMap().get(workerId));
+          }
+          builder.putAllMailBoxInfosMap(dispatchablePlanMetadata.getWorkerIdToMailBoxIdsMap().get(workerId));
+          workerMetadataList[workerId] = builder.build();
+        }
+      }
+
+      // set the stageMetadata
+      int stageId = dispatchableEntry.getKey();
+      dispatchablePlanFragmentList.get(stageId).setWorkerMetadataList(Arrays.asList(workerMetadataList));
+      dispatchablePlanFragmentList.get(stageId)
+          .setWorkerIdToSegmentsMap(dispatchablePlanMetadata.getWorkerIdToSegmentsMap());
+      dispatchablePlanFragmentList.get(stageId)
+          .setServerInstanceToWorkerIdMap(dispatchablePlanMetadata.getServerInstanceToWorkerIdMap());
+      Preconditions.checkState(dispatchablePlanMetadata.getScannedTables().size() <= 1,
+          "More than one table is not supported yet");
+      if (dispatchablePlanMetadata.getScannedTables().size() == 1) {
+        dispatchablePlanFragmentList.get(stageId).setTableName(dispatchablePlanMetadata.getScannedTables().get(0));
+      }
+      if (dispatchablePlanMetadata.getTimeBoundaryInfo() != null) {
+        dispatchablePlanFragmentList.get(stageId)
+            .setTimeBoundaryInfo(dispatchablePlanMetadata.getTimeBoundaryInfo());
+      }
+    }
+    return dispatchablePlanFragmentList;
+  }
+
+  private void createDispatchablePlanFragmentList(DispatchablePlanFragment[] dispatchablePlanFragmentArray,
+      PlanFragment planFragmentRoot) {
+    dispatchablePlanFragmentArray[planFragmentRoot.getFragmentId()] = new DispatchablePlanFragment(planFragmentRoot);
+    for (PlanFragment childPlanFragment : planFragmentRoot.getChildren()) {
+      createDispatchablePlanFragmentList(dispatchablePlanFragmentArray, childPlanFragment);
+    }
+  }
 }
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/DispatchablePlanVisitor.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/DispatchablePlanVisitor.java
index 964d332127..f0e63e1760 100644
--- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/DispatchablePlanVisitor.java
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/DispatchablePlanVisitor.java
@@ -18,7 +18,6 @@
  */
 package org.apache.pinot.query.planner.physical;
 
-import org.apache.pinot.query.planner.QueryPlan;
 import org.apache.pinot.query.planner.plannode.AggregateNode;
 import org.apache.pinot.query.planner.plannode.ExchangeNode;
 import org.apache.pinot.query.planner.plannode.FilterNode;
@@ -41,58 +40,12 @@ public class DispatchablePlanVisitor implements PlanNodeVisitor<Void, Dispatchab
   private DispatchablePlanVisitor() {
   }
 
-  /**
-   * Entry point for attaching dispatch metadata to a query plan. It walks through the plan via the global
-   * {@link PlanNode} root of the query and:
-   * <ul>
-   *   <li>break down the {@link PlanNode}s into Stages that can run on a single worker.</li>
-   *   <li>each stage is represented by a subset of {@link PlanNode}s without data exchange.</li>
-   *   <li>attach worker execution information including physical server address, worker ID to each stage.</li>
-   * </ul>
-   *
-   * @param globalReceiverNode the entrypoint of the stage plan.
-   * @param dispatchablePlanContext dispatchable plan context used to record the walk of the stage node tree.
-   */
-  public QueryPlan constructDispatchablePlan(PlanNode globalReceiverNode,
-      DispatchablePlanContext dispatchablePlanContext) {
-    // 1. start by visiting the stage root.
-    globalReceiverNode.visit(DispatchablePlanVisitor.INSTANCE, dispatchablePlanContext);
-    // 2. add a special stage for the global mailbox receive, this runs on the dispatcher.
-    dispatchablePlanContext.getDispatchablePlanStageRootMap().put(0, globalReceiverNode);
-    // 3. add worker assignment after the dispatchable plan context is fulfilled after the visit.
-    computeWorkerAssignment(globalReceiverNode, dispatchablePlanContext);
-    // 4. compute the mailbox assignment for each stage.
-    // TODO: refactor this to be a pluggable interface.
-    computeMailboxAssignment(dispatchablePlanContext);
-    // 5. convert it into query plan.
-    // TODO: refactor this to be a pluggable interface.
-    return finalizeQueryPlan(dispatchablePlanContext);
-  }
-
-  private void computeMailboxAssignment(DispatchablePlanContext dispatchablePlanContext) {
-    dispatchablePlanContext.getDispatchablePlanStageRootMap().get(0).visit(MailboxAssignmentVisitor.INSTANCE,
-        dispatchablePlanContext);
-  }
-
-  private static QueryPlan finalizeQueryPlan(DispatchablePlanContext dispatchablePlanContext) {
-    return new QueryPlan(dispatchablePlanContext.getResultFields(),
-        dispatchablePlanContext.getDispatchablePlanStageRootMap(),
-        dispatchablePlanContext.getDispatchablePlanMetadataMap());
-  }
-
   private static DispatchablePlanMetadata getOrCreateDispatchablePlanMetadata(PlanNode node,
       DispatchablePlanContext context) {
     return context.getDispatchablePlanMetadataMap().computeIfAbsent(node.getPlanFragmentId(),
         (id) -> new DispatchablePlanMetadata());
   }
 
-  private static void computeWorkerAssignment(PlanNode node, DispatchablePlanContext context) {
-    int planFragmentId = node.getPlanFragmentId();
-    context.getWorkerManager()
-        .assignWorkerToStage(planFragmentId, context.getDispatchablePlanMetadataMap().get(planFragmentId),
-            context.getRequestId(), context.getPlannerContext().getOptions(), context.getTableNames());
-  }
-
   @Override
   public Void visitAggregate(AggregateNode node, DispatchablePlanContext context) {
     node.getInputs().get(0).visit(this, context);
@@ -150,9 +103,7 @@ public class DispatchablePlanVisitor implements PlanNodeVisitor<Void, Dispatchab
   public Void visitMailboxSend(MailboxSendNode node, DispatchablePlanContext context) {
     node.getInputs().get(0).visit(this, context);
     getOrCreateDispatchablePlanMetadata(node, context);
-
     context.getDispatchablePlanStageRootMap().put(node.getPlanFragmentId(), node);
-    computeWorkerAssignment(node, context);
     return null;
   }
 
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/PinotDispatchPlanner.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/PinotDispatchPlanner.java
new file mode 100644
index 0000000000..bded25f200
--- /dev/null
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/PinotDispatchPlanner.java
@@ -0,0 +1,105 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.query.planner.physical;
+
+import org.apache.pinot.common.config.provider.TableCache;
+import org.apache.pinot.query.context.PlannerContext;
+import org.apache.pinot.query.planner.DispatchableSubPlan;
+import org.apache.pinot.query.planner.PlanFragment;
+import org.apache.pinot.query.planner.SubPlan;
+import org.apache.pinot.query.planner.physical.colocated.GreedyShuffleRewriteVisitor;
+import org.apache.pinot.query.planner.plannode.PlanNode;
+import org.apache.pinot.query.routing.WorkerManager;
+
+
+public class PinotDispatchPlanner {
+
+  private final WorkerManager _workerManager;
+  private final long _requestId;
+  private final PlannerContext _plannerContext;
+
+  private final TableCache _tableCache;
+
+  public PinotDispatchPlanner(PlannerContext plannerContext, WorkerManager workerManager, long requestId,
+      TableCache tableCache) {
+    _plannerContext = plannerContext;
+    _workerManager = workerManager;
+    _requestId = requestId;
+    _tableCache = tableCache;
+  }
+
+  /**
+   * Entry point for attaching dispatch metadata to a {@link SubPlan}.
+   * @param subPlan the entrypoint of the sub plan.
+   */
+  public DispatchableSubPlan createDispatchableSubPlan(SubPlan subPlan) {
+    // perform physical plan conversion and assign workers to each stage.
+    DispatchablePlanContext dispatchablePlanContext = new DispatchablePlanContext(_workerManager, _requestId,
+        _plannerContext, subPlan.getSubPlanMetadata().getFields(), subPlan.getSubPlanMetadata().getTableNames());
+    PlanNode subPlanRoot = subPlan.getSubPlanRoot().getFragmentRoot();
+    // 1. start by visiting the sub plan fragment root.
+    subPlanRoot.visit(DispatchablePlanVisitor.INSTANCE, dispatchablePlanContext);
+    // 2. add a special stage for the global mailbox receive, this runs on the dispatcher.
+    dispatchablePlanContext.getDispatchablePlanStageRootMap().put(0, subPlanRoot);
+    // 3. add worker assignment after the dispatchable plan context is fulfilled after the visit.
+    computeWorkerAssignment(subPlan.getSubPlanRoot(), dispatchablePlanContext);
+    // 4. compute the mailbox assignment for each stage.
+    // TODO: refactor this to be a pluggable interface.
+    computeMailboxAssignment(dispatchablePlanContext);
+    // 5. Run physical optimizations
+    runPhysicalOptimizers(subPlanRoot, dispatchablePlanContext, _tableCache);
+    // 6. convert it into query plan.
+    // TODO: refactor this to be a pluggable interface.
+    return finalizeDispatchableSubPlan(subPlan.getSubPlanRoot(), dispatchablePlanContext);
+  }
+
+  private void computeMailboxAssignment(DispatchablePlanContext dispatchablePlanContext) {
+    dispatchablePlanContext.getDispatchablePlanStageRootMap().get(0).visit(MailboxAssignmentVisitor.INSTANCE,
+        dispatchablePlanContext);
+  }
+
+  // TODO: Switch to Worker SPI to avoid multiple-places where workers are assigned.
+  private void runPhysicalOptimizers(PlanNode subPlanRoot, DispatchablePlanContext dispatchablePlanContext,
+      TableCache tableCache) {
+    if (dispatchablePlanContext.getPlannerContext().getOptions().getOrDefault("useColocatedJoin", "false")
+        .equals("true")) {
+      GreedyShuffleRewriteVisitor.optimizeShuffles(subPlanRoot,
+          dispatchablePlanContext.getDispatchablePlanMetadataMap(), tableCache);
+    }
+  }
+
+  private static DispatchableSubPlan finalizeDispatchableSubPlan(PlanFragment subPlanRoot,
+      DispatchablePlanContext dispatchablePlanContext) {
+    return new DispatchableSubPlan(dispatchablePlanContext.getResultFields(),
+        dispatchablePlanContext.constructDispatchablePlanFragmentList(subPlanRoot),
+        dispatchablePlanContext.getTableNames());
+  }
+
+  private static void computeWorkerAssignment(PlanFragment planFragment, DispatchablePlanContext context) {
+    computeWorkerAssignment(planFragment.getFragmentRoot(), context);
+    planFragment.getChildren().forEach(child -> computeWorkerAssignment(child, context));
+  }
+
+  private static void computeWorkerAssignment(PlanNode node, DispatchablePlanContext context) {
+    int planFragmentId = node.getPlanFragmentId();
+    context.getWorkerManager()
+        .assignWorkerToStage(planFragmentId, context.getDispatchablePlanMetadataMap().get(planFragmentId),
+            context.getRequestId(), context.getPlannerContext().getOptions(), context.getTableNames());
+  }
+}
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/colocated/GreedyShuffleRewriteVisitor.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/colocated/GreedyShuffleRewriteVisitor.java
index 99e2f2ab68..0057dc35ab 100644
--- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/colocated/GreedyShuffleRewriteVisitor.java
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/colocated/GreedyShuffleRewriteVisitor.java
@@ -29,7 +29,6 @@ import java.util.Set;
 import java.util.stream.Collectors;
 import org.apache.calcite.rel.RelDistribution;
 import org.apache.pinot.common.config.provider.TableCache;
-import org.apache.pinot.query.planner.QueryPlan;
 import org.apache.pinot.query.planner.logical.RexExpression;
 import org.apache.pinot.query.planner.partitioning.FieldSelectionKeySelector;
 import org.apache.pinot.query.planner.partitioning.KeySelector;
@@ -74,9 +73,8 @@ public class GreedyShuffleRewriteVisitor implements PlanNodeVisitor<Set<Colocati
   private final Map<Integer, DispatchablePlanMetadata> _dispatchablePlanMetadataMap;
   private boolean _canSkipShuffleForJoin;
 
-  public static void optimizeShuffles(QueryPlan queryPlan, TableCache tableCache) {
-    PlanNode rootPlanNode = queryPlan.getQueryStageMap().get(0);
-    Map<Integer, DispatchablePlanMetadata> dispatchablePlanMetadataMap = queryPlan.getDispatchablePlanMetadataMap();
+  public static void optimizeShuffles(PlanNode rootPlanNode,
+      Map<Integer, DispatchablePlanMetadata> dispatchablePlanMetadataMap, TableCache tableCache) {
     GreedyShuffleRewriteContext context = GreedyShuffleRewritePreComputeVisitor.preComputeContext(rootPlanNode);
     // This assumes that if planFragmentId(S1) > planFragmentId(S2), then S1 is not an ancestor of S2.
     // TODO: If this assumption is wrong, we can compute the reverse topological ordering explicitly.
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/plannode/ExchangeNode.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/plannode/ExchangeNode.java
index 04adca3eee..46baec9be5 100644
--- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/plannode/ExchangeNode.java
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/plannode/ExchangeNode.java
@@ -19,6 +19,7 @@
 package org.apache.pinot.query.planner.plannode;
 
 import java.util.List;
+import java.util.Set;
 import org.apache.calcite.rel.RelDistribution;
 import org.apache.calcite.rel.RelFieldCollation;
 import org.apache.pinot.common.utils.DataSchema;
@@ -46,11 +47,17 @@ public class ExchangeNode extends AbstractPlanNode {
   @ProtoProperties
   private List<RelFieldCollation> _collations;
 
+  /**
+   * The set of tables that are scanned in this planFragment.
+   */
+  @ProtoProperties
+  private Set<String> _tableNames;
+
   public ExchangeNode(int planFragmentId) {
     super(planFragmentId);
   }
 
-  public ExchangeNode(int currentStageId, DataSchema dataSchema, RelDistribution distribution,
+  public ExchangeNode(int currentStageId, DataSchema dataSchema, Set<String> tableNames, RelDistribution distribution,
       List<RelFieldCollation> collations, boolean isSortOnSender,
       boolean isSortOnReceiver) {
     super(currentStageId, dataSchema);
@@ -59,6 +66,7 @@ public class ExchangeNode extends AbstractPlanNode {
     _isSortOnSender = isSortOnSender;
     _isSortOnReceiver = isSortOnReceiver;
     _collations = collations;
+    _tableNames = tableNames;
   }
 
   @Override
@@ -90,4 +98,8 @@ public class ExchangeNode extends AbstractPlanNode {
   public List<RelFieldCollation> getCollations() {
     return _collations;
   }
+
+  public Set<String> getTableNames() {
+    return _tableNames;
+  }
 }
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/plannode/PlanNodeVisitor.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/plannode/PlanNodeVisitor.java
index d917cab021..d13eee07b0 100644
--- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/plannode/PlanNodeVisitor.java
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/plannode/PlanNodeVisitor.java
@@ -19,7 +19,6 @@
 package org.apache.pinot.query.planner.plannode;
 
 import org.apache.pinot.query.planner.ExplainPlanPlanVisitor;
-import org.apache.pinot.query.planner.QueryPlan;
 
 
 /**
@@ -28,7 +27,7 @@ import org.apache.pinot.query.planner.QueryPlan;
  * enforced traversal order, and should be implemented by subclasses.
  *
  * <p>It is recommended that implementors use private constructors and static methods to access main
- * functionality (see {@link ExplainPlanPlanVisitor#explain(QueryPlan)}
+ * functionality (see {@link ExplainPlanPlanVisitor#explain(org.apache.pinot.query.planner.DispatchableSubPlan)}
  * as an example of a usage of this pattern.
  *
  * @param <T> the return type for all visitsPlanNodeVisitor
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/routing/PlanFragmentMetadata.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/routing/StageMetadata.java
similarity index 84%
rename from pinot-query-planner/src/main/java/org/apache/pinot/query/routing/PlanFragmentMetadata.java
rename to pinot-query-planner/src/main/java/org/apache/pinot/query/routing/StageMetadata.java
index f1da239417..16d0da897b 100644
--- a/pinot-query-planner/src/main/java/org/apache/pinot/query/routing/PlanFragmentMetadata.java
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/routing/StageMetadata.java
@@ -25,13 +25,13 @@ import org.apache.pinot.core.routing.TimeBoundaryInfo;
 
 
 /**
- * {@code PlanFragmentMetadata} is used to send plan fragment-level info about how to execute a stage physically.
+ * {@code StageMetadata} is used to send plan fragment-level info about how to execute a stage physically.
  */
-public class PlanFragmentMetadata {
+public class StageMetadata {
   private final List<WorkerMetadata> _workerMetadataList;
   private final Map<String, String> _customProperties;
 
-  public PlanFragmentMetadata(List<WorkerMetadata> workerMetadataList, Map<String, String> customProperties) {
+  public StageMetadata(List<WorkerMetadata> workerMetadataList, Map<String, String> customProperties) {
     _workerMetadataList = workerMetadataList;
     _customProperties = customProperties;
   }
@@ -71,8 +71,8 @@ public class PlanFragmentMetadata {
       return this;
     }
 
-    public PlanFragmentMetadata build() {
-      return new PlanFragmentMetadata(_workerMetadataList, _customProperties);
+    public StageMetadata build() {
+      return new StageMetadata(_workerMetadataList, _customProperties);
     }
 
     public void putAllCustomProperties(Map<String, String> customPropertyMap) {
@@ -80,11 +80,11 @@ public class PlanFragmentMetadata {
     }
   }
 
-  public static String getTableName(PlanFragmentMetadata metadata) {
+  public static String getTableName(StageMetadata metadata) {
     return metadata.getCustomProperties().get(Builder.TABLE_NAME_KEY);
   }
 
-  public static TimeBoundaryInfo getTimeBoundary(PlanFragmentMetadata metadata) {
+  public static TimeBoundaryInfo getTimeBoundary(StageMetadata metadata) {
     String timeColumn = metadata.getCustomProperties().get(Builder.TIME_BOUNDARY_COLUMN_KEY);
     String timeValue = metadata.getCustomProperties().get(Builder.TIME_BOUNDARY_VALUE_KEY);
     return timeColumn != null && timeValue != null ? new TimeBoundaryInfo(timeColumn, timeValue) : null;
diff --git a/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryCompilationTest.java b/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryCompilationTest.java
index 41a88ef1b6..c3e45746bf 100644
--- a/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryCompilationTest.java
+++ b/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryCompilationTest.java
@@ -28,10 +28,10 @@ import java.util.concurrent.locks.Lock;
 import java.util.concurrent.locks.ReentrantLock;
 import java.util.stream.Collectors;
 import org.apache.calcite.rel.RelDistribution;
+import org.apache.pinot.query.planner.DispatchablePlanFragment;
+import org.apache.pinot.query.planner.DispatchableSubPlan;
 import org.apache.pinot.query.planner.ExplainPlanPlanVisitor;
 import org.apache.pinot.query.planner.PlannerUtils;
-import org.apache.pinot.query.planner.QueryPlan;
-import org.apache.pinot.query.planner.physical.DispatchablePlanMetadata;
 import org.apache.pinot.query.planner.plannode.AbstractPlanNode;
 import org.apache.pinot.query.planner.plannode.AggregateNode;
 import org.apache.pinot.query.planner.plannode.FilterNode;
@@ -61,8 +61,8 @@ public class QueryCompilationTest extends QueryEnvironmentTestBase {
   public void testQueryPlanWithoutException(String query)
       throws Exception {
     try {
-      QueryPlan queryPlan = _queryEnvironment.planQuery(query);
-      Assert.assertNotNull(queryPlan);
+      DispatchableSubPlan dispatchableSubPlan = _queryEnvironment.planQuery(query);
+      Assert.assertNotNull(dispatchableSubPlan);
     } catch (RuntimeException e) {
       Assert.fail("failed to plan query: " + query, e);
     }
@@ -78,11 +78,12 @@ public class QueryCompilationTest extends QueryEnvironmentTestBase {
     }
   }
 
-  private static void assertGroupBySingletonAfterJoin(QueryPlan queryPlan, boolean shouldRewrite)
+  private static void assertGroupBySingletonAfterJoin(DispatchableSubPlan dispatchableSubPlan, boolean shouldRewrite)
       throws Exception {
-    for (Map.Entry<Integer, DispatchablePlanMetadata> e : queryPlan.getDispatchablePlanMetadataMap().entrySet()) {
-      if (e.getValue().getScannedTables().size() == 0 && !PlannerUtils.isRootPlanFragment(e.getKey())) {
-        PlanNode node = queryPlan.getQueryStageMap().get(e.getKey());
+
+    for (int stageId = 0; stageId < dispatchableSubPlan.getQueryStageList().size(); stageId++) {
+      if (dispatchableSubPlan.getTableNames().size() == 0 && !PlannerUtils.isRootPlanFragment(stageId)) {
+        PlanNode node = dispatchableSubPlan.getQueryStageList().get(stageId).getPlanFragment().getFragmentRoot();
         while (node != null) {
           if (node instanceof JoinNode) {
             // JOIN is exchanged with hash distribution (data shuffle)
@@ -112,32 +113,27 @@ public class QueryCompilationTest extends QueryEnvironmentTestBase {
   public void testQueryAndAssertStageContentForJoin()
       throws Exception {
     String query = "SELECT * FROM a JOIN b ON a.col1 = b.col2";
-    QueryPlan queryPlan = _queryEnvironment.planQuery(query);
-    Assert.assertEquals(queryPlan.getQueryStageMap().size(), 4);
-    Assert.assertEquals(queryPlan.getDispatchablePlanMetadataMap().size(), 4);
-    for (Map.Entry<Integer, DispatchablePlanMetadata> e : queryPlan.getDispatchablePlanMetadataMap().entrySet()) {
-      List<String> tables = e.getValue().getScannedTables();
-      if (tables.size() == 1) {
+    DispatchableSubPlan dispatchableSubPlan = _queryEnvironment.planQuery(query);
+    Assert.assertEquals(dispatchableSubPlan.getQueryStageList().size(), 4);
+
+    for (int stageId = 0; stageId < dispatchableSubPlan.getQueryStageList().size(); stageId++) {
+      DispatchablePlanFragment dispatchablePlanFragment = dispatchableSubPlan.getQueryStageList().get(stageId);
+      String tableName = dispatchablePlanFragment.getTableName();
+      if (tableName != null) {
         // table scan stages; for tableA it should have 2 hosts, for tableB it should have only 1
-        Assert.assertEquals(
-            e.getValue().getServerInstanceToWorkerIdMap().entrySet().stream()
-                .map(ExplainPlanPlanVisitor::stringifyQueryServerInstanceToWorkerIdsEntry)
-                .collect(Collectors.toSet()),
-            tables.get(0).equals("a") ? ImmutableList.of("localhost@{1,1}|[1]", "localhost@{2,2}|[0]")
+        Assert.assertEquals(dispatchablePlanFragment.getServerInstanceToWorkerIdMap().entrySet().stream()
+                .map(ExplainPlanPlanVisitor::stringifyQueryServerInstanceToWorkerIdsEntry).collect(Collectors.toSet()),
+            tableName.equals("a") ? ImmutableList.of("localhost@{1,1}|[1]", "localhost@{2,2}|[0]")
                 : ImmutableList.of("localhost@{1,1}|[0]"));
-      } else if (!PlannerUtils.isRootPlanFragment(e.getKey())) {
+      } else if (!PlannerUtils.isRootPlanFragment(stageId)) {
         // join stage should have both servers used.
-        Assert.assertEquals(
-            e.getValue().getServerInstanceToWorkerIdMap().entrySet().stream()
-                .map(ExplainPlanPlanVisitor::stringifyQueryServerInstanceToWorkerIdsEntry)
-                .collect(Collectors.toSet()),
+        Assert.assertEquals(dispatchablePlanFragment.getServerInstanceToWorkerIdMap().entrySet().stream()
+                .map(ExplainPlanPlanVisitor::stringifyQueryServerInstanceToWorkerIdsEntry).collect(Collectors.toSet()),
             ImmutableSet.of("localhost@{1,1}|[1]", "localhost@{2,2}|[0]"));
       } else {
         // reduce stage should have the reducer instance.
-        Assert.assertEquals(
-            e.getValue().getServerInstanceToWorkerIdMap().entrySet().stream()
-                .map(ExplainPlanPlanVisitor::stringifyQueryServerInstanceToWorkerIdsEntry)
-                .collect(Collectors.toSet()),
+        Assert.assertEquals(dispatchablePlanFragment.getServerInstanceToWorkerIdMap().entrySet().stream()
+                .map(ExplainPlanPlanVisitor::stringifyQueryServerInstanceToWorkerIdsEntry).collect(Collectors.toSet()),
             ImmutableSet.of("localhost@{3,3}|[0]"));
       }
     }
@@ -147,13 +143,13 @@ public class QueryCompilationTest extends QueryEnvironmentTestBase {
   public void testQueryProjectFilterPushDownForJoin() {
     String query = "SELECT a.col1, a.ts, b.col2, b.col3 FROM a JOIN b ON a.col1 = b.col2 "
         + "WHERE a.col3 >= 0 AND a.col2 IN ('b') AND b.col3 < 0";
-    QueryPlan queryPlan = _queryEnvironment.planQuery(query);
-    List<PlanNode> intermediateStageRoots =
-        queryPlan.getDispatchablePlanMetadataMap().entrySet().stream()
-            .filter(e -> e.getValue().getScannedTables().size() == 0)
-            .map(e -> queryPlan.getQueryStageMap().get(e.getKey())).collect(Collectors.toList());
+    DispatchableSubPlan dispatchableSubPlan = _queryEnvironment.planQuery(query);
+    List<DispatchablePlanFragment> intermediateStages =
+        dispatchableSubPlan.getQueryStageList().stream().filter(q -> q.getTableName() == null)
+            .collect(Collectors.toList());
     // Assert that no project of filter node for any intermediate stage because all should've been pushed down.
-    for (PlanNode roots : intermediateStageRoots) {
+    for (DispatchablePlanFragment dispatchablePlanFragment : intermediateStages) {
+      PlanNode roots = dispatchablePlanFragment.getPlanFragment().getFragmentRoot();
       assertNodeTypeNotIn(roots, ImmutableList.of(ProjectNode.class, FilterNode.class));
     }
   }
@@ -161,26 +157,23 @@ public class QueryCompilationTest extends QueryEnvironmentTestBase {
   @Test
   public void testQueryRoutingManagerCompilation() {
     String query = "SELECT * FROM d_OFFLINE";
-    QueryPlan queryPlan = _queryEnvironment.planQuery(query);
-    List<DispatchablePlanMetadata> tableScanMetadataList = queryPlan.getDispatchablePlanMetadataMap().values().stream()
-        .filter(planFragmentMetadata -> planFragmentMetadata.getScannedTables().size() != 0)
-        .collect(Collectors.toList());
+    DispatchableSubPlan dispatchableSubPlan = _queryEnvironment.planQuery(query);
+    List<DispatchablePlanFragment> tableScanMetadataList = dispatchableSubPlan.getQueryStageList().stream()
+        .filter(stageMetadata -> stageMetadata.getTableName() != null).collect(Collectors.toList());
     Assert.assertEquals(tableScanMetadataList.size(), 1);
     Assert.assertEquals(tableScanMetadataList.get(0).getServerInstanceToWorkerIdMap().size(), 2);
 
     query = "SELECT * FROM d_REALTIME";
-    queryPlan = _queryEnvironment.planQuery(query);
-    tableScanMetadataList = queryPlan.getDispatchablePlanMetadataMap().values().stream()
-        .filter(planFragmentMetadata -> planFragmentMetadata.getScannedTables().size() != 0)
-        .collect(Collectors.toList());
+    dispatchableSubPlan = _queryEnvironment.planQuery(query);
+    tableScanMetadataList = dispatchableSubPlan.getQueryStageList().stream()
+        .filter(stageMetadata -> stageMetadata.getTableName() != null).collect(Collectors.toList());
     Assert.assertEquals(tableScanMetadataList.size(), 1);
     Assert.assertEquals(tableScanMetadataList.get(0).getServerInstanceToWorkerIdMap().size(), 1);
 
     query = "SELECT * FROM d";
-    queryPlan = _queryEnvironment.planQuery(query);
-    tableScanMetadataList = queryPlan.getDispatchablePlanMetadataMap().values().stream()
-        .filter(planFragmentMetadata -> planFragmentMetadata.getScannedTables().size() != 0)
-        .collect(Collectors.toList());
+    dispatchableSubPlan = _queryEnvironment.planQuery(query);
+    tableScanMetadataList = dispatchableSubPlan.getQueryStageList().stream()
+        .filter(stageMetadata -> stageMetadata.getTableName() != null).collect(Collectors.toList());
     Assert.assertEquals(tableScanMetadataList.size(), 1);
     Assert.assertEquals(tableScanMetadataList.get(0).getServerInstanceToWorkerIdMap().size(), 2);
   }
@@ -189,26 +182,26 @@ public class QueryCompilationTest extends QueryEnvironmentTestBase {
   @Test
   public void testPlanQueryMultiThread()
       throws Exception {
-    Map<String, ArrayList<QueryPlan>> queryPlans = new HashMap<>();
+    Map<String, ArrayList<DispatchableSubPlan>> queryPlans = new HashMap<>();
     Lock lock = new ReentrantLock();
     Runnable joinQuery = () -> {
       String query = "SELECT a.col1, a.ts, b.col2, b.col3 FROM a JOIN b ON a.col1 = b.col2";
-      QueryPlan queryPlan = _queryEnvironment.planQuery(query);
+      DispatchableSubPlan dispatchableSubPlan = _queryEnvironment.planQuery(query);
       lock.lock();
-      if (!queryPlans.containsKey(queryPlan)) {
+      if (!queryPlans.containsKey(dispatchableSubPlan)) {
         queryPlans.put(query, new ArrayList<>());
       }
-      queryPlans.get(query).add(queryPlan);
+      queryPlans.get(query).add(dispatchableSubPlan);
       lock.unlock();
     };
     Runnable selectQuery = () -> {
       String query = "SELECT * FROM a";
-      QueryPlan queryPlan = _queryEnvironment.planQuery(query);
+      DispatchableSubPlan dispatchableSubPlan = _queryEnvironment.planQuery(query);
       lock.lock();
-      if (!queryPlans.containsKey(queryPlan)) {
+      if (!queryPlans.containsKey(dispatchableSubPlan)) {
         queryPlans.put(query, new ArrayList<>());
       }
-      queryPlans.get(query).add(queryPlan);
+      queryPlans.get(query).add(dispatchableSubPlan);
       lock.unlock();
     };
     ArrayList<Thread> threads = new ArrayList<>();
@@ -228,8 +221,8 @@ public class QueryCompilationTest extends QueryEnvironmentTestBase {
     for (Thread t : threads) {
       t.join();
     }
-    for (ArrayList<QueryPlan> plans : queryPlans.values()) {
-      for (QueryPlan plan : plans) {
+    for (ArrayList<DispatchableSubPlan> plans : queryPlans.values()) {
+      for (DispatchableSubPlan plan : plans) {
         Assert.assertTrue(plan.equals(plans.get(0)));
       }
     }
@@ -241,22 +234,20 @@ public class QueryCompilationTest extends QueryEnvironmentTestBase {
     // Hinting the query to use final stage aggregation makes server directly return final result
     // This is useful when data is already partitioned by col1
     String query = "SELECT /*+ aggFinalStage */ col1, COUNT(*) FROM b GROUP BY col1";
-    QueryPlan queryPlan = _queryEnvironment.planQuery(query);
-    Assert.assertEquals(queryPlan.getQueryStageMap().size(), 2);
-    Assert.assertEquals(queryPlan.getDispatchablePlanMetadataMap().size(), 2);
-    for (Map.Entry<Integer, DispatchablePlanMetadata> e : queryPlan.getDispatchablePlanMetadataMap().entrySet()) {
-      List<String> tables = e.getValue().getScannedTables();
-      if (tables.size() != 0) {
+    DispatchableSubPlan dispatchableSubPlan = _queryEnvironment.planQuery(query);
+    Assert.assertEquals(dispatchableSubPlan.getQueryStageList().size(), 2);
+    for (int stageId = 0; stageId < dispatchableSubPlan.getQueryStageList().size(); stageId++) {
+      DispatchablePlanFragment dispatchablePlanFragment = dispatchableSubPlan.getQueryStageList().get(stageId);
+      String tableName = dispatchablePlanFragment.getTableName();
+      if (tableName != null) {
         // table scan stages; for tableB it should have only 1
-        Assert.assertEquals(e.getValue().getServerInstanceToWorkerIdMap().entrySet().stream()
-                .map(ExplainPlanPlanVisitor::stringifyQueryServerInstanceToWorkerIdsEntry)
-                .collect(Collectors.toSet()),
+        Assert.assertEquals(dispatchablePlanFragment.getServerInstanceToWorkerIdMap().entrySet().stream()
+                .map(ExplainPlanPlanVisitor::stringifyQueryServerInstanceToWorkerIdsEntry).collect(Collectors.toSet()),
             ImmutableList.of("localhost@{1,1}|[0]"));
-      } else if (!PlannerUtils.isRootPlanFragment(e.getKey())) {
+      } else if (!PlannerUtils.isRootPlanFragment(stageId)) {
         // join stage should have both servers used.
-        Assert.assertEquals(e.getValue().getServerInstanceToWorkerIdMap().entrySet().stream()
-                .map(ExplainPlanPlanVisitor::stringifyQueryServerInstanceToWorkerIdsEntry)
-                .collect(Collectors.toSet()),
+        Assert.assertEquals(dispatchablePlanFragment.getServerInstanceToWorkerIdMap().entrySet().stream()
+                .map(ExplainPlanPlanVisitor::stringifyQueryServerInstanceToWorkerIdsEntry).collect(Collectors.toSet()),
             ImmutableList.of("localhost@{1,1}|[1]", "localhost@{2,2}|[0]"));
       }
     }
@@ -273,8 +264,7 @@ public class QueryCompilationTest extends QueryEnvironmentTestBase {
     }
   }
 
-  private static boolean isOneOf(List<Class<? extends AbstractPlanNode>> allowedNodeTypes,
-      PlanNode node) {
+  private static boolean isOneOf(List<Class<? extends AbstractPlanNode>> allowedNodeTypes, PlanNode node) {
     for (Class<? extends AbstractPlanNode> allowedNodeType : allowedNodeTypes) {
       if (node.getClass() == allowedNodeType) {
         return true;
diff --git a/pinot-query-planner/src/test/java/org/apache/pinot/query/planner/logical/RelToStageConverterTest.java b/pinot-query-planner/src/test/java/org/apache/pinot/query/planner/logical/RelToPlanNodeConverterTest.java
similarity index 72%
rename from pinot-query-planner/src/test/java/org/apache/pinot/query/planner/logical/RelToStageConverterTest.java
rename to pinot-query-planner/src/test/java/org/apache/pinot/query/planner/logical/RelToPlanNodeConverterTest.java
index 32a27c979b..427829c71d 100644
--- a/pinot-query-planner/src/test/java/org/apache/pinot/query/planner/logical/RelToStageConverterTest.java
+++ b/pinot-query-planner/src/test/java/org/apache/pinot/query/planner/logical/RelToPlanNodeConverterTest.java
@@ -29,104 +29,104 @@ import org.testng.Assert;
 import org.testng.annotations.Test;
 
 
-public class RelToStageConverterTest {
+public class RelToPlanNodeConverterTest {
 
   @Test
   public void testConvertToColumnDataTypeForObjectTypes() {
-    Assert.assertEquals(RelToStageConverter.convertToColumnDataType(
+    Assert.assertEquals(RelToPlanNodeConverter.convertToColumnDataType(
             new ObjectSqlType(SqlTypeName.BOOLEAN, SqlIdentifier.STAR, true, null, null)),
         DataSchema.ColumnDataType.BOOLEAN);
-    Assert.assertEquals(RelToStageConverter.convertToColumnDataType(
+    Assert.assertEquals(RelToPlanNodeConverter.convertToColumnDataType(
             new ObjectSqlType(SqlTypeName.TINYINT, SqlIdentifier.STAR, true, null, null)),
         DataSchema.ColumnDataType.INT);
-    Assert.assertEquals(RelToStageConverter.convertToColumnDataType(
+    Assert.assertEquals(RelToPlanNodeConverter.convertToColumnDataType(
             new ObjectSqlType(SqlTypeName.SMALLINT, SqlIdentifier.STAR, true, null, null)),
         DataSchema.ColumnDataType.INT);
-    Assert.assertEquals(RelToStageConverter.convertToColumnDataType(
+    Assert.assertEquals(RelToPlanNodeConverter.convertToColumnDataType(
             new ObjectSqlType(SqlTypeName.INTEGER, SqlIdentifier.STAR, true, null, null)),
         DataSchema.ColumnDataType.INT);
-    Assert.assertEquals(RelToStageConverter.convertToColumnDataType(
+    Assert.assertEquals(RelToPlanNodeConverter.convertToColumnDataType(
             new ObjectSqlType(SqlTypeName.BIGINT, SqlIdentifier.STAR, true, null, null)),
         DataSchema.ColumnDataType.LONG);
-    Assert.assertEquals(RelToStageConverter.convertToColumnDataType(
+    Assert.assertEquals(RelToPlanNodeConverter.convertToColumnDataType(
             new ObjectSqlType(SqlTypeName.FLOAT, SqlIdentifier.STAR, true, null, null)),
         DataSchema.ColumnDataType.FLOAT);
-    Assert.assertEquals(RelToStageConverter.convertToColumnDataType(
+    Assert.assertEquals(RelToPlanNodeConverter.convertToColumnDataType(
             new ObjectSqlType(SqlTypeName.DOUBLE, SqlIdentifier.STAR, true, null, null)),
         DataSchema.ColumnDataType.DOUBLE);
-    Assert.assertEquals(RelToStageConverter.convertToColumnDataType(
+    Assert.assertEquals(RelToPlanNodeConverter.convertToColumnDataType(
             new ObjectSqlType(SqlTypeName.TIMESTAMP, SqlIdentifier.STAR, true, null, null)),
         DataSchema.ColumnDataType.TIMESTAMP);
-    Assert.assertEquals(RelToStageConverter.convertToColumnDataType(
+    Assert.assertEquals(RelToPlanNodeConverter.convertToColumnDataType(
             new ObjectSqlType(SqlTypeName.CHAR, SqlIdentifier.STAR, true, null, null)),
         DataSchema.ColumnDataType.STRING);
-    Assert.assertEquals(RelToStageConverter.convertToColumnDataType(
+    Assert.assertEquals(RelToPlanNodeConverter.convertToColumnDataType(
             new ObjectSqlType(SqlTypeName.VARCHAR, SqlIdentifier.STAR, true, null, null)),
         DataSchema.ColumnDataType.STRING);
-    Assert.assertEquals(RelToStageConverter.convertToColumnDataType(
+    Assert.assertEquals(RelToPlanNodeConverter.convertToColumnDataType(
             new ObjectSqlType(SqlTypeName.VARBINARY, SqlIdentifier.STAR, true, null, null)),
         DataSchema.ColumnDataType.BYTES);
-    Assert.assertEquals(RelToStageConverter.convertToColumnDataType(
+    Assert.assertEquals(RelToPlanNodeConverter.convertToColumnDataType(
             new ObjectSqlType(SqlTypeName.OTHER, SqlIdentifier.STAR, true, null, null)),
         DataSchema.ColumnDataType.OBJECT);
   }
 
   @Test
   public void testBigDecimal() {
-    Assert.assertEquals(RelToStageConverter.convertToColumnDataType(
+    Assert.assertEquals(RelToPlanNodeConverter.convertToColumnDataType(
             new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.DECIMAL, 10)),
         DataSchema.ColumnDataType.INT);
-    Assert.assertEquals(RelToStageConverter.convertToColumnDataType(
+    Assert.assertEquals(RelToPlanNodeConverter.convertToColumnDataType(
             new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.DECIMAL, 38)),
         DataSchema.ColumnDataType.LONG);
-    Assert.assertEquals(RelToStageConverter.convertToColumnDataType(
+    Assert.assertEquals(RelToPlanNodeConverter.convertToColumnDataType(
             new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.DECIMAL, 39)),
         DataSchema.ColumnDataType.BIG_DECIMAL);
 
-    Assert.assertEquals(RelToStageConverter.convertToColumnDataType(
+    Assert.assertEquals(RelToPlanNodeConverter.convertToColumnDataType(
             new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.DECIMAL, 14, 10)),
         DataSchema.ColumnDataType.FLOAT);
-    Assert.assertEquals(RelToStageConverter.convertToColumnDataType(
+    Assert.assertEquals(RelToPlanNodeConverter.convertToColumnDataType(
             new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.DECIMAL, 30, 10)),
         DataSchema.ColumnDataType.DOUBLE);
-    Assert.assertEquals(RelToStageConverter.convertToColumnDataType(
+    Assert.assertEquals(RelToPlanNodeConverter.convertToColumnDataType(
             new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.DECIMAL, 31, 10)),
         DataSchema.ColumnDataType.BIG_DECIMAL);
   }
 
   @Test
   public void testConvertToColumnDataTypeForArray() {
-    Assert.assertEquals(RelToStageConverter.convertToColumnDataType(
+    Assert.assertEquals(RelToPlanNodeConverter.convertToColumnDataType(
             new ArraySqlType(new ObjectSqlType(SqlTypeName.BOOLEAN, SqlIdentifier.STAR, true, null, null), true)),
         DataSchema.ColumnDataType.BOOLEAN_ARRAY);
-    Assert.assertEquals(RelToStageConverter.convertToColumnDataType(
+    Assert.assertEquals(RelToPlanNodeConverter.convertToColumnDataType(
             new ArraySqlType(new ObjectSqlType(SqlTypeName.TINYINT, SqlIdentifier.STAR, true, null, null), true)),
         DataSchema.ColumnDataType.INT_ARRAY);
-    Assert.assertEquals(RelToStageConverter.convertToColumnDataType(
+    Assert.assertEquals(RelToPlanNodeConverter.convertToColumnDataType(
             new ArraySqlType(new ObjectSqlType(SqlTypeName.SMALLINT, SqlIdentifier.STAR, true, null, null), true)),
         DataSchema.ColumnDataType.INT_ARRAY);
-    Assert.assertEquals(RelToStageConverter.convertToColumnDataType(
+    Assert.assertEquals(RelToPlanNodeConverter.convertToColumnDataType(
             new ArraySqlType(new ObjectSqlType(SqlTypeName.INTEGER, SqlIdentifier.STAR, true, null, null), true)),
         DataSchema.ColumnDataType.INT_ARRAY);
-    Assert.assertEquals(RelToStageConverter.convertToColumnDataType(
+    Assert.assertEquals(RelToPlanNodeConverter.convertToColumnDataType(
             new ArraySqlType(new ObjectSqlType(SqlTypeName.BIGINT, SqlIdentifier.STAR, true, null, null), true)),
         DataSchema.ColumnDataType.LONG_ARRAY);
-    Assert.assertEquals(RelToStageConverter.convertToColumnDataType(
+    Assert.assertEquals(RelToPlanNodeConverter.convertToColumnDataType(
             new ArraySqlType(new ObjectSqlType(SqlTypeName.FLOAT, SqlIdentifier.STAR, true, null, null), true)),
         DataSchema.ColumnDataType.FLOAT_ARRAY);
-    Assert.assertEquals(RelToStageConverter.convertToColumnDataType(
+    Assert.assertEquals(RelToPlanNodeConverter.convertToColumnDataType(
             new ArraySqlType(new ObjectSqlType(SqlTypeName.DOUBLE, SqlIdentifier.STAR, true, null, null), true)),
         DataSchema.ColumnDataType.DOUBLE_ARRAY);
-    Assert.assertEquals(RelToStageConverter.convertToColumnDataType(
+    Assert.assertEquals(RelToPlanNodeConverter.convertToColumnDataType(
             new ArraySqlType(new ObjectSqlType(SqlTypeName.TIMESTAMP, SqlIdentifier.STAR, true, null, null), true)),
         DataSchema.ColumnDataType.TIMESTAMP_ARRAY);
-    Assert.assertEquals(RelToStageConverter.convertToColumnDataType(
+    Assert.assertEquals(RelToPlanNodeConverter.convertToColumnDataType(
             new ArraySqlType(new ObjectSqlType(SqlTypeName.CHAR, SqlIdentifier.STAR, true, null, null), true)),
         DataSchema.ColumnDataType.STRING_ARRAY);
-    Assert.assertEquals(RelToStageConverter.convertToColumnDataType(
+    Assert.assertEquals(RelToPlanNodeConverter.convertToColumnDataType(
             new ArraySqlType(new ObjectSqlType(SqlTypeName.VARCHAR, SqlIdentifier.STAR, true, null, null), true)),
         DataSchema.ColumnDataType.STRING_ARRAY);
-    Assert.assertEquals(RelToStageConverter.convertToColumnDataType(
+    Assert.assertEquals(RelToPlanNodeConverter.convertToColumnDataType(
             new ArraySqlType(new ObjectSqlType(SqlTypeName.VARBINARY, SqlIdentifier.STAR, true, null, null), true)),
         DataSchema.ColumnDataType.BYTES_ARRAY);
   }
diff --git a/pinot-query-planner/src/test/java/org/apache/pinot/query/planner/plannode/SerDeUtilsTest.java b/pinot-query-planner/src/test/java/org/apache/pinot/query/planner/plannode/SerDeUtilsTest.java
index 111c83bd66..59132601b0 100644
--- a/pinot-query-planner/src/test/java/org/apache/pinot/query/planner/plannode/SerDeUtilsTest.java
+++ b/pinot-query-planner/src/test/java/org/apache/pinot/query/planner/plannode/SerDeUtilsTest.java
@@ -23,7 +23,8 @@ import java.util.List;
 import java.util.Map;
 import org.apache.pinot.common.proto.Plan;
 import org.apache.pinot.query.QueryEnvironmentTestBase;
-import org.apache.pinot.query.planner.QueryPlan;
+import org.apache.pinot.query.planner.DispatchablePlanFragment;
+import org.apache.pinot.query.planner.DispatchableSubPlan;
 import org.apache.pinot.query.planner.serde.ProtoProperties;
 import org.testng.Assert;
 import org.testng.annotations.Test;
@@ -34,14 +35,16 @@ public class SerDeUtilsTest extends QueryEnvironmentTestBase {
   @Test(dataProvider = "testQueryDataProvider")
   public void testQueryStagePlanSerDe(String query)
       throws Exception {
-    QueryPlan queryPlan = _queryEnvironment.planQuery(query);
-    for (PlanNode planNode : queryPlan.getQueryStageMap().values()) {
-      Plan.StageNode serializedStageNode = StageNodeSerDeUtils.serializeStageNode((AbstractPlanNode) planNode);
-      PlanNode deserializedPlanNode = StageNodeSerDeUtils.deserializeStageNode(serializedStageNode);
-      Assert.assertTrue(isObjectEqual(planNode, deserializedPlanNode));
-      Assert.assertEquals(deserializedPlanNode.getPlanFragmentId(), planNode.getPlanFragmentId());
-      Assert.assertEquals(deserializedPlanNode.getDataSchema(), planNode.getDataSchema());
-      Assert.assertEquals(deserializedPlanNode.getInputs().size(), planNode.getInputs().size());
+
+    DispatchableSubPlan dispatchableSubPlan = _queryEnvironment.planQuery(query);
+    for (DispatchablePlanFragment dispatchablePlanFragment : dispatchableSubPlan.getQueryStageList()) {
+      PlanNode stageNode = dispatchablePlanFragment.getPlanFragment().getFragmentRoot();
+      Plan.StageNode serializedStageNode = StageNodeSerDeUtils.serializeStageNode((AbstractPlanNode) stageNode);
+      PlanNode deserializedStageNode = StageNodeSerDeUtils.deserializeStageNode(serializedStageNode);
+      Assert.assertTrue(isObjectEqual(stageNode, deserializedStageNode));
+      Assert.assertEquals(deserializedStageNode.getPlanFragmentId(), stageNode.getPlanFragmentId());
+      Assert.assertEquals(deserializedStageNode.getDataSchema(), stageNode.getDataSchema());
+      Assert.assertEquals(deserializedStageNode.getInputs().size(), stageNode.getInputs().size());
     }
   }
 
diff --git a/pinot-query-planner/src/test/java/org/apache/pinot/query/queries/ResourceBasedQueryPlansTest.java b/pinot-query-planner/src/test/java/org/apache/pinot/query/queries/ResourceBasedQueryPlansTest.java
index 50218e1d4c..dc3c17713e 100644
--- a/pinot-query-planner/src/test/java/org/apache/pinot/query/queries/ResourceBasedQueryPlansTest.java
+++ b/pinot-query-planner/src/test/java/org/apache/pinot/query/queries/ResourceBasedQueryPlansTest.java
@@ -33,7 +33,7 @@ import java.util.Map;
 import java.util.regex.Pattern;
 import org.apache.commons.lang3.StringUtils;
 import org.apache.pinot.query.QueryEnvironmentTestBase;
-import org.apache.pinot.query.planner.QueryPlan;
+import org.apache.pinot.query.planner.DispatchableSubPlan;
 import org.testng.Assert;
 import org.testng.annotations.DataProvider;
 import org.testng.annotations.Test;
@@ -51,9 +51,10 @@ public class ResourceBasedQueryPlansTest extends QueryEnvironmentTestBase {
       Assert.assertEquals(explainedPlan, output,
           String.format("Test case %s for query %s doesn't match expected output: %s", testCaseName, query, output));
       String queryWithoutExplainPlan = query.replace("EXPLAIN PLAN FOR ", "");
-      QueryPlan queryPlan = _queryEnvironment.planQuery(queryWithoutExplainPlan);
-      Assert.assertNotNull(queryPlan, String.format("Test case %s for query %s should not have a null QueryPlan",
-          testCaseName, queryWithoutExplainPlan));
+      DispatchableSubPlan dispatchableSubPlan = _queryEnvironment.planQuery(queryWithoutExplainPlan);
+      Assert.assertNotNull(dispatchableSubPlan,
+          String.format("Test case %s for query %s should not have a null QueryPlan",
+              testCaseName, queryWithoutExplainPlan));
     } catch (Exception e) {
       Assert.fail("Test case: " + testCaseName + " failed to explain query: " + query, e);
     }
@@ -109,27 +110,27 @@ public class ResourceBasedQueryPlansTest extends QueryEnvironmentTestBase {
   private static Object[][] testResourceQueryPlannerTestCaseProviderExceptions()
       throws Exception {
     Map<String, QueryPlanTestCase> testCaseMap = getTestCases();
-      List<Object[]> providerContent = new ArrayList<>();
-        for (Map.Entry<String, QueryPlanTestCase> testCaseEntry : testCaseMap.entrySet()) {
-        String testCaseName = testCaseEntry.getKey();
-        if (testCaseEntry.getValue()._ignored) {
+    List<Object[]> providerContent = new ArrayList<>();
+    for (Map.Entry<String, QueryPlanTestCase> testCaseEntry : testCaseMap.entrySet()) {
+      String testCaseName = testCaseEntry.getKey();
+      if (testCaseEntry.getValue()._ignored) {
+        continue;
+      }
+
+      List<QueryPlanTestCase.Query> queryCases = testCaseEntry.getValue()._queries;
+      for (QueryPlanTestCase.Query queryCase : queryCases) {
+        if (queryCase._ignored) {
           continue;
         }
 
-        List<QueryPlanTestCase.Query> queryCases = testCaseEntry.getValue()._queries;
-        for (QueryPlanTestCase.Query queryCase : queryCases) {
-          if (queryCase._ignored) {
-            continue;
-          }
-
-          if (queryCase._expectedException != null) {
-            String sql = queryCase._sql;
-            String exceptionString = queryCase._expectedException;
-            Object[] testEntry = new Object[]{testCaseName, sql, exceptionString};
-            providerContent.add(testEntry);
-          }
+        if (queryCase._expectedException != null) {
+          String sql = queryCase._sql;
+          String exceptionString = queryCase._expectedException;
+          Object[] testEntry = new Object[]{testCaseName, sql, exceptionString};
+          providerContent.add(testEntry);
         }
       }
+    }
     return providerContent.toArray(new Object[][]{});
   }
 
@@ -161,7 +162,8 @@ public class ResourceBasedQueryPlansTest extends QueryEnvironmentTestBase {
       // This test only supports local resource loading (e.g. must be a file), not support JAR test loading.
       if (testFileUrl != null && new File(testFileUrl.getFile()).exists()) {
         Map<String, QueryPlanTestCase> testCases = MAPPER.readValue(new File(testFileUrl.getFile()),
-            new TypeReference<Map<String, QueryPlanTestCase>>() { });
+            new TypeReference<Map<String, QueryPlanTestCase>>() {
+            });
         {
           HashSet<String> hashSet = new HashSet<>(testCaseMap.keySet());
           hashSet.retainAll(testCases.keySet());
diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/QueryRunner.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/QueryRunner.java
index 8876a950e3..0ff209b904 100644
--- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/QueryRunner.java
+++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/QueryRunner.java
@@ -42,7 +42,7 @@ import org.apache.pinot.core.query.scheduler.resources.ResourceManager;
 import org.apache.pinot.query.mailbox.MailboxService;
 import org.apache.pinot.query.planner.plannode.MailboxSendNode;
 import org.apache.pinot.query.planner.plannode.PlanNode;
-import org.apache.pinot.query.routing.PlanFragmentMetadata;
+import org.apache.pinot.query.routing.StageMetadata;
 import org.apache.pinot.query.routing.VirtualServerAddress;
 import org.apache.pinot.query.routing.WorkerMetadata;
 import org.apache.pinot.query.runtime.executor.LeafSchedulerService;
@@ -208,9 +208,9 @@ public class QueryRunner {
   private static List<ServerPlanRequestContext> constructServerQueryRequests(DistributedStagePlan distributedStagePlan,
       Map<String, String> requestMetadataMap, ZkHelixPropertyStore<ZNRecord> helixPropertyStore,
       MailboxService mailboxService, long deadlineMs) {
-    PlanFragmentMetadata planFragmentMetadata = distributedStagePlan.getStageMetadata();
+    StageMetadata stageMetadata = distributedStagePlan.getStageMetadata();
     WorkerMetadata workerMetadata = distributedStagePlan.getCurrentWorkerMetadata();
-    String rawTableName = PlanFragmentMetadata.getTableName(planFragmentMetadata);
+    String rawTableName = StageMetadata.getTableName(stageMetadata);
     Map<String, List<String>> tableToSegmentListMap = WorkerMetadata.getTableSegmentsMap(workerMetadata);
     List<ServerPlanRequestContext> requests = new ArrayList<>();
     for (Map.Entry<String, List<String>> tableEntry : tableToSegmentListMap.entrySet()) {
@@ -224,7 +224,7 @@ public class QueryRunner {
         Schema schema = ZKMetadataProvider.getTableSchema(helixPropertyStore,
             TableNameBuilder.forType(TableType.OFFLINE).tableNameWithType(rawTableName));
         requests.add(ServerRequestPlanVisitor.build(mailboxService, distributedStagePlan, requestMetadataMap,
-            tableConfig, schema, PlanFragmentMetadata.getTimeBoundary(planFragmentMetadata), TableType.OFFLINE,
+            tableConfig, schema, StageMetadata.getTimeBoundary(stageMetadata), TableType.OFFLINE,
             tableEntry.getValue(), deadlineMs));
       } else if (TableType.REALTIME.name().equals(tableType)) {
         TableConfig tableConfig = ZKMetadataProvider.getTableConfig(helixPropertyStore,
@@ -232,7 +232,7 @@ public class QueryRunner {
         Schema schema = ZKMetadataProvider.getTableSchema(helixPropertyStore,
             TableNameBuilder.forType(TableType.REALTIME).tableNameWithType(rawTableName));
         requests.add(ServerRequestPlanVisitor.build(mailboxService, distributedStagePlan, requestMetadataMap,
-            tableConfig, schema, PlanFragmentMetadata.getTimeBoundary(planFragmentMetadata), TableType.REALTIME,
+            tableConfig, schema, StageMetadata.getTimeBoundary(stageMetadata), TableType.REALTIME,
             tableEntry.getValue(), deadlineMs));
       } else {
         throw new IllegalArgumentException("Unsupported table type key: " + tableType);
diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/utils/OperatorUtils.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/utils/OperatorUtils.java
index 62cb7c3d45..b771165cad 100644
--- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/utils/OperatorUtils.java
+++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/utils/OperatorUtils.java
@@ -25,7 +25,7 @@ import java.util.Map;
 import org.apache.commons.lang.StringUtils;
 import org.apache.pinot.common.datablock.MetadataBlock;
 import org.apache.pinot.common.datatable.DataTable;
-import org.apache.pinot.query.routing.PlanFragmentMetadata;
+import org.apache.pinot.query.planner.DispatchablePlanFragment;
 import org.apache.pinot.query.routing.VirtualServerAddress;
 import org.apache.pinot.query.runtime.operator.OperatorStats;
 import org.apache.pinot.spi.utils.JsonUtils;
@@ -67,10 +67,10 @@ public class OperatorUtils {
     return functionName;
   }
 
-  public static void recordTableName(OperatorStats operatorStats, PlanFragmentMetadata planFragmentMetadata) {
-    if (PlanFragmentMetadata.getTableName(planFragmentMetadata) != null) {
-      operatorStats.recordSingleStat(DataTable.MetadataKey.TABLE.getName(),
-          PlanFragmentMetadata.getTableName(planFragmentMetadata));
+  public static void recordTableName(OperatorStats operatorStats, DispatchablePlanFragment dispatchablePlanFragment) {
+    String tableName = dispatchablePlanFragment.getTableName();
+    if (tableName != null) {
+      operatorStats.recordSingleStat(DataTable.MetadataKey.TABLE.getName(), tableName);
     }
   }
 
diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/DistributedStagePlan.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/DistributedStagePlan.java
index da987d7b9b..7f4e3015f7 100644
--- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/DistributedStagePlan.java
+++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/DistributedStagePlan.java
@@ -19,7 +19,7 @@
 package org.apache.pinot.query.runtime.plan;
 
 import org.apache.pinot.query.planner.plannode.PlanNode;
-import org.apache.pinot.query.routing.PlanFragmentMetadata;
+import org.apache.pinot.query.routing.StageMetadata;
 import org.apache.pinot.query.routing.VirtualServerAddress;
 import org.apache.pinot.query.routing.WorkerMetadata;
 
@@ -34,18 +34,18 @@ public class DistributedStagePlan {
   private int _stageId;
   private VirtualServerAddress _server;
   private PlanNode _stageRoot;
-  private PlanFragmentMetadata _planFragmentMetadata;
+  private StageMetadata _stageMetadata;
 
   public DistributedStagePlan(int stageId) {
     _stageId = stageId;
   }
 
   public DistributedStagePlan(int stageId, VirtualServerAddress server, PlanNode stageRoot,
-      PlanFragmentMetadata planFragmentMetadata) {
+      StageMetadata stageMetadata) {
     _stageId = stageId;
     _server = server;
     _stageRoot = stageRoot;
-    _planFragmentMetadata = planFragmentMetadata;
+    _stageMetadata = stageMetadata;
   }
 
   public int getStageId() {
@@ -60,8 +60,8 @@ public class DistributedStagePlan {
     return _stageRoot;
   }
 
-  public PlanFragmentMetadata getStageMetadata() {
-    return _planFragmentMetadata;
+  public StageMetadata getStageMetadata() {
+    return _stageMetadata;
   }
 
   public void setServer(VirtualServerAddress serverAddress) {
@@ -72,11 +72,11 @@ public class DistributedStagePlan {
     _stageRoot = stageRoot;
   }
 
-  public void setStageMetadata(PlanFragmentMetadata planFragmentMetadata) {
-    _planFragmentMetadata = planFragmentMetadata;
+  public void setStageMetadata(StageMetadata stageMetadata) {
+    _stageMetadata = stageMetadata;
   }
 
   public WorkerMetadata getCurrentWorkerMetadata() {
-    return _planFragmentMetadata.getWorkerMetadataList().get(_server.workerId());
+    return _stageMetadata.getWorkerMetadataList().get(_server.workerId());
   }
 }
diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/OpChainExecutionContext.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/OpChainExecutionContext.java
index 0fd764c00e..11e8107996 100644
--- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/OpChainExecutionContext.java
+++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/OpChainExecutionContext.java
@@ -19,7 +19,7 @@
 package org.apache.pinot.query.runtime.plan;
 
 import org.apache.pinot.query.mailbox.MailboxService;
-import org.apache.pinot.query.routing.PlanFragmentMetadata;
+import org.apache.pinot.query.routing.StageMetadata;
 import org.apache.pinot.query.routing.VirtualServerAddress;
 import org.apache.pinot.query.runtime.operator.OpChainId;
 import org.apache.pinot.query.runtime.operator.OpChainStats;
@@ -37,13 +37,13 @@ public class OpChainExecutionContext {
   private final VirtualServerAddress _server;
   private final long _timeoutMs;
   private final long _deadlineMs;
-  private final PlanFragmentMetadata _planFragmentMetadata;
+  private final StageMetadata _stageMetadata;
   private final OpChainId _id;
   private final OpChainStats _stats;
   private final boolean _traceEnabled;
 
   public OpChainExecutionContext(MailboxService mailboxService, long requestId, int stageId,
-      VirtualServerAddress server, long timeoutMs, long deadlineMs, PlanFragmentMetadata planFragmentMetadata,
+      VirtualServerAddress server, long timeoutMs, long deadlineMs, StageMetadata stageMetadata,
       boolean traceEnabled) {
     _mailboxService = mailboxService;
     _requestId = requestId;
@@ -51,7 +51,7 @@ public class OpChainExecutionContext {
     _server = server;
     _timeoutMs = timeoutMs;
     _deadlineMs = deadlineMs;
-    _planFragmentMetadata = planFragmentMetadata;
+    _stageMetadata = stageMetadata;
     _id = new OpChainId(requestId, server.workerId(), stageId);
     _stats = new OpChainStats(_id.toString());
     _traceEnabled = traceEnabled;
@@ -87,8 +87,8 @@ public class OpChainExecutionContext {
     return _deadlineMs;
   }
 
-  public PlanFragmentMetadata getStageMetadata() {
-    return _planFragmentMetadata;
+  public StageMetadata getStageMetadata() {
+    return _stageMetadata;
   }
 
   public OpChainId getId() {
diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/PlanRequestContext.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/PlanRequestContext.java
index 4383fa65c1..d3d890d9d5 100644
--- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/PlanRequestContext.java
+++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/PlanRequestContext.java
@@ -21,7 +21,7 @@ package org.apache.pinot.query.runtime.plan;
 import java.util.ArrayList;
 import java.util.List;
 import org.apache.pinot.query.mailbox.MailboxService;
-import org.apache.pinot.query.routing.PlanFragmentMetadata;
+import org.apache.pinot.query.routing.StageMetadata;
 import org.apache.pinot.query.routing.VirtualServerAddress;
 
 
@@ -33,20 +33,20 @@ public class PlanRequestContext {
   private final long _timeoutMs;
   private final long _deadlineMs;
   protected final VirtualServerAddress _server;
-  protected final PlanFragmentMetadata _planFragmentMetadata;
+  protected final StageMetadata _stageMetadata;
   protected final List<String> _receivingMailboxIds = new ArrayList<>();
   private final OpChainExecutionContext _opChainExecutionContext;
   private final boolean _traceEnabled;
 
   public PlanRequestContext(MailboxService mailboxService, long requestId, int stageId, long timeoutMs, long deadlineMs,
-      VirtualServerAddress server, PlanFragmentMetadata planFragmentMetadata, boolean traceEnabled) {
+      VirtualServerAddress server, StageMetadata stageMetadata, boolean traceEnabled) {
     _mailboxService = mailboxService;
     _requestId = requestId;
     _stageId = stageId;
     _timeoutMs = timeoutMs;
     _deadlineMs = deadlineMs;
     _server = server;
-    _planFragmentMetadata = planFragmentMetadata;
+    _stageMetadata = stageMetadata;
     _traceEnabled = traceEnabled;
     _opChainExecutionContext = new OpChainExecutionContext(this);
   }
@@ -71,8 +71,8 @@ public class PlanRequestContext {
     return _server;
   }
 
-  public PlanFragmentMetadata getStageMetadata() {
-    return _planFragmentMetadata;
+  public StageMetadata getStageMetadata() {
+    return _stageMetadata;
   }
 
   public MailboxService getMailboxService() {
diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/serde/QueryPlanSerDeUtils.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/serde/QueryPlanSerDeUtils.java
index 1844df632c..62b2af1ab4 100644
--- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/serde/QueryPlanSerDeUtils.java
+++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/serde/QueryPlanSerDeUtils.java
@@ -28,7 +28,7 @@ import org.apache.pinot.common.proto.Worker;
 import org.apache.pinot.query.planner.plannode.AbstractPlanNode;
 import org.apache.pinot.query.planner.plannode.StageNodeSerDeUtils;
 import org.apache.pinot.query.routing.MailboxMetadata;
-import org.apache.pinot.query.routing.PlanFragmentMetadata;
+import org.apache.pinot.query.routing.StageMetadata;
 import org.apache.pinot.query.routing.VirtualServerAddress;
 import org.apache.pinot.query.routing.WorkerMetadata;
 import org.apache.pinot.query.runtime.plan.DistributedStagePlan;
@@ -79,8 +79,8 @@ public class QueryPlanSerDeUtils {
     return String.format("%s@%s:%s", serverAddress.workerId(), serverAddress.hostname(), serverAddress.port());
   }
 
-  private static PlanFragmentMetadata fromProtoStageMetadata(Worker.StageMetadata protoStageMetadata) {
-    PlanFragmentMetadata.Builder builder = new PlanFragmentMetadata.Builder();
+  private static StageMetadata fromProtoStageMetadata(Worker.StageMetadata protoStageMetadata) {
+    StageMetadata.Builder builder = new StageMetadata.Builder();
     List<WorkerMetadata> workerMetadataList = new ArrayList<>();
     for (Worker.WorkerMetadata protoWorkerMetadata : protoStageMetadata.getWorkerMetadataList()) {
       workerMetadataList.add(fromProtoWorkerMetadata(protoWorkerMetadata));
@@ -119,12 +119,12 @@ public class QueryPlanSerDeUtils {
     return mailboxMetadata;
   }
 
-  private static Worker.StageMetadata toProtoStageMetadata(PlanFragmentMetadata planFragmentMetadata) {
+  private static Worker.StageMetadata toProtoStageMetadata(StageMetadata stageMetadata) {
     Worker.StageMetadata.Builder builder = Worker.StageMetadata.newBuilder();
-    for (WorkerMetadata workerMetadata : planFragmentMetadata.getWorkerMetadataList()) {
+    for (WorkerMetadata workerMetadata : stageMetadata.getWorkerMetadataList()) {
       builder.addWorkerMetadata(toProtoWorkerMetadata(workerMetadata));
     }
-    builder.putAllCustomProperty(planFragmentMetadata.getCustomProperties());
+    builder.putAllCustomProperty(stageMetadata.getCustomProperties());
     return builder.build();
   }
 
diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/server/ServerPlanRequestContext.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/server/ServerPlanRequestContext.java
index b1ba1624b5..1c0f7168ff 100644
--- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/server/ServerPlanRequestContext.java
+++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/server/ServerPlanRequestContext.java
@@ -22,7 +22,7 @@ import org.apache.pinot.common.request.InstanceRequest;
 import org.apache.pinot.common.request.PinotQuery;
 import org.apache.pinot.core.routing.TimeBoundaryInfo;
 import org.apache.pinot.query.mailbox.MailboxService;
-import org.apache.pinot.query.routing.PlanFragmentMetadata;
+import org.apache.pinot.query.routing.StageMetadata;
 import org.apache.pinot.query.routing.VirtualServerAddress;
 import org.apache.pinot.query.runtime.plan.PlanRequestContext;
 import org.apache.pinot.spi.config.table.TableType;
@@ -40,9 +40,9 @@ public class ServerPlanRequestContext extends PlanRequestContext {
   protected InstanceRequest _instanceRequest;
 
   public ServerPlanRequestContext(MailboxService mailboxService, long requestId, int stageId, long timeoutMs,
-      long deadlineMs, VirtualServerAddress server, PlanFragmentMetadata planFragmentMetadata, PinotQuery pinotQuery,
+      long deadlineMs, VirtualServerAddress server, StageMetadata stageMetadata, PinotQuery pinotQuery,
       TableType tableType, TimeBoundaryInfo timeBoundaryInfo, boolean traceEnabled) {
-    super(mailboxService, requestId, stageId, timeoutMs, deadlineMs, server, planFragmentMetadata, traceEnabled);
+    super(mailboxService, requestId, stageId, timeoutMs, deadlineMs, server, stageMetadata, traceEnabled);
     _pinotQuery = pinotQuery;
     _tableType = tableType;
     _timeBoundaryInfo = timeBoundaryInfo;
diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/QueryDispatcher.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/QueryDispatcher.java
index c4c48abf59..aa700a2e51 100644
--- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/QueryDispatcher.java
+++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/QueryDispatcher.java
@@ -45,11 +45,9 @@ import org.apache.pinot.core.common.ObjectSerDeUtils;
 import org.apache.pinot.core.query.reduce.ExecutionStatsAggregator;
 import org.apache.pinot.core.util.trace.TracedThreadFactory;
 import org.apache.pinot.query.mailbox.MailboxService;
+import org.apache.pinot.query.planner.DispatchableSubPlan;
 import org.apache.pinot.query.planner.ExplainPlanPlanVisitor;
-import org.apache.pinot.query.planner.QueryPlan;
-import org.apache.pinot.query.planner.physical.DispatchablePlanMetadata;
 import org.apache.pinot.query.planner.plannode.MailboxReceiveNode;
-import org.apache.pinot.query.routing.PlanFragmentMetadata;
 import org.apache.pinot.query.routing.QueryServerInstance;
 import org.apache.pinot.query.routing.VirtualServerAddress;
 import org.apache.pinot.query.runtime.blocks.TransferableBlock;
@@ -84,29 +82,33 @@ public class QueryDispatcher {
         new TracedThreadFactory(Thread.NORM_PRIORITY, false, PINOT_BROKER_QUERY_DISPATCHER_FORMAT));
   }
 
-  public ResultTable submitAndReduce(long requestId, QueryPlan queryPlan, MailboxService mailboxService, long timeoutMs,
+  public ResultTable submitAndReduce(long requestId, DispatchableSubPlan dispatchableSubPlan,
+      MailboxService mailboxService, long timeoutMs,
       Map<String, String> queryOptions, Map<Integer, ExecutionStatsAggregator> executionStatsAggregator,
       boolean traceEnabled)
       throws Exception {
     try {
       // submit all the distributed stages.
-      int reduceStageId = submit(requestId, queryPlan, timeoutMs, queryOptions);
+      int reduceStageId = submit(requestId, dispatchableSubPlan, timeoutMs, queryOptions);
       // run reduce stage and return result.
-      return runReducer(requestId, queryPlan, reduceStageId, timeoutMs, mailboxService, executionStatsAggregator,
+      return runReducer(requestId, dispatchableSubPlan, reduceStageId, timeoutMs, mailboxService,
+          executionStatsAggregator,
           traceEnabled);
     } catch (Exception e) {
-      cancel(requestId, queryPlan);
-      throw new RuntimeException("Error executing query: " + ExplainPlanPlanVisitor.explain(queryPlan), e);
+      cancel(requestId, dispatchableSubPlan);
+      throw new RuntimeException("Error executing query: " + ExplainPlanPlanVisitor.explain(dispatchableSubPlan), e);
     }
   }
 
-  private void cancel(long requestId, QueryPlan queryPlan) {
+  private void cancel(long requestId, DispatchableSubPlan dispatchableSubPlan) {
     Set<DispatchClient> dispatchClientSet = new HashSet<>();
-    for (Map.Entry<Integer, DispatchablePlanMetadata> stage : queryPlan.getDispatchablePlanMetadataMap().entrySet()) {
-      int stageId = stage.getKey();
+
+    for (int stageId = 0; stageId < dispatchableSubPlan.getQueryStageList().size(); stageId++) {
       // stage rooting at a mailbox receive node means reduce stage.
-      if (!(queryPlan.getQueryStageMap().get(stageId) instanceof MailboxReceiveNode)) {
-        Set<QueryServerInstance> serverInstances = stage.getValue().getServerInstanceToWorkerIdMap().keySet();
+      if (!(dispatchableSubPlan.getQueryStageList().get(stageId).getPlanFragment()
+          .getFragmentRoot() instanceof MailboxReceiveNode)) {
+        Set<QueryServerInstance> serverInstances =
+            dispatchableSubPlan.getQueryStageList().get(stageId).getServerInstanceToWorkerIdMap().keySet();
         for (QueryServerInstance serverInstance : serverInstances) {
           String host = serverInstance.getHostname();
           int servicePort = serverInstance.getQueryServicePort();
@@ -120,20 +122,22 @@ public class QueryDispatcher {
   }
 
   @VisibleForTesting
-  int submit(long requestId, QueryPlan queryPlan, long timeoutMs, Map<String, String> queryOptions)
+  int submit(long requestId, DispatchableSubPlan dispatchableSubPlan, long timeoutMs,
+      Map<String, String> queryOptions)
       throws Exception {
     int reduceStageId = -1;
     Deadline deadline = Deadline.after(timeoutMs, TimeUnit.MILLISECONDS);
     BlockingQueue<AsyncQueryDispatchResponse> dispatchCallbacks = new LinkedBlockingQueue<>();
     int dispatchCalls = 0;
-    for (Map.Entry<Integer, DispatchablePlanMetadata> stage : queryPlan.getDispatchablePlanMetadataMap().entrySet()) {
-      int stageId = stage.getKey();
+
+    for (int stageId = 0; stageId < dispatchableSubPlan.getQueryStageList().size(); stageId++) {
       // stage rooting at a mailbox receive node means reduce stage.
-      if (queryPlan.getQueryStageMap().get(stageId) instanceof MailboxReceiveNode) {
+      if (dispatchableSubPlan.getQueryStageList().get(stageId).getPlanFragment()
+          .getFragmentRoot() instanceof MailboxReceiveNode) {
         reduceStageId = stageId;
       } else {
         for (Map.Entry<QueryServerInstance, List<Integer>> queryServerEntry
-            : stage.getValue().getServerInstanceToWorkerIdMap().entrySet()) {
+            : dispatchableSubPlan.getQueryStageList().get(stageId).getServerInstanceToWorkerIdMap().entrySet()) {
           QueryServerInstance queryServerInstance = queryServerEntry.getKey();
           for (int workerId : queryServerEntry.getValue()) {
             String host = queryServerInstance.getHostname();
@@ -142,15 +146,14 @@ public class QueryDispatcher {
             VirtualServerAddress virtualServerAddress = new VirtualServerAddress(host, mailboxPort, workerId);
             DispatchClient client = getOrCreateDispatchClient(host, servicePort);
             dispatchCalls++;
-            _executorService.submit(() -> {
-              client.submit(Worker.QueryRequest.newBuilder().setStagePlan(
-                      QueryPlanSerDeUtils.serialize(constructDistributedStagePlan(queryPlan, stageId,
-                          virtualServerAddress)))
-                  .putMetadata(QueryConfig.KEY_OF_BROKER_REQUEST_ID, String.valueOf(requestId))
-                  .putMetadata(QueryConfig.KEY_OF_BROKER_REQUEST_TIMEOUT_MS, String.valueOf(timeoutMs))
-                  .putAllMetadata(queryOptions).build(), stageId, queryServerInstance, deadline,
-                  dispatchCallbacks::offer);
-            });
+            int finalStageId = stageId;
+            _executorService.submit(() -> client.submit(Worker.QueryRequest.newBuilder().setStagePlan(
+                        QueryPlanSerDeUtils.serialize(
+                            constructDistributedStagePlan(dispatchableSubPlan, finalStageId, virtualServerAddress)))
+                    .putMetadata(QueryConfig.KEY_OF_BROKER_REQUEST_ID, String.valueOf(requestId))
+                    .putMetadata(QueryConfig.KEY_OF_BROKER_REQUEST_TIMEOUT_MS, String.valueOf(timeoutMs))
+                    .putAllMetadata(queryOptions).build(), finalStageId, queryServerInstance, deadline,
+                dispatchCallbacks::offer));
           }
         }
       }
@@ -184,29 +187,37 @@ public class QueryDispatcher {
   }
 
   @VisibleForTesting
-  public static ResultTable runReducer(long requestId, QueryPlan queryPlan, int reduceStageId, long timeoutMs,
+  public static ResultTable runReducer(long requestId, DispatchableSubPlan dispatchableSubPlan, int reduceStageId,
+      long timeoutMs,
       MailboxService mailboxService, Map<Integer, ExecutionStatsAggregator> statsAggregatorMap, boolean traceEnabled) {
-    MailboxReceiveNode reduceNode = (MailboxReceiveNode) queryPlan.getQueryStageMap().get(reduceStageId);
+    MailboxReceiveNode reduceNode =
+        (MailboxReceiveNode) dispatchableSubPlan.getQueryStageList().get(reduceStageId).getPlanFragment()
+            .getFragmentRoot();
     VirtualServerAddress server = new VirtualServerAddress(mailboxService.getHostname(), mailboxService.getPort(), 0);
     OpChainExecutionContext context =
         new OpChainExecutionContext(mailboxService, requestId, reduceStageId, server, timeoutMs,
-            System.currentTimeMillis() + timeoutMs, queryPlan.getStageMetadata(reduceStageId), traceEnabled);
+            System.currentTimeMillis() + timeoutMs,
+            dispatchableSubPlan.getQueryStageList().get(reduceStageId).toStageMetadata(),
+            traceEnabled);
     MailboxReceiveOperator mailboxReceiveOperator = createReduceStageOperator(context, reduceNode.getSenderStageId());
     List<DataBlock> resultDataBlocks =
-        reduceMailboxReceive(mailboxReceiveOperator, timeoutMs, statsAggregatorMap, queryPlan, context.getStats());
-    return toResultTable(resultDataBlocks, queryPlan.getQueryResultFields(),
-        queryPlan.getQueryStageMap().get(0).getDataSchema());
+        reduceMailboxReceive(mailboxReceiveOperator, timeoutMs, statsAggregatorMap, dispatchableSubPlan,
+            context.getStats());
+    return toResultTable(resultDataBlocks, dispatchableSubPlan.getQueryResultFields(),
+        dispatchableSubPlan.getQueryStageList().get(0).getPlanFragment().getFragmentRoot().getDataSchema());
   }
 
   @VisibleForTesting
-  public static DistributedStagePlan constructDistributedStagePlan(QueryPlan queryPlan, int stageId,
-      VirtualServerAddress serverAddress) {
-    return new DistributedStagePlan(stageId, serverAddress, queryPlan.getQueryStageMap().get(stageId),
-        queryPlan.getStageMetadata(stageId));
+  public static DistributedStagePlan constructDistributedStagePlan(DispatchableSubPlan dispatchableSubPlan,
+      int stageId, VirtualServerAddress serverAddress) {
+    return new DistributedStagePlan(stageId, serverAddress,
+        dispatchableSubPlan.getQueryStageList().get(stageId).getPlanFragment().getFragmentRoot(),
+        dispatchableSubPlan.getQueryStageList().get(stageId).toStageMetadata());
   }
 
   private static List<DataBlock> reduceMailboxReceive(MailboxReceiveOperator mailboxReceiveOperator, long timeoutMs,
-      @Nullable Map<Integer, ExecutionStatsAggregator> executionStatsAggregatorMap, QueryPlan queryPlan,
+      @Nullable Map<Integer, ExecutionStatsAggregator> executionStatsAggregatorMap,
+      DispatchableSubPlan dispatchableSubPlan,
       OpChainStats stats) {
     List<DataBlock> resultDataBlocks = new ArrayList<>();
     TransferableBlock transferableBlock;
@@ -231,9 +242,9 @@ public class QueryDispatcher {
             ExecutionStatsAggregator stageStatsAggregator = executionStatsAggregatorMap.get(operatorStats.getStageId());
             rootStatsAggregator.aggregate(null, entry.getValue().getExecutionStats(), new HashMap<>());
             if (stageStatsAggregator != null) {
-              if (queryPlan != null) {
-                PlanFragmentMetadata planFragmentMetadata = queryPlan.getStageMetadata(operatorStats.getStageId());
-                OperatorUtils.recordTableName(operatorStats, planFragmentMetadata);
+              if (dispatchableSubPlan != null) {
+                OperatorUtils.recordTableName(operatorStats,
+                    dispatchableSubPlan.getQueryStageList().get(operatorStats.getStageId()));
               }
               stageStatsAggregator.aggregate(null, entry.getValue().getExecutionStats(), new HashMap<>());
             }
diff --git a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/QueryRunnerTest.java b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/QueryRunnerTest.java
index 37465dc211..ca35b62a9c 100644
--- a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/QueryRunnerTest.java
+++ b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/QueryRunnerTest.java
@@ -30,7 +30,7 @@ import org.apache.pinot.core.common.datatable.DataTableBuilderFactory;
 import org.apache.pinot.query.QueryEnvironmentTestBase;
 import org.apache.pinot.query.QueryServerEnclosure;
 import org.apache.pinot.query.mailbox.MailboxService;
-import org.apache.pinot.query.planner.QueryPlan;
+import org.apache.pinot.query.planner.DispatchableSubPlan;
 import org.apache.pinot.query.planner.plannode.MailboxReceiveNode;
 import org.apache.pinot.query.routing.QueryServerInstance;
 import org.apache.pinot.query.service.QueryConfig;
@@ -64,6 +64,7 @@ public class QueryRunnerTest extends QueryRunnerTestBase {
       new Object[]{"charlie", "bar", 1},
   };
   public static final Schema.SchemaBuilder SCHEMA_BUILDER;
+
   static {
     SCHEMA_BUILDER = new Schema.SchemaBuilder()
         .addSingleValueDimension("col1", FieldSpec.DataType.STRING, "")
@@ -153,11 +154,10 @@ public class QueryRunnerTest extends QueryRunnerTestBase {
     _mailboxService.shutdown();
   }
 
-
   /**
    * Test compares with expected row count only.
    */
-   @Test(dataProvider = "testDataWithSqlToFinalRowCount")
+  @Test(dataProvider = "testDataWithSqlToFinalRowCount")
   public void testSqlWithFinalRowCountChecker(String sql, int expectedRows)
       throws Exception {
     List<Object[]> resultRows = queryRunner(sql, null);
@@ -185,23 +185,24 @@ public class QueryRunnerTest extends QueryRunnerTestBase {
   @Test(dataProvider = "testDataWithSqlExecutionExceptions")
   public void testSqlWithExceptionMsgChecker(String sql, String exceptionMsg) {
     long requestId = RANDOM_REQUEST_ID_GEN.nextLong();
-    QueryPlan queryPlan = _queryEnvironment.planQuery(sql);
+    DispatchableSubPlan dispatchableSubPlan = _queryEnvironment.planQuery(sql);
     Map<String, String> requestMetadataMap =
         ImmutableMap.of(QueryConfig.KEY_OF_BROKER_REQUEST_ID, String.valueOf(requestId),
             QueryConfig.KEY_OF_BROKER_REQUEST_TIMEOUT_MS,
             String.valueOf(CommonConstants.Broker.DEFAULT_BROKER_TIMEOUT_MS));
     int reducerStageId = -1;
-    for (int stageId : queryPlan.getDispatchablePlanMetadataMap().keySet()) {
-      if (queryPlan.getQueryStageMap().get(stageId) instanceof MailboxReceiveNode) {
+    for (int stageId = 0; stageId < dispatchableSubPlan.getQueryStageList().size(); stageId++) {
+      if (dispatchableSubPlan.getQueryStageList().get(stageId).getPlanFragment()
+          .getFragmentRoot() instanceof MailboxReceiveNode) {
         reducerStageId = stageId;
       } else {
-        processDistributedStagePlans(queryPlan, stageId, requestMetadataMap);
+        processDistributedStagePlans(dispatchableSubPlan, stageId, requestMetadataMap);
       }
     }
     Preconditions.checkState(reducerStageId != -1);
 
     try {
-      QueryDispatcher.runReducer(requestId, queryPlan, reducerStageId,
+      QueryDispatcher.runReducer(requestId, dispatchableSubPlan, reducerStageId,
           Long.parseLong(requestMetadataMap.get(QueryConfig.KEY_OF_BROKER_REQUEST_TIMEOUT_MS)), _mailboxService, null,
           false);
     } catch (RuntimeException rte) {
@@ -213,7 +214,7 @@ public class QueryRunnerTest extends QueryRunnerTestBase {
 
   @DataProvider(name = "testDataWithSqlToFinalRowCount")
   private Object[][] provideTestSqlAndRowCount() {
-    return new Object[][] {
+    return new Object[][]{
         // using join clause
         new Object[]{"SELECT * FROM a JOIN b USING (col1)", 15},
 
@@ -225,8 +226,10 @@ public class QueryRunnerTest extends QueryRunnerTestBase {
         new Object[]{"SELECT dateTrunc('DAY', ts) FROM a LIMIT 10", 10},
         new Object[]{"SELECT dateTrunc('DAY', CAST(col3 AS BIGINT)) FROM a LIMIT 10", 10},
         //   - on intermediate stage
-        new Object[]{"SELECT dateTrunc('DAY', round(a.ts, b.ts)) FROM a JOIN b "
-            + "ON a.col1 = b.col1 AND a.col2 = b.col2", 15},
+        new Object[]{
+            "SELECT dateTrunc('DAY', round(a.ts, b.ts)) FROM a JOIN b "
+                + "ON a.col1 = b.col1 AND a.col2 = b.col2", 15
+        },
         new Object[]{"SELECT dateTrunc('DAY', CAST(MAX(a.col3) AS BIGINT)) FROM a", 1},
 
         // ScalarFunction
@@ -250,14 +253,18 @@ public class QueryRunnerTest extends QueryRunnerTestBase {
 
   @DataProvider(name = "testDataWithSqlExecutionExceptions")
   private Object[][] provideTestSqlWithExecutionException() {
-    return new Object[][] {
+    return new Object[][]{
         // Timeout exception should occur with this option:
-        new Object[]{"SET timeoutMs = 1; SELECT * FROM a JOIN b ON a.col1 = b.col1 JOIN c ON a.col1 = c.col1",
-            "timeout"},
+        new Object[]{
+            "SET timeoutMs = 1; SELECT * FROM a JOIN b ON a.col1 = b.col1 JOIN c ON a.col1 = c.col1",
+            "timeout"
+        },
 
         // Function with incorrect argument signature should throw runtime exception when casting string to numeric
-        new Object[]{"SELECT least(a.col2, b.col3) FROM a JOIN b ON a.col1 = b.col1",
-            "For input string:"},
+        new Object[]{
+            "SELECT least(a.col2, b.col3) FROM a JOIN b ON a.col1 = b.col1",
+            "For input string:"
+        },
 
         // Scalar function that doesn't have a valid use should throw an exception on the leaf stage
         //   - predicate only functions:
diff --git a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/QueryRunnerTestBase.java b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/QueryRunnerTestBase.java
index 55e43dae11..32164207d0 100644
--- a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/QueryRunnerTestBase.java
+++ b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/QueryRunnerTestBase.java
@@ -47,7 +47,7 @@ import org.apache.pinot.query.QueryEnvironment;
 import org.apache.pinot.query.QueryServerEnclosure;
 import org.apache.pinot.query.QueryTestSet;
 import org.apache.pinot.query.mailbox.MailboxService;
-import org.apache.pinot.query.planner.QueryPlan;
+import org.apache.pinot.query.planner.DispatchableSubPlan;
 import org.apache.pinot.query.planner.plannode.MailboxReceiveNode;
 import org.apache.pinot.query.routing.QueryServerInstance;
 import org.apache.pinot.query.routing.VirtualServerAddress;
@@ -94,7 +94,9 @@ public abstract class QueryRunnerTestBase extends QueryTestSet {
   protected List<Object[]> queryRunner(String sql, Map<Integer, ExecutionStatsAggregator> executionStatsAggregatorMap) {
     long requestId = RANDOM_REQUEST_ID_GEN.nextLong();
     SqlNodeAndOptions sqlNodeAndOptions = CalciteSqlParser.compileToSqlNodeAndOptions(sql);
-    QueryPlan queryPlan = _queryEnvironment.planQuery(sql, sqlNodeAndOptions, requestId).getQueryPlan();
+    QueryEnvironment.QueryPlannerResult queryPlannerResult =
+        _queryEnvironment.planQuery(sql, sqlNodeAndOptions, requestId);
+    DispatchableSubPlan dispatchableSubPlan = queryPlannerResult.getQueryPlan();
     Map<String, String> requestMetadataMap = new HashMap<>();
     requestMetadataMap.put(QueryConfig.KEY_OF_BROKER_REQUEST_ID, String.valueOf(requestId));
     requestMetadataMap.put(QueryConfig.KEY_OF_BROKER_REQUEST_TIMEOUT_MS,
@@ -107,32 +109,33 @@ public abstract class QueryRunnerTestBase extends QueryTestSet {
     }
 
     int reducerStageId = -1;
-    for (int stageId : queryPlan.getDispatchablePlanMetadataMap().keySet()) {
-      if (queryPlan.getQueryStageMap().get(stageId) instanceof MailboxReceiveNode) {
+    for (int stageId = 0; stageId < dispatchableSubPlan.getQueryStageList().size(); stageId++) {
+      if (dispatchableSubPlan.getQueryStageList().get(stageId).getPlanFragment()
+          .getFragmentRoot() instanceof MailboxReceiveNode) {
         reducerStageId = stageId;
       } else {
-        processDistributedStagePlans(queryPlan, stageId, requestMetadataMap);
+        processDistributedStagePlans(dispatchableSubPlan, stageId, requestMetadataMap);
       }
       if (executionStatsAggregatorMap != null) {
         executionStatsAggregatorMap.put(stageId, new ExecutionStatsAggregator(true));
       }
     }
     Preconditions.checkState(reducerStageId != -1);
-    ResultTable resultTable = QueryDispatcher.runReducer(requestId, queryPlan, reducerStageId,
+    ResultTable resultTable = QueryDispatcher.runReducer(requestId, dispatchableSubPlan, reducerStageId,
         Long.parseLong(requestMetadataMap.get(QueryConfig.KEY_OF_BROKER_REQUEST_TIMEOUT_MS)), _mailboxService,
         executionStatsAggregatorMap, true);
     return resultTable.getRows();
   }
 
-  protected void processDistributedStagePlans(QueryPlan queryPlan, int stageId,
+  protected void processDistributedStagePlans(DispatchableSubPlan dispatchableSubPlan, int stageId,
       Map<String, String> requestMetadataMap) {
     Map<QueryServerInstance, List<Integer>> serverInstanceToWorkerIdMap =
-        queryPlan.getDispatchablePlanMetadataMap().get(stageId).getServerInstanceToWorkerIdMap();
+        dispatchableSubPlan.getQueryStageList().get(stageId).getServerInstanceToWorkerIdMap();
     for (Map.Entry<QueryServerInstance, List<Integer>> entry : serverInstanceToWorkerIdMap.entrySet()) {
       QueryServerInstance server = entry.getKey();
       for (int workerId : entry.getValue()) {
         DistributedStagePlan distributedStagePlan = QueryDispatcher.constructDistributedStagePlan(
-            queryPlan, stageId, new VirtualServerAddress(server, workerId));
+            dispatchableSubPlan, stageId, new VirtualServerAddress(server, workerId));
         _servers.get(server).processQuery(distributedStagePlan, requestMetadataMap);
       }
     }
diff --git a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/MailboxReceiveOperatorTest.java b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/MailboxReceiveOperatorTest.java
index 1764165b1f..8c602da156 100644
--- a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/MailboxReceiveOperatorTest.java
+++ b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/MailboxReceiveOperatorTest.java
@@ -31,7 +31,7 @@ import org.apache.pinot.query.mailbox.MailboxIdUtils;
 import org.apache.pinot.query.mailbox.MailboxService;
 import org.apache.pinot.query.mailbox.ReceivingMailbox;
 import org.apache.pinot.query.routing.MailboxMetadata;
-import org.apache.pinot.query.routing.PlanFragmentMetadata;
+import org.apache.pinot.query.routing.StageMetadata;
 import org.apache.pinot.query.routing.VirtualServerAddress;
 import org.apache.pinot.query.routing.WorkerMetadata;
 import org.apache.pinot.query.runtime.blocks.TransferableBlock;
@@ -64,8 +64,8 @@ public class MailboxReceiveOperatorTest {
   private ReceivingMailbox _mailbox1;
   @Mock
   private ReceivingMailbox _mailbox2;
-  private PlanFragmentMetadata _planFragmentMetadataBoth;
-  private PlanFragmentMetadata _planFragmentMetadata1;
+  private StageMetadata _stageMetadataBoth;
+  private StageMetadata _stageMetadata1;
 
   @BeforeMethod
   public void setUp() {
@@ -74,7 +74,7 @@ public class MailboxReceiveOperatorTest {
     when(_mailboxService.getPort()).thenReturn(123);
     VirtualServerAddress server1 = new VirtualServerAddress("localhost", 123, 0);
     VirtualServerAddress server2 = new VirtualServerAddress("localhost", 123, 1);
-    _planFragmentMetadataBoth = new PlanFragmentMetadata.Builder()
+    _stageMetadataBoth = new StageMetadata.Builder()
         .setWorkerMetadataList(Stream.of(server1, server2).map(
                 s -> new WorkerMetadata.Builder()
                     .setVirtualServerAddress(s)
@@ -94,7 +94,7 @@ public class MailboxReceiveOperatorTest {
             .collect(Collectors.toList()))
         .build();
     // sending stage is 0, receiving stage is 1
-    _planFragmentMetadata1 = new PlanFragmentMetadata.Builder()
+    _stageMetadata1 = new StageMetadata.Builder()
         .setWorkerMetadataList(Stream.of(server1).map(
             s -> new WorkerMetadata.Builder()
                 .setVirtualServerAddress(s)
@@ -120,13 +120,13 @@ public class MailboxReceiveOperatorTest {
   public void shouldThrowSingletonNoMatchMailboxServer() {
     VirtualServerAddress server1 = new VirtualServerAddress("localhost", 456, 0);
     VirtualServerAddress server2 = new VirtualServerAddress("localhost", 789, 1);
-    PlanFragmentMetadata planFragmentMetadata = new PlanFragmentMetadata.Builder()
+    StageMetadata stageMetadata = new StageMetadata.Builder()
         .setWorkerMetadataList(Stream.of(server1, server2).map(
             s -> new WorkerMetadata.Builder().setVirtualServerAddress(s).build()).collect(Collectors.toList()))
         .build();
     OpChainExecutionContext context =
         new OpChainExecutionContext(_mailboxService, 0, 0, RECEIVER_ADDRESS, Long.MAX_VALUE, Long.MAX_VALUE,
-            planFragmentMetadata, false);
+            stageMetadata, false);
     //noinspection resource
     new MailboxReceiveOperator(context, RelDistribution.Type.SINGLETON, 1);
   }
@@ -135,7 +135,7 @@ public class MailboxReceiveOperatorTest {
   public void shouldThrowReceiveSingletonFromMultiMatchMailboxServer() {
     OpChainExecutionContext context =
         new OpChainExecutionContext(_mailboxService, 0, 0, RECEIVER_ADDRESS, Long.MAX_VALUE, Long.MAX_VALUE,
-            _planFragmentMetadataBoth, false);
+            _stageMetadataBoth, false);
     //noinspection resource
     new MailboxReceiveOperator(context, RelDistribution.Type.SINGLETON, 1);
   }
@@ -157,7 +157,7 @@ public class MailboxReceiveOperatorTest {
     // Short timeoutMs should result in timeout
     OpChainExecutionContext context =
         new OpChainExecutionContext(_mailboxService, 0, 0, RECEIVER_ADDRESS, 10L, System.currentTimeMillis() + 10L,
-            _planFragmentMetadata1, false);
+            _stageMetadata1, false);
     try (MailboxReceiveOperator receiveOp = new MailboxReceiveOperator(context, RelDistribution.Type.SINGLETON, 1)) {
       Thread.sleep(100L);
       TransferableBlock mailbox = receiveOp.nextBlock();
@@ -168,7 +168,7 @@ public class MailboxReceiveOperatorTest {
 
     // Longer timeout or default timeout (10s) doesn't result in timeout
     context = new OpChainExecutionContext(_mailboxService, 0, 0, RECEIVER_ADDRESS, 10_000L,
-        System.currentTimeMillis() + 10_000L, _planFragmentMetadata1, false);
+        System.currentTimeMillis() + 10_000L, _stageMetadata1, false);
     try (MailboxReceiveOperator receiveOp = new MailboxReceiveOperator(context, RelDistribution.Type.SINGLETON, 1)) {
       Thread.sleep(100L);
       TransferableBlock mailbox = receiveOp.nextBlock();
@@ -182,7 +182,7 @@ public class MailboxReceiveOperatorTest {
 
     OpChainExecutionContext context =
         new OpChainExecutionContext(_mailboxService, 0, 0, RECEIVER_ADDRESS, Long.MAX_VALUE, Long.MAX_VALUE,
-            _planFragmentMetadata1, false);
+            _stageMetadata1, false);
     try (MailboxReceiveOperator receiveOp = new MailboxReceiveOperator(context, RelDistribution.Type.SINGLETON, 1)) {
       assertTrue(receiveOp.nextBlock().isNoOpBlock());
     }
@@ -195,7 +195,7 @@ public class MailboxReceiveOperatorTest {
 
     OpChainExecutionContext context =
         new OpChainExecutionContext(_mailboxService, 0, 0, RECEIVER_ADDRESS, Long.MAX_VALUE, Long.MAX_VALUE,
-            _planFragmentMetadata1, false);
+            _stageMetadata1, false);
     try (MailboxReceiveOperator receiveOp = new MailboxReceiveOperator(context, RelDistribution.Type.SINGLETON, 1)) {
       assertTrue(receiveOp.nextBlock().isEndOfStreamBlock());
     }
@@ -210,7 +210,7 @@ public class MailboxReceiveOperatorTest {
 
     OpChainExecutionContext context =
         new OpChainExecutionContext(_mailboxService, 0, 0, RECEIVER_ADDRESS, Long.MAX_VALUE, Long.MAX_VALUE,
-            _planFragmentMetadata1, false);
+            _stageMetadata1, false);
     try (MailboxReceiveOperator receiveOp = new MailboxReceiveOperator(context, RelDistribution.Type.SINGLETON, 1)) {
       List<Object[]> actualRows = receiveOp.nextBlock().getContainer();
       assertEquals(actualRows.size(), 1);
@@ -228,7 +228,7 @@ public class MailboxReceiveOperatorTest {
 
     OpChainExecutionContext context =
         new OpChainExecutionContext(_mailboxService, 0, 0, RECEIVER_ADDRESS, Long.MAX_VALUE, Long.MAX_VALUE,
-            _planFragmentMetadata1, false);
+            _stageMetadata1, false);
     try (MailboxReceiveOperator receiveOp = new MailboxReceiveOperator(context, RelDistribution.Type.SINGLETON, 1)) {
       TransferableBlock block = receiveOp.nextBlock();
       assertTrue(block.isErrorBlock());
@@ -247,7 +247,7 @@ public class MailboxReceiveOperatorTest {
 
     OpChainExecutionContext context =
         new OpChainExecutionContext(_mailboxService, 0, 0, RECEIVER_ADDRESS, Long.MAX_VALUE, Long.MAX_VALUE,
-            _planFragmentMetadataBoth, false);
+            _stageMetadataBoth, false);
     try (MailboxReceiveOperator receiveOp = new MailboxReceiveOperator(context, RelDistribution.Type.HASH_DISTRIBUTED,
         1)) {
       List<Object[]> actualRows = receiveOp.nextBlock().getContainer();
@@ -271,7 +271,7 @@ public class MailboxReceiveOperatorTest {
 
     OpChainExecutionContext context =
         new OpChainExecutionContext(_mailboxService, 0, 0, RECEIVER_ADDRESS, Long.MAX_VALUE, Long.MAX_VALUE,
-            _planFragmentMetadataBoth, false);
+            _stageMetadataBoth, false);
     try (MailboxReceiveOperator receiveOp = new MailboxReceiveOperator(context, RelDistribution.Type.HASH_DISTRIBUTED,
         1)) {
       // Receive first block from server1
@@ -297,7 +297,7 @@ public class MailboxReceiveOperatorTest {
 
     OpChainExecutionContext context =
         new OpChainExecutionContext(_mailboxService, 0, 0, RECEIVER_ADDRESS, Long.MAX_VALUE, Long.MAX_VALUE,
-            _planFragmentMetadataBoth, false);
+            _stageMetadataBoth, false);
     try (MailboxReceiveOperator receiveOp = new MailboxReceiveOperator(context, RelDistribution.Type.HASH_DISTRIBUTED,
         1)) {
       TransferableBlock block = receiveOp.nextBlock();
diff --git a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/MailboxSendOperatorTest.java b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/MailboxSendOperatorTest.java
index ec7c00d513..5080ad8454 100644
--- a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/MailboxSendOperatorTest.java
+++ b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/MailboxSendOperatorTest.java
@@ -23,7 +23,7 @@ import java.util.List;
 import java.util.Map;
 import org.apache.pinot.common.utils.DataSchema;
 import org.apache.pinot.query.mailbox.MailboxService;
-import org.apache.pinot.query.routing.PlanFragmentMetadata;
+import org.apache.pinot.query.routing.StageMetadata;
 import org.apache.pinot.query.routing.VirtualServerAddress;
 import org.apache.pinot.query.routing.WorkerMetadata;
 import org.apache.pinot.query.runtime.blocks.TransferableBlock;
@@ -165,12 +165,12 @@ public class MailboxSendOperatorTest {
   }
 
   private MailboxSendOperator getMailboxSendOperator() {
-    PlanFragmentMetadata planFragmentMetadata = new PlanFragmentMetadata.Builder()
+    StageMetadata stageMetadata = new StageMetadata.Builder()
         .setWorkerMetadataList(Collections.singletonList(
             new WorkerMetadata.Builder().setVirtualServerAddress(_server).build())).build();
     OpChainExecutionContext context =
         new OpChainExecutionContext(_mailboxService, 0, SENDER_STAGE_ID, _server, Long.MAX_VALUE, Long.MAX_VALUE,
-            planFragmentMetadata, false);
+            stageMetadata, false);
     return new MailboxSendOperator(context, _sourceOperator, _exchange, null, null, false);
   }
 }
diff --git a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/OpChainTest.java b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/OpChainTest.java
index 585e884e65..e433ccf65c 100644
--- a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/OpChainTest.java
+++ b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/OpChainTest.java
@@ -42,7 +42,7 @@ import org.apache.pinot.query.mailbox.ReceivingMailbox;
 import org.apache.pinot.query.planner.logical.RexExpression;
 import org.apache.pinot.query.planner.physical.MailboxIdUtils;
 import org.apache.pinot.query.routing.MailboxMetadata;
-import org.apache.pinot.query.routing.PlanFragmentMetadata;
+import org.apache.pinot.query.routing.StageMetadata;
 import org.apache.pinot.query.routing.VirtualServerAddress;
 import org.apache.pinot.query.routing.WorkerMetadata;
 import org.apache.pinot.query.runtime.blocks.TransferableBlock;
@@ -84,13 +84,13 @@ public class OpChainTest {
   private BlockExchange _exchange;
 
   private VirtualServerAddress _serverAddress;
-  private PlanFragmentMetadata _receivingPlanFragmentMetadata;
+  private StageMetadata _receivingStageMetadata;
 
   @BeforeMethod
   public void setUp() {
     _mocks = MockitoAnnotations.openMocks(this);
     _serverAddress = new VirtualServerAddress("localhost", 123, 0);
-    _receivingPlanFragmentMetadata = new PlanFragmentMetadata.Builder()
+    _receivingStageMetadata = new StageMetadata.Builder()
         .setWorkerMetadataList(Stream.of(_serverAddress).map(
             s -> new WorkerMetadata.Builder()
                 .setVirtualServerAddress(s)
@@ -199,7 +199,7 @@ public class OpChainTest {
     int senderStageId = 1;
     OpChainExecutionContext context =
         new OpChainExecutionContext(_mailboxService1, 1, senderStageId, _serverAddress, 1000,
-            System.currentTimeMillis() + 1000, _receivingPlanFragmentMetadata, true);
+            System.currentTimeMillis() + 1000, _receivingStageMetadata, true);
 
     Stack<MultiStageOperator> operators =
         getFullOpchain(receivedStageId, senderStageId, context, dummyOperatorWaitTime);
@@ -213,7 +213,7 @@ public class OpChainTest {
 
     OpChainExecutionContext secondStageContext =
         new OpChainExecutionContext(_mailboxService2, 1, senderStageId + 1, _serverAddress, 1000,
-            System.currentTimeMillis() + 1000, _receivingPlanFragmentMetadata, true);
+            System.currentTimeMillis() + 1000, _receivingStageMetadata, true);
 
     MailboxReceiveOperator secondStageReceiveOp =
         new MailboxReceiveOperator(secondStageContext, RelDistribution.Type.BROADCAST_DISTRIBUTED, senderStageId + 1);
@@ -239,7 +239,7 @@ public class OpChainTest {
     int senderStageId = 1;
     OpChainExecutionContext context =
         new OpChainExecutionContext(_mailboxService1, 1, senderStageId, _serverAddress, 1000,
-            System.currentTimeMillis() + 1000, _receivingPlanFragmentMetadata, false);
+            System.currentTimeMillis() + 1000, _receivingStageMetadata, false);
 
     Stack<MultiStageOperator> operators =
         getFullOpchain(receivedStageId, senderStageId, context, dummyOperatorWaitTime);
@@ -251,7 +251,7 @@ public class OpChainTest {
 
     OpChainExecutionContext secondStageContext =
         new OpChainExecutionContext(_mailboxService2, 1, senderStageId + 1, _serverAddress, 1000,
-            System.currentTimeMillis() + 1000, _receivingPlanFragmentMetadata, false);
+            System.currentTimeMillis() + 1000, _receivingStageMetadata, false);
     MailboxReceiveOperator secondStageReceiveOp =
         new MailboxReceiveOperator(secondStageContext, RelDistribution.Type.BROADCAST_DISTRIBUTED, senderStageId);
 
diff --git a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/SortedMailboxReceiveOperatorTest.java b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/SortedMailboxReceiveOperatorTest.java
index c17739e789..1a65a8457f 100644
--- a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/SortedMailboxReceiveOperatorTest.java
+++ b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/SortedMailboxReceiveOperatorTest.java
@@ -35,7 +35,7 @@ import org.apache.pinot.query.mailbox.MailboxService;
 import org.apache.pinot.query.mailbox.ReceivingMailbox;
 import org.apache.pinot.query.planner.logical.RexExpression;
 import org.apache.pinot.query.routing.MailboxMetadata;
-import org.apache.pinot.query.routing.PlanFragmentMetadata;
+import org.apache.pinot.query.routing.StageMetadata;
 import org.apache.pinot.query.routing.VirtualServerAddress;
 import org.apache.pinot.query.routing.WorkerMetadata;
 import org.apache.pinot.query.runtime.blocks.TransferableBlock;
@@ -73,8 +73,8 @@ public class SortedMailboxReceiveOperatorTest {
   @Mock
   private ReceivingMailbox _mailbox2;
 
-  private PlanFragmentMetadata _planFragmentMetadataBoth;
-  private PlanFragmentMetadata _planFragmentMetadata1;
+  private StageMetadata _stageMetadataBoth;
+  private StageMetadata _stageMetadata1;
 
   @BeforeMethod
   public void setUp() {
@@ -83,7 +83,7 @@ public class SortedMailboxReceiveOperatorTest {
     when(_mailboxService.getPort()).thenReturn(123);
     VirtualServerAddress server1 = new VirtualServerAddress("localhost", 123, 0);
     VirtualServerAddress server2 = new VirtualServerAddress("localhost", 123, 1);
-    _planFragmentMetadataBoth = new PlanFragmentMetadata.Builder()
+    _stageMetadataBoth = new StageMetadata.Builder()
         .setWorkerMetadataList(Stream.of(server1, server2).map(
                 s -> new WorkerMetadata.Builder()
                     .setVirtualServerAddress(s)
@@ -103,7 +103,7 @@ public class SortedMailboxReceiveOperatorTest {
             .collect(Collectors.toList()))
         .build();
     // sending stage is 0, receiving stage is 1
-    _planFragmentMetadata1 = new PlanFragmentMetadata.Builder()
+    _stageMetadata1 = new StageMetadata.Builder()
         .setWorkerMetadataList(Stream.of(server1).map(
             s -> new WorkerMetadata.Builder()
                 .setVirtualServerAddress(s)
@@ -129,13 +129,13 @@ public class SortedMailboxReceiveOperatorTest {
   public void shouldThrowSingletonNoMatchMailboxServer() {
     VirtualServerAddress server1 = new VirtualServerAddress("localhost", 456, 0);
     VirtualServerAddress server2 = new VirtualServerAddress("localhost", 789, 1);
-    PlanFragmentMetadata planFragmentMetadata = new PlanFragmentMetadata.Builder()
+    StageMetadata stageMetadata = new StageMetadata.Builder()
         .setWorkerMetadataList(Stream.of(server1, server2).map(
             s -> new WorkerMetadata.Builder().setVirtualServerAddress(s).build()).collect(Collectors.toList()))
         .build();
     OpChainExecutionContext context =
         new OpChainExecutionContext(_mailboxService, 0, 0, RECEIVER_ADDRESS, Long.MAX_VALUE, Long.MAX_VALUE,
-            planFragmentMetadata, false);
+            stageMetadata, false);
     //noinspection resource
     new SortedMailboxReceiveOperator(context, RelDistribution.Type.SINGLETON, DATA_SCHEMA, COLLATION_KEYS,
         COLLATION_DIRECTIONS, false, 1);
@@ -145,7 +145,7 @@ public class SortedMailboxReceiveOperatorTest {
   public void shouldThrowReceiveSingletonFromMultiMatchMailboxServer() {
     OpChainExecutionContext context =
         new OpChainExecutionContext(_mailboxService, 0, 0, RECEIVER_ADDRESS, Long.MAX_VALUE, Long.MAX_VALUE,
-            _planFragmentMetadataBoth, false);
+            _stageMetadataBoth, false);
     //noinspection resource
     new SortedMailboxReceiveOperator(context, RelDistribution.Type.SINGLETON, DATA_SCHEMA, COLLATION_KEYS,
         COLLATION_DIRECTIONS, false, 1);
@@ -166,7 +166,7 @@ public class SortedMailboxReceiveOperatorTest {
     when(_mailboxService.getReceivingMailbox(MAILBOX_ID_1)).thenReturn(_mailbox1);
     OpChainExecutionContext context =
         new OpChainExecutionContext(_mailboxService, 0, 0, RECEIVER_ADDRESS, 10L, System.currentTimeMillis() + 10L,
-            _planFragmentMetadata1, false);
+            _stageMetadata1, false);
     //noinspection resource
     new SortedMailboxReceiveOperator(context, RelDistribution.Type.SINGLETON, DATA_SCHEMA, Collections.emptyList(),
         Collections.emptyList(), false, 1);
@@ -179,7 +179,7 @@ public class SortedMailboxReceiveOperatorTest {
     // Short timeoutMs should result in timeout
     OpChainExecutionContext context =
         new OpChainExecutionContext(_mailboxService, 0, 0, RECEIVER_ADDRESS, 10L, System.currentTimeMillis() + 10L,
-            _planFragmentMetadata1, false);
+            _stageMetadata1, false);
     try (SortedMailboxReceiveOperator receiveOp = new SortedMailboxReceiveOperator(context,
         RelDistribution.Type.SINGLETON, DATA_SCHEMA, COLLATION_KEYS, COLLATION_DIRECTIONS, false, 1)) {
       Thread.sleep(100L);
@@ -191,7 +191,7 @@ public class SortedMailboxReceiveOperatorTest {
 
     // Longer timeout or default timeout (10s) doesn't result in timeout
     context = new OpChainExecutionContext(_mailboxService, 0, 0, RECEIVER_ADDRESS, 10_000L,
-        System.currentTimeMillis() + 10_000L, _planFragmentMetadata1, false);
+        System.currentTimeMillis() + 10_000L, _stageMetadata1, false);
     try (SortedMailboxReceiveOperator receiveOp = new SortedMailboxReceiveOperator(context,
         RelDistribution.Type.SINGLETON, DATA_SCHEMA, COLLATION_KEYS, COLLATION_DIRECTIONS, false, 1)) {
       Thread.sleep(100L);
@@ -205,7 +205,7 @@ public class SortedMailboxReceiveOperatorTest {
     when(_mailboxService.getReceivingMailbox(MAILBOX_ID_1)).thenReturn(_mailbox1);
     OpChainExecutionContext context =
         new OpChainExecutionContext(_mailboxService, 0, 0, RECEIVER_ADDRESS, Long.MAX_VALUE, Long.MAX_VALUE,
-            _planFragmentMetadata1, false);
+            _stageMetadata1, false);
     try (SortedMailboxReceiveOperator receiveOp = new SortedMailboxReceiveOperator(context,
         RelDistribution.Type.SINGLETON, DATA_SCHEMA, COLLATION_KEYS, COLLATION_DIRECTIONS, false, 1)) {
       assertTrue(receiveOp.nextBlock().isNoOpBlock());
@@ -218,7 +218,7 @@ public class SortedMailboxReceiveOperatorTest {
     when(_mailbox1.poll()).thenReturn(TransferableBlockUtils.getEndOfStreamTransferableBlock());
     OpChainExecutionContext context =
         new OpChainExecutionContext(_mailboxService, 0, 0, RECEIVER_ADDRESS, Long.MAX_VALUE, Long.MAX_VALUE,
-            _planFragmentMetadata1, false);
+            _stageMetadata1, false);
     try (SortedMailboxReceiveOperator receiveOp = new SortedMailboxReceiveOperator(context,
         RelDistribution.Type.SINGLETON, DATA_SCHEMA, COLLATION_KEYS, COLLATION_DIRECTIONS, false, 1)) {
       assertTrue(receiveOp.nextBlock().isEndOfStreamBlock());
@@ -233,7 +233,7 @@ public class SortedMailboxReceiveOperatorTest {
         TransferableBlockUtils.getEndOfStreamTransferableBlock());
     OpChainExecutionContext context =
         new OpChainExecutionContext(_mailboxService, 0, 0, RECEIVER_ADDRESS, Long.MAX_VALUE, Long.MAX_VALUE,
-            _planFragmentMetadata1, false);
+            _stageMetadata1, false);
     try (SortedMailboxReceiveOperator receiveOp = new SortedMailboxReceiveOperator(context,
         RelDistribution.Type.SINGLETON, DATA_SCHEMA, COLLATION_KEYS, COLLATION_DIRECTIONS, false, 1)) {
       List<Object[]> actualRows = receiveOp.nextBlock().getContainer();
@@ -251,7 +251,7 @@ public class SortedMailboxReceiveOperatorTest {
         TransferableBlockUtils.getErrorTransferableBlock(new RuntimeException(errorMessage)));
     OpChainExecutionContext context =
         new OpChainExecutionContext(_mailboxService, 0, 0, RECEIVER_ADDRESS, Long.MAX_VALUE, Long.MAX_VALUE,
-            _planFragmentMetadata1, false);
+            _stageMetadata1, false);
     try (SortedMailboxReceiveOperator receiveOp = new SortedMailboxReceiveOperator(context,
         RelDistribution.Type.SINGLETON, DATA_SCHEMA, COLLATION_KEYS, COLLATION_DIRECTIONS, false, 1)) {
       TransferableBlock block = receiveOp.nextBlock();
@@ -270,7 +270,7 @@ public class SortedMailboxReceiveOperatorTest {
         TransferableBlockUtils.getEndOfStreamTransferableBlock());
     OpChainExecutionContext context =
         new OpChainExecutionContext(_mailboxService, 0, 0, RECEIVER_ADDRESS, Long.MAX_VALUE, Long.MAX_VALUE,
-            _planFragmentMetadataBoth, false);
+            _stageMetadataBoth, false);
     try (SortedMailboxReceiveOperator receiveOp = new SortedMailboxReceiveOperator(context,
         RelDistribution.Type.HASH_DISTRIBUTED, DATA_SCHEMA, COLLATION_KEYS, COLLATION_DIRECTIONS, false, 1)) {
       assertTrue(receiveOp.nextBlock().isNoOpBlock());
@@ -293,7 +293,7 @@ public class SortedMailboxReceiveOperatorTest {
         TransferableBlockUtils.getEndOfStreamTransferableBlock());
     OpChainExecutionContext context =
         new OpChainExecutionContext(_mailboxService, 0, 0, RECEIVER_ADDRESS, Long.MAX_VALUE, Long.MAX_VALUE,
-            _planFragmentMetadataBoth, false);
+            _stageMetadataBoth, false);
     try (SortedMailboxReceiveOperator receiveOp = new SortedMailboxReceiveOperator(context,
         RelDistribution.Type.HASH_DISTRIBUTED, DATA_SCHEMA, COLLATION_KEYS, COLLATION_DIRECTIONS, false, 1)) {
       TransferableBlock block = receiveOp.nextBlock();
@@ -318,7 +318,7 @@ public class SortedMailboxReceiveOperatorTest {
         TransferableBlockUtils.getEndOfStreamTransferableBlock());
     OpChainExecutionContext context =
         new OpChainExecutionContext(_mailboxService, 0, 0, RECEIVER_ADDRESS, Long.MAX_VALUE, Long.MAX_VALUE,
-            _planFragmentMetadataBoth, false);
+            _stageMetadataBoth, false);
     try (SortedMailboxReceiveOperator receiveOp = new SortedMailboxReceiveOperator(context,
         RelDistribution.Type.HASH_DISTRIBUTED, DATA_SCHEMA, COLLATION_KEYS, COLLATION_DIRECTIONS, false, 1)) {
       assertEquals(receiveOp.nextBlock().getContainer(), Arrays.asList(row5, row2, row4, row1, row3));
@@ -349,7 +349,7 @@ public class SortedMailboxReceiveOperatorTest {
 
     OpChainExecutionContext context =
         new OpChainExecutionContext(_mailboxService, 0, 0, RECEIVER_ADDRESS, Long.MAX_VALUE, Long.MAX_VALUE,
-            _planFragmentMetadataBoth, false);
+            _stageMetadataBoth, false);
     try (SortedMailboxReceiveOperator receiveOp = new SortedMailboxReceiveOperator(context,
         RelDistribution.Type.HASH_DISTRIBUTED, dataSchema, collationKeys, collationDirection, false, 1)) {
       assertEquals(receiveOp.nextBlock().getContainer(), Arrays.asList(row1, row2, row3, row5, row4));
diff --git a/pinot-query-runtime/src/test/java/org/apache/pinot/query/service/QueryServerTest.java b/pinot-query-runtime/src/test/java/org/apache/pinot/query/service/QueryServerTest.java
index ab7748850f..d9803acc17 100644
--- a/pinot-query-runtime/src/test/java/org/apache/pinot/query/service/QueryServerTest.java
+++ b/pinot-query-runtime/src/test/java/org/apache/pinot/query/service/QueryServerTest.java
@@ -36,10 +36,11 @@ import org.apache.pinot.core.routing.TimeBoundaryInfo;
 import org.apache.pinot.query.QueryEnvironment;
 import org.apache.pinot.query.QueryEnvironmentTestBase;
 import org.apache.pinot.query.QueryTestSet;
-import org.apache.pinot.query.planner.QueryPlan;
+import org.apache.pinot.query.planner.DispatchablePlanFragment;
+import org.apache.pinot.query.planner.DispatchableSubPlan;
 import org.apache.pinot.query.planner.plannode.PlanNode;
-import org.apache.pinot.query.routing.PlanFragmentMetadata;
 import org.apache.pinot.query.routing.QueryServerInstance;
+import org.apache.pinot.query.routing.StageMetadata;
 import org.apache.pinot.query.routing.VirtualServerAddress;
 import org.apache.pinot.query.routing.WorkerMetadata;
 import org.apache.pinot.query.runtime.QueryRunner;
@@ -61,12 +62,15 @@ public class QueryServerTest extends QueryTestSet {
   private static final int QUERY_SERVER_COUNT = 2;
   private static final String KEY_OF_SERVER_INSTANCE_HOST = "pinot.query.runner.server.hostname";
   private static final String KEY_OF_SERVER_INSTANCE_PORT = "pinot.query.runner.server.port";
-  private static final ExecutorService LEAF_WORKER_EXECUTOR_SERVICE = Executors.newFixedThreadPool(
-      ResourceManager.DEFAULT_QUERY_WORKER_THREADS, new NamedThreadFactory("QueryDispatcherTest_LeafWorker"));
-  private static final ExecutorService INTERM_WORKER_EXECUTOR_SERVICE = Executors.newFixedThreadPool(
-      ResourceManager.DEFAULT_QUERY_WORKER_THREADS, new NamedThreadFactory("QueryDispatcherTest_IntermWorker"));
-  private static final ExecutorService RUNNER_EXECUTOR_SERVICE = Executors.newFixedThreadPool(
-      ResourceManager.DEFAULT_QUERY_RUNNER_THREADS, new NamedThreadFactory("QueryServerTest_Runner"));
+  private static final ExecutorService LEAF_WORKER_EXECUTOR_SERVICE =
+      Executors.newFixedThreadPool(ResourceManager.DEFAULT_QUERY_WORKER_THREADS,
+          new NamedThreadFactory("QueryDispatcherTest_LeafWorker"));
+  private static final ExecutorService INTERM_WORKER_EXECUTOR_SERVICE =
+      Executors.newFixedThreadPool(ResourceManager.DEFAULT_QUERY_WORKER_THREADS,
+          new NamedThreadFactory("QueryDispatcherTest_IntermWorker"));
+  private static final ExecutorService RUNNER_EXECUTOR_SERVICE =
+      Executors.newFixedThreadPool(ResourceManager.DEFAULT_QUERY_RUNNER_THREADS,
+          new NamedThreadFactory("QueryServerTest_Runner"));
 
   private final Map<Integer, QueryServer> _queryServerMap = new HashMap<>();
   private final Map<Integer, QueryRunner> _queryRunnerMap = new HashMap<>();
@@ -108,32 +112,36 @@ public class QueryServerTest extends QueryTestSet {
   @Test(dataProvider = "testSql")
   public void testWorkerAcceptsWorkerRequestCorrect(String sql)
       throws Exception {
-    QueryPlan queryPlan = _queryEnvironment.planQuery(sql);
+    DispatchableSubPlan dispatchableSubPlan = _queryEnvironment.planQuery(sql);
 
-    for (int stageId : queryPlan.getDispatchablePlanMetadataMap().keySet()) {
+    for (int stageId = 0; stageId < dispatchableSubPlan.getQueryStageList().size(); stageId++) {
       if (stageId > 0) { // we do not test reduce stage.
         // only get one worker request out.
-        Worker.QueryRequest queryRequest = getQueryRequest(queryPlan, stageId);
+        Worker.QueryRequest queryRequest = getQueryRequest(dispatchableSubPlan, stageId);
 
         // submit the request for testing.
         submitRequest(queryRequest);
 
-        PlanFragmentMetadata planFragmentMetadata = queryPlan.getStageMetadata(stageId);
+        DispatchablePlanFragment dispatchablePlanFragment = dispatchableSubPlan.getQueryStageList().get(stageId);
+
+        StageMetadata stageMetadata = dispatchablePlanFragment.toStageMetadata();
 
         // ensure mock query runner received correctly deserialized payload.
-        QueryRunner mockRunner = _queryRunnerMap.get(
-            Integer.parseInt(queryRequest.getMetadataOrThrow(KEY_OF_SERVER_INSTANCE_PORT)));
+        QueryRunner mockRunner =
+            _queryRunnerMap.get(Integer.parseInt(queryRequest.getMetadataOrThrow(KEY_OF_SERVER_INSTANCE_PORT)));
         String requestIdStr = queryRequest.getMetadataOrThrow(QueryConfig.KEY_OF_BROKER_REQUEST_ID);
 
         // since submitRequest is async, we need to wait for the mockRunner to receive the query payload.
+        int finalStageId = stageId;
         TestUtils.waitForCondition(aVoid -> {
           try {
             Mockito.verify(mockRunner).processQuery(Mockito.argThat(distributedStagePlan -> {
-              PlanNode planNode = queryPlan.getQueryStageMap().get(stageId);
-              return isStageNodesEqual(planNode, distributedStagePlan.getStageRoot())
-                  && isStageMetadataEqual(planFragmentMetadata, distributedStagePlan.getStageMetadata());
-            }), Mockito.argThat(requestMetadataMap ->
-                requestIdStr.equals(requestMetadataMap.get(QueryConfig.KEY_OF_BROKER_REQUEST_ID))));
+              PlanNode planNode =
+                  dispatchableSubPlan.getQueryStageList().get(finalStageId).getPlanFragment().getFragmentRoot();
+              return isStageNodesEqual(planNode, distributedStagePlan.getStageRoot()) && isStageMetadataEqual(
+                  stageMetadata, distributedStagePlan.getStageMetadata());
+            }), Mockito.argThat(requestMetadataMap -> requestIdStr.equals(
+                requestMetadataMap.get(QueryConfig.KEY_OF_BROKER_REQUEST_ID))));
             return true;
           } catch (Throwable t) {
             return false;
@@ -146,20 +154,21 @@ public class QueryServerTest extends QueryTestSet {
     }
   }
 
-  private boolean isStageMetadataEqual(PlanFragmentMetadata expected, PlanFragmentMetadata actual) {
-    if (!EqualityUtils.isEqual(PlanFragmentMetadata.getTableName(expected),
-        PlanFragmentMetadata.getTableName(actual))) {
+  private boolean isStageMetadataEqual(StageMetadata expected, StageMetadata actual) {
+    if (!EqualityUtils.isEqual(StageMetadata.getTableName(expected),
+        StageMetadata.getTableName(actual))) {
       return false;
     }
-    TimeBoundaryInfo expectedTimeBoundaryInfo = PlanFragmentMetadata.getTimeBoundary(expected);
-    TimeBoundaryInfo actualTimeBoundaryInfo = PlanFragmentMetadata.getTimeBoundary(actual);
+    TimeBoundaryInfo expectedTimeBoundaryInfo = StageMetadata.getTimeBoundary(expected);
+    TimeBoundaryInfo actualTimeBoundaryInfo = StageMetadata.getTimeBoundary(actual);
     if (expectedTimeBoundaryInfo == null && actualTimeBoundaryInfo != null
         || expectedTimeBoundaryInfo != null && actualTimeBoundaryInfo == null) {
       return false;
     }
-    if (expectedTimeBoundaryInfo != null && actualTimeBoundaryInfo != null
-        && (!EqualityUtils.isEqual(expectedTimeBoundaryInfo.getTimeColumn(), actualTimeBoundaryInfo.getTimeColumn())
-        || !EqualityUtils.isEqual(expectedTimeBoundaryInfo.getTimeValue(), actualTimeBoundaryInfo.getTimeValue()))) {
+    if (expectedTimeBoundaryInfo != null && actualTimeBoundaryInfo != null && (
+        !EqualityUtils.isEqual(expectedTimeBoundaryInfo.getTimeColumn(), actualTimeBoundaryInfo.getTimeColumn())
+            || !EqualityUtils.isEqual(expectedTimeBoundaryInfo.getTimeValue(),
+            actualTimeBoundaryInfo.getTimeValue()))) {
       return false;
     }
     List<WorkerMetadata> expectedWorkerMetadataList = expected.getWorkerMetadataList();
@@ -213,16 +222,16 @@ public class QueryServerTest extends QueryTestSet {
     channel.shutdown();
   }
 
-  private Worker.QueryRequest getQueryRequest(QueryPlan queryPlan, int stageId) {
+  private Worker.QueryRequest getQueryRequest(DispatchableSubPlan dispatchableSubPlan, int stageId) {
     Map<QueryServerInstance, List<Integer>> serverInstanceToWorkerIdMap =
-        queryPlan.getDispatchablePlanMetadataMap().get(stageId).getServerInstanceToWorkerIdMap();
+        dispatchableSubPlan.getQueryStageList().get(stageId).getServerInstanceToWorkerIdMap();
     // this particular test set requires the request to have a single QueryServerInstance to dispatch to
     // as it is not testing the multi-tenancy dispatch (which is in the QueryDispatcherTest)
     QueryServerInstance serverInstance = serverInstanceToWorkerIdMap.keySet().iterator().next();
     int workerId = serverInstanceToWorkerIdMap.get(serverInstance).get(0);
 
     return Worker.QueryRequest.newBuilder().setStagePlan(QueryPlanSerDeUtils.serialize(
-            QueryDispatcher.constructDistributedStagePlan(queryPlan, stageId,
+            QueryDispatcher.constructDistributedStagePlan(dispatchableSubPlan, stageId,
                 new VirtualServerAddress(serverInstance, workerId))))
         // the default configurations that must exist.
         .putMetadata(QueryConfig.KEY_OF_BROKER_REQUEST_ID, String.valueOf(RANDOM_REQUEST_ID_GEN.nextLong()))
diff --git a/pinot-query-runtime/src/test/java/org/apache/pinot/query/service/dispatch/QueryDispatcherTest.java b/pinot-query-runtime/src/test/java/org/apache/pinot/query/service/dispatch/QueryDispatcherTest.java
index 7e98f1a7ae..69d2dda026 100644
--- a/pinot-query-runtime/src/test/java/org/apache/pinot/query/service/dispatch/QueryDispatcherTest.java
+++ b/pinot-query-runtime/src/test/java/org/apache/pinot/query/service/dispatch/QueryDispatcherTest.java
@@ -33,8 +33,8 @@ import org.apache.pinot.core.query.scheduler.resources.ResourceManager;
 import org.apache.pinot.query.QueryEnvironment;
 import org.apache.pinot.query.QueryEnvironmentTestBase;
 import org.apache.pinot.query.QueryTestSet;
+import org.apache.pinot.query.planner.DispatchableSubPlan;
 import org.apache.pinot.query.planner.PlannerUtils;
-import org.apache.pinot.query.planner.QueryPlan;
 import org.apache.pinot.query.runtime.QueryRunner;
 import org.apache.pinot.query.service.QueryServer;
 import org.apache.pinot.query.testutils.QueryTestUtils;
@@ -70,7 +70,7 @@ public class QueryDispatcherTest extends QueryTestSet {
 
     for (int i = 0; i < QUERY_SERVER_COUNT; i++) {
       int availablePort = QueryTestUtils.getAvailablePort();
-      QueryRunner queryRunner = Mockito.mock(QueryRunner.class);;
+      QueryRunner queryRunner = Mockito.mock(QueryRunner.class);
       Mockito.when(queryRunner.getQueryWorkerLeafExecutorService()).thenReturn(LEAF_WORKER_EXECUTOR_SERVICE);
       Mockito.when(queryRunner.getQueryWorkerIntermExecutorService()).thenReturn(INTERM_WORKER_EXECUTOR_SERVICE);
       Mockito.when(queryRunner.getQueryRunnerExecutorService()).thenReturn(RUNNER_EXECUTOR_SERVICE);
@@ -99,9 +99,10 @@ public class QueryDispatcherTest extends QueryTestSet {
   @Test(dataProvider = "testSql")
   public void testQueryDispatcherCanSendCorrectPayload(String sql)
       throws Exception {
-    QueryPlan queryPlan = _queryEnvironment.planQuery(sql);
+    DispatchableSubPlan dispatchableSubPlan = _queryEnvironment.planQuery(sql);
     QueryDispatcher dispatcher = new QueryDispatcher();
-    int reducerStageId = dispatcher.submit(RANDOM_REQUEST_ID_GEN.nextLong(), queryPlan, 10_000L, new HashMap<>());
+    int reducerStageId =
+        dispatcher.submit(RANDOM_REQUEST_ID_GEN.nextLong(), dispatchableSubPlan, 10_000L, new HashMap<>());
     Assert.assertTrue(PlannerUtils.isRootPlanFragment(reducerStageId));
     dispatcher.shutdown();
   }
@@ -112,10 +113,10 @@ public class QueryDispatcherTest extends QueryTestSet {
     String sql = "SELECT * FROM a WHERE col1 = 'foo'";
     QueryServer failingQueryServer = _queryServerMap.values().iterator().next();
     Mockito.doThrow(new RuntimeException("foo")).when(failingQueryServer).submit(Mockito.any(), Mockito.any());
-    QueryPlan queryPlan = _queryEnvironment.planQuery(sql);
+    DispatchableSubPlan dispatchableSubPlan = _queryEnvironment.planQuery(sql);
     QueryDispatcher dispatcher = new QueryDispatcher();
     try {
-      dispatcher.submit(RANDOM_REQUEST_ID_GEN.nextLong(), queryPlan, 10_000L, new HashMap<>());
+      dispatcher.submit(RANDOM_REQUEST_ID_GEN.nextLong(), dispatchableSubPlan, 10_000L, new HashMap<>());
       Assert.fail("Method call above should have failed");
     } catch (Exception e) {
       Assert.assertTrue(e.getMessage().contains("Error dispatching query"));
@@ -137,11 +138,11 @@ public class QueryDispatcherTest extends QueryTestSet {
         return null;
       }
     }).when(failingQueryServer).submit(Mockito.any(), Mockito.any());
-    QueryPlan queryPlan = _queryEnvironment.planQuery(sql);
+    DispatchableSubPlan dispatchableSubPlan = _queryEnvironment.planQuery(sql);
     QueryDispatcher dispatcher = new QueryDispatcher();
     long requestId = RANDOM_REQUEST_ID_GEN.nextLong();
     try {
-      dispatcher.submitAndReduce(requestId, queryPlan, null, 10_000L, new HashMap<>(), null, false);
+      dispatcher.submitAndReduce(requestId, dispatchableSubPlan, null, 10_000L, new HashMap<>(), null, false);
       Assert.fail("Method call above should have failed");
     } catch (Exception e) {
       Assert.assertTrue(e.getMessage().contains("Error executing query"));
@@ -159,14 +160,15 @@ public class QueryDispatcherTest extends QueryTestSet {
   public void testQueryDispatcherCancelWhenQueryReducerThrowsError()
       throws Exception {
     String sql = "SELECT * FROM a";
-    QueryPlan queryPlan = _queryEnvironment.planQuery(sql);
+    DispatchableSubPlan dispatchableSubPlan = _queryEnvironment.planQuery(sql);
     QueryDispatcher dispatcher = new QueryDispatcher();
     long requestId = RANDOM_REQUEST_ID_GEN.nextLong();
     try {
       // will throw b/c mailboxService is null
-      dispatcher.submitAndReduce(requestId, queryPlan, null, 10_000L, new HashMap<>(), null, false);
+      dispatcher.submitAndReduce(requestId, dispatchableSubPlan, null, 10_000L, new HashMap<>(), null, false);
       Assert.fail("Method call above should have failed");
     } catch (Exception e) {
+      System.out.println("e = " + e);
       Assert.assertTrue(e.getMessage().contains("Error executing query"));
     }
     // wait just a little, until the cancel is being called.
@@ -192,10 +194,10 @@ public class QueryDispatcherTest extends QueryTestSet {
         return null;
       }
     }).when(failingQueryServer).submit(Mockito.any(), Mockito.any());
-    QueryPlan queryPlan = _queryEnvironment.planQuery(sql);
+    DispatchableSubPlan dispatchableSubPlan = _queryEnvironment.planQuery(sql);
     QueryDispatcher dispatcher = new QueryDispatcher();
     try {
-      dispatcher.submit(RANDOM_REQUEST_ID_GEN.nextLong(), queryPlan, 10_000L, new HashMap<>());
+      dispatcher.submit(RANDOM_REQUEST_ID_GEN.nextLong(), dispatchableSubPlan, 10_000L, new HashMap<>());
       Assert.fail("Method call above should have failed");
     } catch (Exception e) {
       Assert.assertTrue(e.getMessage().contains("Error dispatching query"));
@@ -218,10 +220,10 @@ public class QueryDispatcherTest extends QueryTestSet {
         return null;
       }
     }).when(failingQueryServer).submit(Mockito.any(), Mockito.any());
-    QueryPlan queryPlan = _queryEnvironment.planQuery(sql);
+    DispatchableSubPlan dispatchableSubPlan = _queryEnvironment.planQuery(sql);
     QueryDispatcher dispatcher = new QueryDispatcher();
     try {
-      dispatcher.submit(RANDOM_REQUEST_ID_GEN.nextLong(), queryPlan, 1_000, new HashMap<>());
+      dispatcher.submit(RANDOM_REQUEST_ID_GEN.nextLong(), dispatchableSubPlan, 1_000, new HashMap<>());
       Assert.fail("Method call above should have failed");
     } catch (Exception e) {
       Assert.assertTrue(e.getMessage().contains("Timed out waiting for response")
@@ -234,10 +236,10 @@ public class QueryDispatcherTest extends QueryTestSet {
   @Test
   public void testQueryDispatcherThrowsWhenDeadlinePreExpiredAndAsyncResponseNotPolled() {
     String sql = "SELECT * FROM a WHERE col1 = 'foo'";
-    QueryPlan queryPlan = _queryEnvironment.planQuery(sql);
+    DispatchableSubPlan dispatchableSubPlan = _queryEnvironment.planQuery(sql);
     QueryDispatcher dispatcher = new QueryDispatcher();
     try {
-      dispatcher.submit(RANDOM_REQUEST_ID_GEN.nextLong(), queryPlan, -10_000, new HashMap<>());
+      dispatcher.submit(RANDOM_REQUEST_ID_GEN.nextLong(), dispatchableSubPlan, -10_000, new HashMap<>());
       Assert.fail("Method call above should have failed");
     } catch (Exception e) {
       Assert.assertTrue(e.getMessage().contains("Timed out waiting"));


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@pinot.apache.org
For additional commands, e-mail: commits-help@pinot.apache.org