You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@pinot.apache.org by ro...@apache.org on 2023/04/28 01:56:41 UTC

[pinot] branch master updated: [multistage] clean up stage metadata and mailbox usage on instance/server (#10673)

This is an automated email from the ASF dual-hosted git repository.

rongr pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git


The following commit(s) were added to refs/heads/master by this push:
     new f5dba86d5c [multistage] clean up stage metadata and mailbox usage on instance/server (#10673)
f5dba86d5c is described below

commit f5dba86d5cb19eeee9c8eb7a7cff81d92fe9958b
Author: Rong Rong <ro...@apache.org>
AuthorDate: Thu Apr 27 18:56:33 2023 -0700

    [multistage] clean up stage metadata and mailbox usage on instance/server (#10673)
    
    * [cleanup] clean up stage metadata usage
    
    ---------
    
    Co-authored-by: Rong Rong <ro...@startree.ai>
---
 .../MultiStageBrokerRequestHandler.java            |   2 +-
 pinot-common/src/main/proto/worker.proto           |  20 +--
 .../query/planner/ExplainPlanStageVisitor.java     |  82 ++++++-----
 .../org/apache/pinot/query/planner/QueryPlan.java  |  73 ++++++++--
 .../pinot/query/planner/logical/StagePlanner.java  |   9 +-
 .../planner/physical/DispatchablePlanContext.java  |  37 +++--
 .../DispatchablePlanMetadata.java}                 |  52 ++++---
 .../planner/physical/DispatchablePlanVisitor.java  |  79 ++++++----
 .../colocated/GreedyShuffleRewriteVisitor.java     |  46 +++---
 ...VirtualServer.java => QueryServerInstance.java} |  55 +++----
 .../apache/pinot/query/routing/StageMetadata.java  |  92 ++++++++++++
 .../pinot/query/routing/VirtualServerAddress.java  |  30 ++--
 .../apache/pinot/query/routing/WorkerInstance.java |  56 --------
 .../apache/pinot/query/routing/WorkerManager.java  | 160 +++++++++++----------
 .../apache/pinot/query/routing/WorkerMetadata.java | 106 ++++++++++++++
 .../apache/pinot/query/QueryCompilationTest.java   |  64 +++++----
 .../query/testutils/MockRoutingManagerFactory.java |  20 ++-
 .../apache/pinot/query/runtime/QueryRunner.java    |  40 +++---
 .../operator/BaseMailboxReceiveOperator.java       |  31 ++--
 .../runtime/operator/MailboxSendOperator.java      |  41 +++---
 .../runtime/operator/utils/OperatorUtils.java      |   9 +-
 .../query/runtime/plan/DistributedStagePlan.java   |  39 +++--
 .../runtime/plan/OpChainExecutionContext.java      |  18 +--
 .../query/runtime/plan/PlanRequestContext.java     |  13 +-
 .../runtime/plan/ServerRequestPlanVisitor.java     |   5 +-
 .../runtime/plan/serde/QueryPlanSerDeUtils.java    | 129 +++++++----------
 .../plan/server/ServerPlanRequestContext.java      |   8 +-
 .../dispatch/AsyncQueryDispatchResponse.java       |   8 +-
 .../query/service/dispatch/DispatchClient.java     |   8 +-
 .../query/service/dispatch/DispatchObserver.java   |   7 +-
 .../query/service/dispatch/QueryDispatcher.java    |  59 ++++----
 .../pinot/query/runtime/QueryRunnerTest.java       |  16 +--
 .../pinot/query/runtime/QueryRunnerTestBase.java   |  28 ++--
 .../operator/MailboxReceiveOperatorTest.java       |  83 +++++------
 .../runtime/operator/MailboxSendOperatorTest.java  |  23 ++-
 .../pinot/query/runtime/operator/OpChainTest.java  |  47 +++---
 .../query/runtime/operator/OperatorTestUtil.java   |   8 +-
 .../operator/SortedMailboxReceiveOperatorTest.java | 101 ++++++-------
 .../plan/serde/QueryPlanSerDeUtilsTest.java        |  32 ++---
 .../runtime/queries/ResourceBasedQueriesTest.java  |   6 +-
 .../pinot/query/service/QueryServerTest.java       |  83 ++++++++---
 41 files changed, 1047 insertions(+), 778 deletions(-)

diff --git a/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/MultiStageBrokerRequestHandler.java b/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/MultiStageBrokerRequestHandler.java
index 25d60f6a3a..860592392f 100644
--- a/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/MultiStageBrokerRequestHandler.java
+++ b/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/MultiStageBrokerRequestHandler.java
@@ -201,7 +201,7 @@ public class MultiStageBrokerRequestHandler extends BaseBrokerRequestHandler {
 
     ResultTable queryResults;
     Map<Integer, ExecutionStatsAggregator> stageIdStatsMap = new HashMap<>();
-    for (Integer stageId : queryPlan.getStageMetadataMap().keySet()) {
+    for (Integer stageId : queryPlan.getDispatchablePlanMetadataMap().keySet()) {
       stageIdStatsMap.put(stageId, new ExecutionStatsAggregator(traceEnabled));
     }
 
diff --git a/pinot-common/src/main/proto/worker.proto b/pinot-common/src/main/proto/worker.proto
index 0370ad0eac..e6e5db8222 100644
--- a/pinot-common/src/main/proto/worker.proto
+++ b/pinot-common/src/main/proto/worker.proto
@@ -70,23 +70,17 @@ message QueryResponse {
 
 message StagePlan {
   int32 stageId = 1;
-  string instanceId = 2;
+  string virtualAddress = 2;
   StageNode stageRoot = 3;
-  map<int32, StageMetadata> stageMetadata = 4;
+  repeated StageMetadata stageMetadata = 4;
 }
 
 message StageMetadata {
-  repeated string instances = 1;
-  repeated string dataSources = 2;
-  map<string, SegmentMap> instanceToSegmentMap = 3;
-  string timeColumn = 4;
-  string timeValue = 5;
+  repeated WorkerMetadata workerMetadata = 1;
+  map<string, string> customProperty = 2;
 }
 
-message SegmentMap {
-  map<string, SegmentList> tableTypeToSegmentList = 1;
-}
-
-message SegmentList {
-  repeated string segments = 1;
+message WorkerMetadata {
+  string virtualAddress = 1;
+  map<string, string> customProperty = 2;
 }
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/ExplainPlanStageVisitor.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/ExplainPlanStageVisitor.java
index 8f02f7018b..325c1cd7ad 100644
--- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/ExplainPlanStageVisitor.java
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/ExplainPlanStageVisitor.java
@@ -22,7 +22,7 @@ import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
 import java.util.stream.Collectors;
-import org.apache.pinot.core.transport.ServerInstance;
+import org.apache.pinot.query.planner.physical.DispatchablePlanMetadata;
 import org.apache.pinot.query.planner.stage.AggregateNode;
 import org.apache.pinot.query.planner.stage.FilterNode;
 import org.apache.pinot.query.planner.stage.JoinNode;
@@ -36,7 +36,7 @@ import org.apache.pinot.query.planner.stage.StageNodeVisitor;
 import org.apache.pinot.query.planner.stage.TableScanNode;
 import org.apache.pinot.query.planner.stage.ValueNode;
 import org.apache.pinot.query.planner.stage.WindowNode;
-import org.apache.pinot.query.routing.VirtualServer;
+import org.apache.pinot.query.routing.QueryServerInstance;
 
 
 /**
@@ -62,7 +62,8 @@ public class ExplainPlanStageVisitor implements StageNodeVisitor<StringBuilder,
     }
 
     // the root of a query plan always only has a single node
-    VirtualServer rootServer = queryPlan.getStageMetadataMap().get(0).getServerInstances().get(0);
+    QueryServerInstance rootServer = queryPlan.getDispatchablePlanMetadataMap().get(0).getServerInstanceToWorkerIdMap()
+        .keySet().iterator().next();
     return explainFrom(queryPlan, queryPlan.getQueryStageMap().get(0), rootServer);
   }
 
@@ -78,10 +79,10 @@ public class ExplainPlanStageVisitor implements StageNodeVisitor<StringBuilder,
    *
    * @return a query plan associated with
    */
-  public static String explainFrom(QueryPlan queryPlan, StageNode node, VirtualServer rootServer) {
+  public static String explainFrom(QueryPlan queryPlan, StageNode node, QueryServerInstance rootServer) {
     final ExplainPlanStageVisitor visitor = new ExplainPlanStageVisitor(queryPlan);
     return node
-        .visit(visitor, new Context(rootServer, "", "", new StringBuilder()))
+        .visit(visitor, new Context(rootServer, 0, "", "", new StringBuilder()))
         .toString();
   }
 
@@ -98,7 +99,7 @@ public class ExplainPlanStageVisitor implements StageNodeVisitor<StringBuilder,
         .append("]@")
         .append(context._host.getHostname())
         .append(':')
-        .append(context._host.getPort())
+        .append(context._host.getQueryServicePort())
         .append(' ')
         .append(node.explain());
     return context._builder;
@@ -106,7 +107,7 @@ public class ExplainPlanStageVisitor implements StageNodeVisitor<StringBuilder,
 
   private StringBuilder visitSimpleNode(StageNode node, Context context) {
     appendInfo(node, context).append('\n');
-    return node.getInputs().get(0).visit(this, context.next(false, context._host));
+    return node.getInputs().get(0).visit(this, context.next(false, context._host, context._workerId));
   }
 
   @Override
@@ -123,7 +124,7 @@ public class ExplainPlanStageVisitor implements StageNodeVisitor<StringBuilder,
   public StringBuilder visitSetOp(SetOpNode setOpNode, Context context) {
     appendInfo(setOpNode, context).append('\n');
     for (StageNode input : setOpNode.getInputs()) {
-      input.visit(this, context.next(false, context._host));
+      input.visit(this, context.next(false, context._host, context._workerId));
     }
     return context._builder;
   }
@@ -136,8 +137,8 @@ public class ExplainPlanStageVisitor implements StageNodeVisitor<StringBuilder,
   @Override
   public StringBuilder visitJoin(JoinNode node, Context context) {
     appendInfo(node, context).append('\n');
-    node.getInputs().get(0).visit(this, context.next(true, context._host));
-    node.getInputs().get(1).visit(this, context.next(false, context._host));
+    node.getInputs().get(0).visit(this, context.next(true, context._host, context._workerId));
+    node.getInputs().get(1).visit(this, context.next(false, context._host, context._workerId));
     return context._builder;
   }
 
@@ -147,24 +148,27 @@ public class ExplainPlanStageVisitor implements StageNodeVisitor<StringBuilder,
 
     MailboxSendNode sender = (MailboxSendNode) node.getSender();
     int senderStageId = node.getSenderStageId();
-    StageMetadata metadata = _queryPlan.getStageMetadataMap().get(senderStageId);
-    Map<ServerInstance, Map<String, List<String>>> segments = metadata.getServerInstanceToSegmentsMap();
+    DispatchablePlanMetadata metadata = _queryPlan.getDispatchablePlanMetadataMap().get(senderStageId);
+    Map<Integer, Map<String, List<String>>> segments = metadata.getWorkerIdToSegmentsMap();
 
-    Iterator<VirtualServer> iterator = metadata.getServerInstances().iterator();
+    Map<QueryServerInstance, List<Integer>> serverInstanceToWorkerIdMap = metadata.getServerInstanceToWorkerIdMap();
+    Iterator<QueryServerInstance> iterator = serverInstanceToWorkerIdMap.keySet().iterator();
     while (iterator.hasNext()) {
-      VirtualServer serverInstance = iterator.next();
-      if (segments.containsKey(serverInstance)) {
-        // always print out leaf stages
-        sender.visit(this, context.next(iterator.hasNext(), serverInstance));
-      } else {
-        if (!iterator.hasNext()) {
-          // always print out the last one
-          sender.visit(this, context.next(false, serverInstance));
+      QueryServerInstance queryServerInstance = iterator.next();
+      for (int workerId : serverInstanceToWorkerIdMap.get(queryServerInstance)) {
+        if (segments.containsKey(workerId)) {
+          // always print out leaf stages
+          sender.visit(this, context.next(iterator.hasNext(), queryServerInstance, workerId));
         } else {
-          // only print short version of the sender node
-          appendMailboxSend(sender, context.next(true, serverInstance))
-              .append(" (Subtree Omitted)")
-              .append('\n');
+          if (!iterator.hasNext()) {
+            // always print out the last one
+            sender.visit(this, context.next(false, queryServerInstance, workerId));
+          } else {
+            // only print short version of the sender node
+            appendMailboxSend(sender, context.next(true, queryServerInstance, workerId))
+                .append(" (Subtree Omitted)")
+                .append('\n');
+          }
         }
       }
     }
@@ -174,17 +178,18 @@ public class ExplainPlanStageVisitor implements StageNodeVisitor<StringBuilder,
   @Override
   public StringBuilder visitMailboxSend(MailboxSendNode node, Context context) {
     appendMailboxSend(node, context).append('\n');
-    return node.getInputs().get(0).visit(this, context.next(false, context._host));
+    return node.getInputs().get(0).visit(this, context.next(false, context._host, context._workerId));
   }
 
   private StringBuilder appendMailboxSend(MailboxSendNode node, Context context) {
     appendInfo(node, context);
 
     int receiverStageId = node.getReceiverStageId();
-    List<VirtualServer> servers = _queryPlan.getStageMetadataMap().get(receiverStageId).getServerInstances();
+    Map<QueryServerInstance, List<Integer>> servers = _queryPlan.getDispatchablePlanMetadataMap().get(receiverStageId)
+        .getServerInstanceToWorkerIdMap();
     context._builder.append("->");
-    String receivers = servers.stream()
-        .map(VirtualServer::toString)
+    String receivers = servers.entrySet().stream()
+        .map(ExplainPlanStageVisitor::stringifyQueryServerInstanceToWorkerIdsEntry)
         .map(s -> "[" + receiverStageId + "]@" + s)
         .collect(Collectors.joining(",", "{", "}"));
     return context._builder.append(receivers);
@@ -204,10 +209,10 @@ public class ExplainPlanStageVisitor implements StageNodeVisitor<StringBuilder,
   public StringBuilder visitTableScan(TableScanNode node, Context context) {
     return appendInfo(node, context)
         .append(' ')
-        .append(_queryPlan.getStageMetadataMap()
+        .append(_queryPlan.getDispatchablePlanMetadataMap()
             .get(node.getStageId())
-            .getServerInstanceToSegmentsMap()
-            .get(context._host.getServer()))
+            .getWorkerIdToSegmentsMap()
+            .get(context._host))
         .append('\n');
   }
 
@@ -217,25 +222,32 @@ public class ExplainPlanStageVisitor implements StageNodeVisitor<StringBuilder,
   }
 
   static class Context {
-    final VirtualServer _host;
+    final QueryServerInstance _host;
+    final int _workerId;
     final String _prefix;
     final String _childPrefix;
     final StringBuilder _builder;
 
-    Context(VirtualServer host, String prefix, String childPrefix, StringBuilder builder) {
+    Context(QueryServerInstance host, int workerId, String prefix, String childPrefix, StringBuilder builder) {
       _host = host;
+      _workerId = workerId;
       _prefix = prefix;
       _childPrefix = childPrefix;
       _builder = builder;
     }
 
-    Context next(boolean hasMoreChildren, VirtualServer host) {
+    Context next(boolean hasMoreChildren, QueryServerInstance host, int workerId) {
       return new Context(
           host,
+          workerId,
           hasMoreChildren ? _childPrefix + "├── " : _childPrefix + "└── ",
           hasMoreChildren ? _childPrefix + "│   " : _childPrefix + "   ",
           _builder
       );
     }
   }
+
+  public static String stringifyQueryServerInstanceToWorkerIdsEntry(Map.Entry<QueryServerInstance, List<Integer>> e) {
+    return e.getKey() + "|" + e.getValue();
+  }
 }
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/QueryPlan.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/QueryPlan.java
index 651d259979..7fbe422656 100644
--- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/QueryPlan.java
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/QueryPlan.java
@@ -18,34 +18,42 @@
  */
 package org.apache.pinot.query.planner;
 
+import java.util.Arrays;
 import java.util.List;
 import java.util.Map;
 import org.apache.calcite.util.Pair;
-import org.apache.pinot.query.planner.logical.LogicalPlanner;
+import org.apache.pinot.query.planner.physical.DispatchablePlanMetadata;
 import org.apache.pinot.query.planner.stage.StageNode;
+import org.apache.pinot.query.routing.QueryServerInstance;
+import org.apache.pinot.query.routing.StageMetadata;
+import org.apache.pinot.query.routing.VirtualServerAddress;
+import org.apache.pinot.query.routing.WorkerMetadata;
 
 
 /**
- * The {@code QueryPlan} is the dispatchable query execution plan from the result of {@link LogicalPlanner}.
+ * The {@code QueryPlan} is the dispatchable query execution plan from the result of
+ * {@link org.apache.pinot.query.planner.logical.StagePlanner}.
  *
  * <p>QueryPlan should contain the necessary stage boundary information and the cross exchange information
  * for:
  * <ul>
  *   <li>dispatch individual stages to executor.</li>
- *   <li>instruct stage executor to establish connection channels to other stages.</li>
- *   <li>encode data blocks for transfer between stages based on partitioning scheme.</li>
+ *   <li>instruction for stage executor to establish connection channels to other stages.</li>
+ *   <li>instruction for encoding data blocks & transferring between stages based on partitioning scheme.</li>
  * </ul>
  */
 public class QueryPlan {
   private final List<Pair<Integer, String>> _queryResultFields;
   private final Map<Integer, StageNode> _queryStageMap;
-  private final Map<Integer, StageMetadata> _stageMetadataMap;
+  private final List<StageMetadata> _stageMetadataList;
+  private final Map<Integer, DispatchablePlanMetadata> _dispatchablePlanMetadataMap;
 
   public QueryPlan(List<Pair<Integer, String>> fields, Map<Integer, StageNode> queryStageMap,
-      Map<Integer, StageMetadata> stageMetadataMap) {
+      Map<Integer, DispatchablePlanMetadata> dispatchablePlanMetadataMap) {
     _queryResultFields = fields;
     _queryStageMap = queryStageMap;
-    _stageMetadataMap = stageMetadataMap;
+    _dispatchablePlanMetadataMap = dispatchablePlanMetadataMap;
+    _stageMetadataList = constructStageMetadataList(_dispatchablePlanMetadataMap);
   }
 
   /**
@@ -60,8 +68,16 @@ public class QueryPlan {
    * Get the stage metadata information.
    * @return stage metadata info.
    */
-  public Map<Integer, StageMetadata> getStageMetadataMap() {
-    return _stageMetadataMap;
+  public List<StageMetadata> getStageMetadataList() {
+    return _stageMetadataList;
+  }
+
+  /**
+   * Get the dispatch metadata information.
+   * @return dispatch metadata info.
+   */
+  public Map<Integer, DispatchablePlanMetadata> getDispatchablePlanMetadataMap() {
+    return _dispatchablePlanMetadataMap;
   }
 
   /**
@@ -84,4 +100,43 @@ public class QueryPlan {
   public String explain() {
     return ExplainPlanStageVisitor.explain(this);
   }
+
+  /**
+   * Convert the {@link DispatchablePlanMetadata} into dispatchable info for each stage/worker.
+   */
+  private static List<StageMetadata> constructStageMetadataList(
+      Map<Integer, DispatchablePlanMetadata> dispatchablePlanMetadataMap) {
+    StageMetadata[] stageMetadataList = new StageMetadata[dispatchablePlanMetadataMap.size()];
+    for (Map.Entry<Integer, DispatchablePlanMetadata> dispatchableEntry : dispatchablePlanMetadataMap.entrySet()) {
+      DispatchablePlanMetadata dispatchablePlanMetadata = dispatchableEntry.getValue();
+
+      // construct each worker metadata
+      WorkerMetadata[] workerMetadataList = new WorkerMetadata[dispatchablePlanMetadata.getTotalWorkerCount()];
+      for (Map.Entry<QueryServerInstance, List<Integer>> queryServerEntry
+          : dispatchablePlanMetadata.getServerInstanceToWorkerIdMap().entrySet()) {
+        for (int workerId : queryServerEntry.getValue()) {
+          VirtualServerAddress virtualServerAddress = new VirtualServerAddress(queryServerEntry.getKey(), workerId);
+          WorkerMetadata.Builder builder = new WorkerMetadata.Builder();
+          builder.setVirtualServerAddress(virtualServerAddress);
+          if (dispatchablePlanMetadata.getScannedTables().size() == 1) {
+            builder.addTableSegmentsMap(dispatchablePlanMetadata.getWorkerIdToSegmentsMap().get(workerId));
+          }
+          workerMetadataList[workerId] = builder.build();
+        }
+      }
+
+      // construct the stageMetadata
+      int stageId = dispatchableEntry.getKey();
+      StageMetadata.Builder builder = new StageMetadata.Builder();
+      builder.setWorkerMetadataList(Arrays.asList(workerMetadataList));
+      if (dispatchablePlanMetadata.getScannedTables().size() == 1) {
+        builder.addTableName(dispatchablePlanMetadata.getScannedTables().get(0));
+      }
+      if (dispatchablePlanMetadata.getTimeBoundaryInfo() != null) {
+        builder.addTimeBoundaryInfo(dispatchablePlanMetadata.getTimeBoundaryInfo());
+      }
+      stageMetadataList[stageId] = builder.build();
+    }
+    return Arrays.asList(stageMetadataList);
+  }
 }
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/StagePlanner.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/StagePlanner.java
index 7b98d63893..ce0417ea0f 100644
--- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/StagePlanner.java
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/StagePlanner.java
@@ -86,11 +86,10 @@ public class StagePlanner {
             RelDistribution.Type.RANDOM_DISTRIBUTED, null, null, false, false, globalSenderNode);
 
     // perform physical plan conversion and assign workers to each stage.
-    DispatchablePlanContext physicalPlanContext = new DispatchablePlanContext(
-        _workerManager, _requestId, _plannerContext, relRoot.fields, tableNames
-    );
-    DispatchablePlanVisitor.INSTANCE.constructDispatchablePlan(globalReceiverNode, physicalPlanContext);
-    QueryPlan queryPlan = physicalPlanContext.getQueryPlan();
+    DispatchablePlanContext dispatchablePlanContext = new DispatchablePlanContext(_workerManager, _requestId,
+        _plannerContext, relRoot.fields, tableNames);
+    QueryPlan queryPlan = DispatchablePlanVisitor.INSTANCE.constructDispatchablePlan(globalReceiverNode,
+        dispatchablePlanContext);
 
     // Run physical optimizations
     runPhysicalOptimizers(queryPlan);
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/DispatchablePlanContext.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/DispatchablePlanContext.java
index 7f333775b7..227aa5429b 100644
--- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/DispatchablePlanContext.java
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/DispatchablePlanContext.java
@@ -20,33 +20,36 @@ package org.apache.pinot.query.planner.physical;
 
 import java.util.HashMap;
 import java.util.List;
+import java.util.Map;
 import java.util.Set;
 import org.apache.calcite.util.Pair;
 import org.apache.pinot.query.context.PlannerContext;
-import org.apache.pinot.query.planner.QueryPlan;
+import org.apache.pinot.query.planner.stage.StageNode;
 import org.apache.pinot.query.routing.WorkerManager;
 
 
 public class DispatchablePlanContext {
   private final WorkerManager _workerManager;
+
   private final long _requestId;
-  private final PlannerContext _plannerContext;
-  private final QueryPlan _queryPlan;
   private final Set<String> _tableNames;
+  private final List<Pair<Integer, String>> _resultFields;
+
+  private final PlannerContext _plannerContext;
+  private final Map<Integer, DispatchablePlanMetadata> _dispatchablePlanMetadataMap;
+  private final Map<Integer, StageNode> _dispatchablePlanStageRootMap;
 
   public DispatchablePlanContext(WorkerManager workerManager, long requestId, PlannerContext plannerContext,
       List<Pair<Integer, String>> resultFields, Set<String> tableNames) {
     _workerManager = workerManager;
     _requestId = requestId;
     _plannerContext = plannerContext;
-    _queryPlan = new QueryPlan(resultFields, new HashMap<>(), new HashMap<>());
+    _dispatchablePlanMetadataMap = new HashMap<>();
+    _dispatchablePlanStageRootMap = new HashMap<>();
+    _resultFields = resultFields;
     _tableNames = tableNames;
   }
 
-  public QueryPlan getQueryPlan() {
-    return _queryPlan;
-  }
-
   public WorkerManager getWorkerManager() {
     return _workerManager;
   }
@@ -55,12 +58,24 @@ public class DispatchablePlanContext {
     return _requestId;
   }
 
+  // Returns all the table names.
+  public Set<String> getTableNames() {
+    return _tableNames;
+  }
+
+  public List<Pair<Integer, String>> getResultFields() {
+    return _resultFields;
+  }
+
   public PlannerContext getPlannerContext() {
     return _plannerContext;
   }
 
-  // Returns all the table names.
-  public Set<String> getTableNames() {
-    return _tableNames;
+  public Map<Integer, DispatchablePlanMetadata> getDispatchablePlanMetadataMap() {
+    return _dispatchablePlanMetadataMap;
+  }
+
+  public Map<Integer, StageNode> getDispatchablePlanStageRootMap() {
+    return _dispatchablePlanStageRootMap;
   }
 }
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/StageMetadata.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/DispatchablePlanMetadata.java
similarity index 64%
rename from pinot-query-planner/src/main/java/org/apache/pinot/query/planner/StageMetadata.java
rename to pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/DispatchablePlanMetadata.java
index 24650617e8..a77ad757ce 100644
--- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/StageMetadata.java
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/DispatchablePlanMetadata.java
@@ -16,7 +16,7 @@
  * specific language governing permissions and limitations
  * under the License.
  */
-package org.apache.pinot.query.planner;
+package org.apache.pinot.query.planner.physical;
 
 import java.io.Serializable;
 import java.util.ArrayList;
@@ -24,8 +24,7 @@ import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import org.apache.pinot.core.routing.TimeBoundaryInfo;
-import org.apache.pinot.core.transport.ServerInstance;
-import org.apache.pinot.query.routing.VirtualServer;
+import org.apache.pinot.query.routing.QueryServerInstance;
 
 
 /**
@@ -38,16 +37,16 @@ import org.apache.pinot.query.routing.VirtualServer;
  *   <li>the server instances to which this stage should be execute on</li>
  * </ul>
  */
-public class StageMetadata implements Serializable {
+public class DispatchablePlanMetadata implements Serializable {
   private List<String> _scannedTables;
 
   // used for assigning server/worker nodes.
-  private List<VirtualServer> _serverInstances;
+  private Map<QueryServerInstance, List<Integer>> _serverInstanceToWorkerIdMap;
 
   // used for table scan stage - we use ServerInstance instead of VirtualServer
   // here because all virtual servers that share a server instance will have the
   // same segments on them
-  private Map<ServerInstance, Map<String, List<String>>> _serverInstanceToSegmentsMap;
+  private Map<Integer, Map<String, List<String>>> _workerIdToSegmentsMap;
 
   // time boundary info
   private TimeBoundaryInfo _timeBoundaryInfo;
@@ -55,10 +54,13 @@ public class StageMetadata implements Serializable {
   // whether a stage requires singleton instance to execute, e.g. stage contains global reduce (sort/agg) operator.
   private boolean _requiresSingletonInstance;
 
-  public StageMetadata() {
+  // Total worker count of this stage.
+  private int _totalWorkerCount;
+
+  public DispatchablePlanMetadata() {
     _scannedTables = new ArrayList<>();
-    _serverInstances = new ArrayList<>();
-    _serverInstanceToSegmentsMap = new HashMap<>();
+    _serverInstanceToWorkerIdMap = new HashMap<>();
+    _workerIdToSegmentsMap = new HashMap<>();
     _timeBoundaryInfo = null;
     _requiresSingletonInstance = false;
   }
@@ -75,21 +77,21 @@ public class StageMetadata implements Serializable {
   // attached physical plan context.
   // -----------------------------------------------
 
-  public Map<ServerInstance, Map<String, List<String>>> getServerInstanceToSegmentsMap() {
-    return _serverInstanceToSegmentsMap;
+  public Map<Integer, Map<String, List<String>>> getWorkerIdToSegmentsMap() {
+    return _workerIdToSegmentsMap;
   }
 
-  public void setServerInstanceToSegmentsMap(
-      Map<ServerInstance, Map<String, List<String>>> serverInstanceToSegmentsMap) {
-    _serverInstanceToSegmentsMap = serverInstanceToSegmentsMap;
+  public void setWorkerIdToSegmentsMap(
+      Map<Integer, Map<String, List<String>>> workerIdToSegmentsMap) {
+    _workerIdToSegmentsMap = workerIdToSegmentsMap;
   }
 
-  public List<VirtualServer> getServerInstances() {
-    return _serverInstances;
+  public Map<QueryServerInstance, List<Integer>> getServerInstanceToWorkerIdMap() {
+    return _serverInstanceToWorkerIdMap;
   }
 
-  public void setServerInstances(List<VirtualServer> serverInstances) {
-    _serverInstances = serverInstances;
+  public void setServerInstanceToWorkerIdMap(Map<QueryServerInstance, List<Integer>> serverInstances) {
+    _serverInstanceToWorkerIdMap = serverInstances;
   }
 
   public TimeBoundaryInfo getTimeBoundaryInfo() {
@@ -108,10 +110,18 @@ public class StageMetadata implements Serializable {
     _requiresSingletonInstance = _requiresSingletonInstance || newRequireInstance;
   }
 
+  public int getTotalWorkerCount() {
+    return _totalWorkerCount;
+  }
+
+  public void setTotalWorkerCount(int totalWorkerCount) {
+    _totalWorkerCount = totalWorkerCount;
+  }
+
   @Override
   public String toString() {
-    return "StageMetadata{" + "_scannedTables=" + _scannedTables + ", _serverInstances=" + _serverInstances
-        + ", _serverInstanceToSegmentsMap=" + _serverInstanceToSegmentsMap + ", _timeBoundaryInfo=" + _timeBoundaryInfo
-        + '}';
+    return "DispatchablePlanMetadata{" + "_scannedTables=" + _scannedTables + ", _servers="
+        + _serverInstanceToWorkerIdMap + ", _serverInstanceToSegmentsMap=" + _workerIdToSegmentsMap
+        + ", _timeBoundaryInfo=" + _timeBoundaryInfo + '}';
   }
 }
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/DispatchablePlanVisitor.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/DispatchablePlanVisitor.java
index ea4eaf7ac7..c67bfbd887 100644
--- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/DispatchablePlanVisitor.java
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/DispatchablePlanVisitor.java
@@ -18,7 +18,7 @@
  */
 package org.apache.pinot.query.planner.physical;
 
-import org.apache.pinot.query.planner.StageMetadata;
+import org.apache.pinot.query.planner.QueryPlan;
 import org.apache.pinot.query.planner.stage.AggregateNode;
 import org.apache.pinot.query.planner.stage.FilterNode;
 import org.apache.pinot.query.planner.stage.JoinNode;
@@ -41,82 +41,101 @@ public class DispatchablePlanVisitor implements StageNodeVisitor<Void, Dispatcha
   }
 
   /**
-   * Entry point
-   * @param globalReceiverNode
-   * @param physicalPlanContext
+   * Entry point for attaching dispatch metadata to a query plan. It walks through the plan via the global
+   * {@link StageNode} root of the query and:
+   * <ul>
+   *   <li>break down the {@link StageNode}s into Stages that can run on a single worker.</li>
+   *   <li>each stage is represented by a subset of {@link StageNode}s without data exchange.</li>
+   *   <li>attach worker execution information including physical server address, worker ID to each stage.</li>
+   * </ul>
+   *
+   * @param globalReceiverNode the entrypoint of the stage plan.
+   * @param dispatchablePlanContext dispatchable plan context used to record the walk of the stage node tree.
    */
-  public void constructDispatchablePlan(StageNode globalReceiverNode, DispatchablePlanContext physicalPlanContext) {
-    globalReceiverNode.visit(DispatchablePlanVisitor.INSTANCE, physicalPlanContext);
-    // special case for the global mailbox receive node
-    physicalPlanContext.getQueryPlan().getQueryStageMap().put(0, globalReceiverNode);
-    computeWorkerAssignment(globalReceiverNode, physicalPlanContext);
+  public QueryPlan constructDispatchablePlan(StageNode globalReceiverNode,
+      DispatchablePlanContext dispatchablePlanContext) {
+    // 1. start by visiting the stage root.
+    globalReceiverNode.visit(DispatchablePlanVisitor.INSTANCE, dispatchablePlanContext);
+    // 2. add a special stage for the global mailbox receive, this runs on the dispatcher.
+    dispatchablePlanContext.getDispatchablePlanStageRootMap().put(0, globalReceiverNode);
+    // 3. add worker assignment after the dispatchable plan context is fulfilled after the visit.
+    computeWorkerAssignment(globalReceiverNode, dispatchablePlanContext);
+    // 4. convert it into query plan.
+    return finalizeQueryPlan(dispatchablePlanContext);
   }
 
-  private StageMetadata getStageMetadata(StageNode node, DispatchablePlanContext context) {
-    return context.getQueryPlan().getStageMetadataMap().computeIfAbsent(
-        node.getStageId(), (id) -> new StageMetadata());
+  private static QueryPlan finalizeQueryPlan(DispatchablePlanContext dispatchablePlanContext) {
+    return new QueryPlan(dispatchablePlanContext.getResultFields(),
+        dispatchablePlanContext.getDispatchablePlanStageRootMap(),
+        dispatchablePlanContext.getDispatchablePlanMetadataMap());
   }
 
-  private void computeWorkerAssignment(StageNode node, DispatchablePlanContext context) {
+  private static DispatchablePlanMetadata getOrCreateDispatchablePlanMetadata(StageNode node,
+      DispatchablePlanContext context) {
+    return context.getDispatchablePlanMetadataMap().computeIfAbsent(node.getStageId(),
+        (id) -> new DispatchablePlanMetadata());
+  }
+
+  private static void computeWorkerAssignment(StageNode node, DispatchablePlanContext context) {
     int stageId = node.getStageId();
-    context.getWorkerManager().assignWorkerToStage(stageId, context.getQueryPlan().getStageMetadataMap().get(stageId),
+    context.getWorkerManager().assignWorkerToStage(stageId, context.getDispatchablePlanMetadataMap().get(stageId),
         context.getRequestId(), context.getPlannerContext().getOptions(), context.getTableNames());
   }
 
   @Override
   public Void visitAggregate(AggregateNode node, DispatchablePlanContext context) {
     node.getInputs().get(0).visit(this, context);
-    StageMetadata stageMetadata = getStageMetadata(node, context);
-    stageMetadata.setRequireSingleton(node.getGroupSet().size() == 0 && AggregateNode.isFinalStage(node));
+    DispatchablePlanMetadata dispatchablePlanMetadata = getOrCreateDispatchablePlanMetadata(node, context);
+    dispatchablePlanMetadata.setRequireSingleton(node.getGroupSet().size() == 0 && AggregateNode.isFinalStage(node));
     return null;
   }
 
   @Override
   public Void visitWindow(WindowNode node, DispatchablePlanContext context) {
     node.getInputs().get(0).visit(this, context);
-    StageMetadata stageMetadata = getStageMetadata(node, context);
+    DispatchablePlanMetadata dispatchablePlanMetadata = getOrCreateDispatchablePlanMetadata(node, context);
     // TODO: Figure out a way to parallelize Empty OVER() and OVER(ORDER BY) so the computation can be done across
     //       multiple nodes.
     // Empty OVER() and OVER(ORDER BY) need to be processed on a singleton node. OVER() with PARTITION BY can be
     // distributed as no global ordering is required across partitions.
-    stageMetadata.setRequireSingleton(node.getGroupSet().size() == 0);
+    dispatchablePlanMetadata.setRequireSingleton(node.getGroupSet().size() == 0);
     return null;
   }
 
   @Override
   public Void visitSetOp(SetOpNode setOpNode, DispatchablePlanContext context) {
     setOpNode.getInputs().forEach(input -> input.visit(this, context));
-    getStageMetadata(setOpNode, context);
+    getOrCreateDispatchablePlanMetadata(setOpNode, context);
     return null;
   }
 
   @Override
   public Void visitFilter(FilterNode node, DispatchablePlanContext context) {
     node.getInputs().get(0).visit(this, context);
-    getStageMetadata(node, context);
+    getOrCreateDispatchablePlanMetadata(node, context);
     return null;
   }
 
   @Override
   public Void visitJoin(JoinNode node, DispatchablePlanContext context) {
     node.getInputs().forEach(join -> join.visit(this, context));
-    getStageMetadata(node, context);
+    getOrCreateDispatchablePlanMetadata(node, context);
     return null;
   }
 
   @Override
   public Void visitMailboxReceive(MailboxReceiveNode node, DispatchablePlanContext context) {
     node.getSender().visit(this, context);
-    getStageMetadata(node, context);
+    getOrCreateDispatchablePlanMetadata(node, context);
     return null;
   }
 
   @Override
   public Void visitMailboxSend(MailboxSendNode node, DispatchablePlanContext context) {
     node.getInputs().get(0).visit(this, context);
-    getStageMetadata(node, context);
+    getOrCreateDispatchablePlanMetadata(node, context);
 
-    context.getQueryPlan().getQueryStageMap().put(node.getStageId(), node);
+    context.getDispatchablePlanStageRootMap().put(node.getStageId(), node);
     computeWorkerAssignment(node, context);
     return null;
   }
@@ -124,28 +143,28 @@ public class DispatchablePlanVisitor implements StageNodeVisitor<Void, Dispatcha
   @Override
   public Void visitProject(ProjectNode node, DispatchablePlanContext context) {
     node.getInputs().get(0).visit(this, context);
-    getStageMetadata(node, context);
+    getOrCreateDispatchablePlanMetadata(node, context);
     return null;
   }
 
   @Override
   public Void visitSort(SortNode node, DispatchablePlanContext context) {
     node.getInputs().get(0).visit(this, context);
-    StageMetadata stageMetadata = getStageMetadata(node, context);
-    stageMetadata.setRequireSingleton(node.getCollationKeys().size() > 0 && node.getOffset() != -1);
+    DispatchablePlanMetadata dispatchablePlanMetadata = getOrCreateDispatchablePlanMetadata(node, context);
+    dispatchablePlanMetadata.setRequireSingleton(node.getCollationKeys().size() > 0 && node.getOffset() != -1);
     return null;
   }
 
   @Override
   public Void visitTableScan(TableScanNode node, DispatchablePlanContext context) {
-    StageMetadata stageMetadata = getStageMetadata(node, context);
-    stageMetadata.addScannedTable(node.getTableName());
+    DispatchablePlanMetadata dispatchablePlanMetadata = getOrCreateDispatchablePlanMetadata(node, context);
+    dispatchablePlanMetadata.addScannedTable(node.getTableName());
     return null;
   }
 
   @Override
   public Void visitValue(ValueNode node, DispatchablePlanContext context) {
-    getStageMetadata(node, context);
+    getOrCreateDispatchablePlanMetadata(node, context);
     return null;
   }
 }
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/colocated/GreedyShuffleRewriteVisitor.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/colocated/GreedyShuffleRewriteVisitor.java
index b5d77c1193..4f89435a6e 100644
--- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/colocated/GreedyShuffleRewriteVisitor.java
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/colocated/GreedyShuffleRewriteVisitor.java
@@ -30,10 +30,10 @@ import java.util.stream.Collectors;
 import org.apache.calcite.rel.RelDistribution;
 import org.apache.pinot.common.config.provider.TableCache;
 import org.apache.pinot.query.planner.QueryPlan;
-import org.apache.pinot.query.planner.StageMetadata;
 import org.apache.pinot.query.planner.logical.RexExpression;
 import org.apache.pinot.query.planner.partitioning.FieldSelectionKeySelector;
 import org.apache.pinot.query.planner.partitioning.KeySelector;
+import org.apache.pinot.query.planner.physical.DispatchablePlanMetadata;
 import org.apache.pinot.query.planner.stage.AggregateNode;
 import org.apache.pinot.query.planner.stage.FilterNode;
 import org.apache.pinot.query.planner.stage.JoinNode;
@@ -47,7 +47,7 @@ import org.apache.pinot.query.planner.stage.StageNodeVisitor;
 import org.apache.pinot.query.planner.stage.TableScanNode;
 import org.apache.pinot.query.planner.stage.ValueNode;
 import org.apache.pinot.query.planner.stage.WindowNode;
-import org.apache.pinot.query.routing.VirtualServer;
+import org.apache.pinot.query.routing.QueryServerInstance;
 import org.apache.pinot.spi.config.table.ColumnPartitionConfig;
 import org.apache.pinot.spi.config.table.IndexingConfig;
 import org.apache.pinot.spi.config.table.TableConfig;
@@ -70,24 +70,25 @@ public class GreedyShuffleRewriteVisitor implements StageNodeVisitor<Set<Colocat
   private static final Logger LOGGER = LoggerFactory.getLogger(GreedyShuffleRewriteVisitor.class);
 
   private final TableCache _tableCache;
-  private final Map<Integer, StageMetadata> _stageMetadataMap;
+  private final Map<Integer, DispatchablePlanMetadata> _dispatchablePlanMetadataMap;
   private boolean _canSkipShuffleForJoin;
 
   public static void optimizeShuffles(QueryPlan queryPlan, TableCache tableCache) {
     StageNode rootStageNode = queryPlan.getQueryStageMap().get(0);
-    Map<Integer, StageMetadata> stageMetadataMap = queryPlan.getStageMetadataMap();
+    Map<Integer, DispatchablePlanMetadata> dispatchablePlanMetadataMap = queryPlan.getDispatchablePlanMetadataMap();
     GreedyShuffleRewriteContext context = GreedyShuffleRewritePreComputeVisitor.preComputeContext(rootStageNode);
     // This assumes that if stageId(S1) > stageId(S2), then S1 is not an ancestor of S2.
     // TODO: If this assumption is wrong, we can compute the reverse topological ordering explicitly.
-    for (int stageId = stageMetadataMap.size() - 1; stageId >= 0; stageId--) {
+    for (int stageId = dispatchablePlanMetadataMap.size() - 1; stageId >= 0; stageId--) {
       StageNode stageNode = context.getRootStageNode(stageId);
-      stageNode.visit(new GreedyShuffleRewriteVisitor(tableCache, stageMetadataMap), context);
+      stageNode.visit(new GreedyShuffleRewriteVisitor(tableCache, dispatchablePlanMetadataMap), context);
     }
   }
 
-  private GreedyShuffleRewriteVisitor(TableCache tableCache, Map<Integer, StageMetadata> stageMetadataMap) {
+  private GreedyShuffleRewriteVisitor(TableCache tableCache,
+      Map<Integer, DispatchablePlanMetadata> dispatchablePlanMetadataMap) {
     _tableCache = tableCache;
-    _stageMetadataMap = stageMetadataMap;
+    _dispatchablePlanMetadataMap = dispatchablePlanMetadataMap;
     _canSkipShuffleForJoin = false;
   }
 
@@ -138,8 +139,8 @@ public class GreedyShuffleRewriteVisitor implements StageNodeVisitor<Set<Colocat
     canColocate = canColocate && checkPartitionScheme(innerLeafNodes.get(0), innerLeafNodes.get(1), context);
     if (canColocate) {
       // If shuffle can be skipped, reassign servers.
-      _stageMetadataMap.get(node.getStageId())
-          .setServerInstances(_stageMetadataMap.get(innerLeafNodes.get(0).getSenderStageId()).getServerInstances());
+      _dispatchablePlanMetadataMap.get(node.getStageId()).setServerInstanceToWorkerIdMap(
+          _dispatchablePlanMetadataMap.get(innerLeafNodes.get(0).getSenderStageId()).getServerInstanceToWorkerIdMap());
       _canSkipShuffleForJoin = true;
     }
 
@@ -172,12 +173,12 @@ public class GreedyShuffleRewriteVisitor implements StageNodeVisitor<Set<Colocat
       } else if (colocationKeyCondition(oldColocationKeys, selector) && areServersSuperset(node.getStageId(),
           node.getSenderStageId())) {
         node.setExchangeType(RelDistribution.Type.SINGLETON);
-        _stageMetadataMap.get(node.getStageId())
-            .setServerInstances(_stageMetadataMap.get(node.getSenderStageId()).getServerInstances());
+        _dispatchablePlanMetadataMap.get(node.getStageId()).setServerInstanceToWorkerIdMap(
+            _dispatchablePlanMetadataMap.get(node.getSenderStageId()).getServerInstanceToWorkerIdMap());
         return oldColocationKeys;
       }
       // This means we can't skip shuffle and there's a partitioning enforced by receiver.
-      int numPartitions = _stageMetadataMap.get(node.getStageId()).getServerInstances().size();
+      int numPartitions = _dispatchablePlanMetadataMap.get(node.getStageId()).getServerInstanceToWorkerIdMap().size();
       List<ColocationKey> colocationKeys = ((FieldSelectionKeySelector) selector).getColumnIndices().stream()
           .map(x -> new ColocationKey(x, numPartitions, selector.hashAlgorithm())).collect(Collectors.toList());
       return new HashSet<>(colocationKeys);
@@ -193,7 +194,7 @@ public class GreedyShuffleRewriteVisitor implements StageNodeVisitor<Set<Colocat
       return new HashSet<>();
     }
     // This means we can't skip shuffle and there's a partitioning enforced by receiver.
-    int numPartitions = _stageMetadataMap.get(node.getStageId()).getServerInstances().size();
+    int numPartitions = _dispatchablePlanMetadataMap.get(node.getStageId()).getServerInstanceToWorkerIdMap().size();
     List<ColocationKey> colocationKeys = ((FieldSelectionKeySelector) selector).getColumnIndices().stream()
         .map(x -> new ColocationKey(x, numPartitions, selector.hashAlgorithm())).collect(Collectors.toList());
     return new HashSet<>(colocationKeys);
@@ -298,8 +299,8 @@ public class GreedyShuffleRewriteVisitor implements StageNodeVisitor<Set<Colocat
    * Checks if servers assigned to the receiver stage are a super-set of the sender stage.
    */
   private boolean areServersSuperset(int receiverStageId, int senderStageId) {
-    return _stageMetadataMap.get(receiverStageId).getServerInstances()
-        .containsAll(_stageMetadataMap.get(senderStageId).getServerInstances());
+    return _dispatchablePlanMetadataMap.get(receiverStageId).getServerInstanceToWorkerIdMap().keySet()
+        .containsAll(_dispatchablePlanMetadataMap.get(senderStageId).getServerInstanceToWorkerIdMap().keySet());
   }
 
   /*
@@ -308,12 +309,15 @@ public class GreedyShuffleRewriteVisitor implements StageNodeVisitor<Set<Colocat
    * 2. Servers assigned to the join-stage are a superset of S.
    */
   private boolean canServerAssignmentAllowShuffleSkip(int currentStageId, int leftStageId, int rightStageId) {
-    Set<VirtualServer> leftServerInstances = new HashSet<>(_stageMetadataMap.get(leftStageId).getServerInstances());
-    List<VirtualServer> rightServerInstances = _stageMetadataMap.get(rightStageId).getServerInstances();
-    List<VirtualServer> currentServerInstances = _stageMetadataMap.get(currentStageId).getServerInstances();
+    Set<QueryServerInstance> leftServerInstances = new HashSet<>(_dispatchablePlanMetadataMap.get(leftStageId)
+        .getServerInstanceToWorkerIdMap().keySet());
+    Set<QueryServerInstance> rightServerInstances = _dispatchablePlanMetadataMap.get(rightStageId)
+        .getServerInstanceToWorkerIdMap().keySet();
+    Set<QueryServerInstance> currentServerInstances = _dispatchablePlanMetadataMap.get(currentStageId)
+        .getServerInstanceToWorkerIdMap().keySet();
     return leftServerInstances.containsAll(rightServerInstances)
-        && leftServerInstances.size() == rightServerInstances.size() && currentServerInstances.containsAll(
-        leftServerInstances);
+        && leftServerInstances.size() == rightServerInstances.size()
+        && currentServerInstances.containsAll(leftServerInstances);
   }
 
   /**
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/routing/VirtualServer.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/routing/QueryServerInstance.java
similarity index 54%
rename from pinot-query-planner/src/main/java/org/apache/pinot/query/routing/VirtualServer.java
rename to pinot-query-planner/src/main/java/org/apache/pinot/query/routing/QueryServerInstance.java
index c50aee41e7..b9442b728c 100644
--- a/pinot-query-planner/src/main/java/org/apache/pinot/query/routing/VirtualServer.java
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/routing/QueryServerInstance.java
@@ -24,48 +24,36 @@ import org.apache.pinot.core.transport.ServerInstance;
 
 
 /**
- * {@code VirtualServer} is a {@link ServerInstance} associated with a
- * unique virtualization identifier which allows the multistage query
- * engine to collocate multiple virtual servers on a single physical
- * instance, enabling higher levels of parallelism and partitioning
- * the query input.
+ * {@code QueryServerInstance} is representation used during query dispatch to indicate the
+ * physical location of a query server.
+ *
+ * <p>Note that {@code QueryServerInstance} should only be used during dispatch.</p>
  */
-public class VirtualServer {
-
-  private final ServerInstance _server;
-  private final int _virtualId;
-
-  public VirtualServer(ServerInstance server, int virtualId) {
-    _server = server;
-    _virtualId = virtualId;
-  }
+public class QueryServerInstance {
+  private final String _hostname;
+  private final int _queryServicePort;
+  private final int _queryMailboxPort;
 
-  public ServerInstance getServer() {
-    return _server;
+  public QueryServerInstance(ServerInstance server) {
+    this(server.getHostname(), server.getQueryServicePort(), server.getQueryMailboxPort());
   }
 
-  public int getVirtualId() {
-    return _virtualId;
+  public QueryServerInstance(String hostName, int servicePort, int mailboxPort) {
+    _hostname = hostName;
+    _queryServicePort = servicePort;
+    _queryMailboxPort = mailboxPort;
   }
 
   public String getHostname() {
-    return _server.getHostname();
-  }
-
-  public int getPort() {
-    return _server.getPort();
+    return _hostname;
   }
 
   public int getQueryMailboxPort() {
-    return _server.getQueryMailboxPort();
+    return _queryMailboxPort;
   }
 
   public int getQueryServicePort() {
-    return _server.getQueryServicePort();
-  }
-
-  public int getGrpcPort() {
-    return _server.getGrpcPort();
+    return _queryServicePort;
   }
 
   @Override
@@ -76,17 +64,18 @@ public class VirtualServer {
     if (o == null || getClass() != o.getClass()) {
       return false;
     }
-    VirtualServer that = (VirtualServer) o;
-    return _virtualId == that._virtualId && Objects.equals(_server, that._server);
+    QueryServerInstance that = (QueryServerInstance) o;
+    return _hostname.equals(that._hostname) && _queryServicePort == that._queryServicePort
+        && _queryMailboxPort == that._queryMailboxPort;
   }
 
   @Override
   public int hashCode() {
-    return Objects.hash(_server, _virtualId);
+    return Objects.hash(_hostname, _queryServicePort, _queryMailboxPort);
   }
 
   @Override
   public String toString() {
-    return _virtualId + "@" + _server.getInstanceId();
+    return _hostname + "@{" + _queryServicePort + "," + _queryMailboxPort + "}";
   }
 }
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/routing/StageMetadata.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/routing/StageMetadata.java
new file mode 100644
index 0000000000..4fc6c2ef50
--- /dev/null
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/routing/StageMetadata.java
@@ -0,0 +1,92 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.query.routing;
+
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import org.apache.pinot.core.routing.TimeBoundaryInfo;
+
+
+/**
+ * {@code StageMetadata} is used to send stage-level info about how to execute a stage physically.
+ */
+public class StageMetadata {
+  private final List<WorkerMetadata> _workerMetadataList;
+  private final Map<String, String> _customProperties;
+
+  public StageMetadata(List<WorkerMetadata> workerMetadataList, Map<String, String> customProperties) {
+    _workerMetadataList = workerMetadataList;
+    _customProperties = customProperties;
+  }
+
+  public List<WorkerMetadata> getWorkerMetadataList() {
+    return _workerMetadataList;
+  }
+
+  public Map<String, String> getCustomProperties() {
+    return _customProperties;
+  }
+
+  public static class Builder {
+    public static final String TABLE_NAME_KEY = "tableName";
+    public static final String TIME_BOUNDARY_COLUMN_KEY = "timeBoundaryInfo.timeColumn";
+    public static final String TIME_BOUNDARY_VALUE_KEY = "timeBoundaryInfo.timeValue";
+    private List<WorkerMetadata> _workerMetadataList;
+    private Map<String, String> _customProperties;
+
+    public Builder() {
+      _customProperties = new HashMap<>();
+    }
+
+    public Builder setWorkerMetadataList(List<WorkerMetadata> workerMetadataList) {
+      _workerMetadataList = workerMetadataList;
+      return this;
+    }
+
+    public Builder addTableName(String tableName) {
+      _customProperties.put(TABLE_NAME_KEY, tableName);
+      return this;
+    }
+
+    public Builder addTimeBoundaryInfo(TimeBoundaryInfo timeBoundaryInfo) {
+      _customProperties.put(TIME_BOUNDARY_COLUMN_KEY, timeBoundaryInfo.getTimeColumn());
+      _customProperties.put(TIME_BOUNDARY_VALUE_KEY, timeBoundaryInfo.getTimeValue());
+      return this;
+    }
+
+    public StageMetadata build() {
+      return new StageMetadata(_workerMetadataList, _customProperties);
+    }
+
+    public void putAllCustomProperties(Map<String, String> customPropertyMap) {
+      _customProperties.putAll(customPropertyMap);
+    }
+  }
+
+  public static String getTableName(StageMetadata metadata) {
+    return metadata.getCustomProperties().get(Builder.TABLE_NAME_KEY);
+  }
+
+  public static TimeBoundaryInfo getTimeBoundary(StageMetadata metadata) {
+    String timeColumn = metadata.getCustomProperties().get(Builder.TIME_BOUNDARY_COLUMN_KEY);
+    String timeValue = metadata.getCustomProperties().get(Builder.TIME_BOUNDARY_VALUE_KEY);
+    return timeColumn != null && timeValue != null ? new TimeBoundaryInfo(timeColumn, timeValue) : null;
+  }
+}
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/routing/VirtualServerAddress.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/routing/VirtualServerAddress.java
index 99074f7191..5b2f3b012d 100644
--- a/pinot-query-planner/src/main/java/org/apache/pinot/query/routing/VirtualServerAddress.java
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/routing/VirtualServerAddress.java
@@ -23,29 +23,23 @@ import java.util.Objects;
 
 
 /**
- * Represents the address of a {@link VirtualServer} containing
- * both the ID of the specific virtualized server and the physical
- * internet address in id@hostname:port format.
- *
- * <p>This is needed in addition to {@code VirtualServer} because there
- * are some parts of the code that don't have enough information to
- * construct the full {@code VirtualServer} and only require the
- * hostname, port and virtualId.</p>
+ * Represents the address of a {@link QueryServerInstance} containing both the ID of the specific worker and the
+ * physical host/port info from {@link QueryServerInstance}.
  */
 public class VirtualServerAddress {
 
   private final String _hostname;
   private final int _port;
-  private final int _virtualId;
+  private final int _workerId;
 
-  public VirtualServerAddress(String hostname, int port, int virtualId) {
+  public VirtualServerAddress(String hostname, int port, int workerId) {
     _hostname = hostname;
     _port = port;
-    _virtualId = virtualId;
+    _workerId = workerId;
   }
 
-  public VirtualServerAddress(VirtualServer server) {
-    this(server.getHostname(), server.getQueryMailboxPort(), server.getVirtualId());
+  public VirtualServerAddress(QueryServerInstance server, int workerId) {
+    this(server.getHostname(), server.getQueryMailboxPort(), workerId);
   }
 
   /**
@@ -75,8 +69,8 @@ public class VirtualServerAddress {
     return _port;
   }
 
-  public int virtualId() {
-    return _virtualId;
+  public int workerId() {
+    return _workerId;
   }
 
   @Override
@@ -89,17 +83,17 @@ public class VirtualServerAddress {
     }
     VirtualServerAddress that = (VirtualServerAddress) o;
     return _port == that._port
-        && _virtualId == that._virtualId
+        && _workerId == that._workerId
         && Objects.equals(_hostname, that._hostname);
   }
 
   @Override
   public int hashCode() {
-    return Objects.hash(_hostname, _port, _virtualId);
+    return Objects.hash(_hostname, _port, _workerId);
   }
 
   @Override
   public String toString() {
-    return _virtualId + "@" + _hostname + ":" + _port;
+    return _workerId + "@" + _hostname + ":" + _port;
   }
 }
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/routing/WorkerInstance.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/routing/WorkerInstance.java
deleted file mode 100644
index f635b99885..0000000000
--- a/pinot-query-planner/src/main/java/org/apache/pinot/query/routing/WorkerInstance.java
+++ /dev/null
@@ -1,56 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.pinot.query.routing;
-
-import java.util.Map;
-import org.apache.helix.model.InstanceConfig;
-import org.apache.helix.zookeeper.datamodel.ZNRecord;
-import org.apache.pinot.core.transport.ServerInstance;
-import org.apache.pinot.spi.utils.CommonConstants;
-
-
-/**
- * WorkerInstance is a wrapper around {@link ServerInstance}.
- *
- * <p>This can be considered as a simplified version which directly enable host-port initialization.
- */
-public class WorkerInstance extends ServerInstance {
-
-  public WorkerInstance(InstanceConfig instanceConfig) {
-    super(instanceConfig);
-  }
-
-  public WorkerInstance(String hostname, int nettyPort, int grpcPort, int servicePort, int mailboxPort) {
-    super(toInstanceConfig(hostname, nettyPort, grpcPort, servicePort, mailboxPort));
-  }
-
-  private static InstanceConfig toInstanceConfig(String hostname, int nettyPort, int grpcPort, int servicePort,
-      int mailboxPort) {
-    String server = String.format("%s%s_%d", CommonConstants.Helix.PREFIX_OF_SERVER_INSTANCE, hostname, nettyPort);
-    InstanceConfig instanceConfig = InstanceConfig.toInstanceConfig(server);
-    ZNRecord znRecord = instanceConfig.getRecord();
-    Map<String, String> simpleFields = znRecord.getSimpleFields();
-    simpleFields.put(CommonConstants.Helix.Instance.GRPC_PORT_KEY, String.valueOf(grpcPort));
-    simpleFields.put(CommonConstants.Helix.Instance.MULTI_STAGE_QUERY_ENGINE_SERVICE_PORT_KEY,
-        String.valueOf(servicePort));
-    simpleFields.put(CommonConstants.Helix.Instance.MULTI_STAGE_QUERY_ENGINE_MAILBOX_PORT_KEY,
-        String.valueOf(mailboxPort));
-    return instanceConfig;
-  }
-}
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/routing/WorkerManager.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/routing/WorkerManager.java
index dbebcd7692..c1193a9c5b 100644
--- a/pinot-query-planner/src/main/java/org/apache/pinot/query/routing/WorkerManager.java
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/routing/WorkerManager.java
@@ -19,8 +19,8 @@
 package org.apache.pinot.query.routing;
 
 import com.google.common.base.Preconditions;
-import com.google.common.collect.Lists;
 import java.util.ArrayList;
+import java.util.Collections;
 import java.util.HashMap;
 import java.util.HashSet;
 import java.util.List;
@@ -33,7 +33,7 @@ import org.apache.pinot.core.routing.RoutingTable;
 import org.apache.pinot.core.routing.TimeBoundaryInfo;
 import org.apache.pinot.core.transport.ServerInstance;
 import org.apache.pinot.query.planner.PlannerUtils;
-import org.apache.pinot.query.planner.StageMetadata;
+import org.apache.pinot.query.planner.physical.DispatchablePlanMetadata;
 import org.apache.pinot.spi.config.table.TableType;
 import org.apache.pinot.spi.utils.CommonConstants;
 import org.apache.pinot.spi.utils.builder.TableNameBuilder;
@@ -62,102 +62,118 @@ public class WorkerManager {
     _routingManager = routingManager;
   }
 
-  public void assignWorkerToStage(int stageId, StageMetadata stageMetadata, long requestId,
+  public void assignWorkerToStage(int stageId, DispatchablePlanMetadata dispatchablePlanMetadata, long requestId,
       Map<String, String> options, Set<String> tableNames) {
-    if (isLeafStage(stageMetadata)) {
-      // --- LEAF STAGE ---
-      // table scan stage, need to attach server as well as segment info for each physical table type.
-      List<String> scannedTables = stageMetadata.getScannedTables();
-      String logicalTableName = scannedTables.get(0);
-      Map<String, RoutingTable> routingTableMap = getRoutingTable(logicalTableName, requestId);
-      if (routingTableMap.size() == 0) {
-        throw new IllegalArgumentException("Unable to find routing entries for table: " + logicalTableName);
-      }
-      // acquire time boundary info if it is a hybrid table.
-      if (routingTableMap.size() > 1) {
-        TimeBoundaryInfo timeBoundaryInfo = _routingManager.getTimeBoundaryInfo(TableNameBuilder
-            .forType(TableType.OFFLINE).tableNameWithType(TableNameBuilder.extractRawTableName(logicalTableName)));
-        if (timeBoundaryInfo != null) {
-          stageMetadata.setTimeBoundaryInfo(timeBoundaryInfo);
-        } else {
-          // remove offline table routing if no time boundary info is acquired.
-          routingTableMap.remove(TableType.OFFLINE.name());
-        }
-      }
-
-      // extract all the instances associated to each table type
-      Map<ServerInstance, Map<String, List<String>>> serverInstanceToSegmentsMap = new HashMap<>();
-      for (Map.Entry<String, RoutingTable> routingEntry : routingTableMap.entrySet()) {
-        String tableType = routingEntry.getKey();
-        RoutingTable routingTable = routingEntry.getValue();
-        // for each server instance, attach all table types and their associated segment list.
-        for (Map.Entry<ServerInstance, List<String>> serverEntry
-            : routingTable.getServerInstanceToSegmentsMap().entrySet()) {
-          serverInstanceToSegmentsMap.putIfAbsent(serverEntry.getKey(), new HashMap<>());
-          Map<String, List<String>> tableTypeToSegmentListMap = serverInstanceToSegmentsMap.get(serverEntry.getKey());
-          Preconditions.checkState(tableTypeToSegmentListMap.put(tableType, serverEntry.getValue()) == null,
-              "Entry for server {} and table type: {} already exist!", serverEntry.getKey(), tableType);
-        }
-      }
-      int globalIdx = 0;
-      List<VirtualServer> serverInstanceList = new ArrayList<>();
-      for (ServerInstance serverInstance : serverInstanceToSegmentsMap.keySet()) {
-        serverInstanceList.add(new VirtualServer(serverInstance, globalIdx++));
-      }
-      stageMetadata.setServerInstances(serverInstanceList);
-      stageMetadata.setServerInstanceToSegmentsMap(serverInstanceToSegmentsMap);
-    } else if (PlannerUtils.isRootStage(stageId)) {
+    if (PlannerUtils.isRootStage(stageId)) {
       // --- ROOT STAGE / BROKER REDUCE STAGE ---
       // ROOT stage doesn't have a QueryServer as it is strictly only reducing results.
       // here we simply assign the worker instance with identical server/mailbox port number.
-      stageMetadata.setServerInstances(Lists.newArrayList(
-          new VirtualServer(new WorkerInstance(_hostName, _port, _port, _port, _port), 0)));
+      dispatchablePlanMetadata.setServerInstanceToWorkerIdMap(Collections.singletonMap(
+          new QueryServerInstance(_hostName, _port, _port), Collections.singletonList(0)));
+      dispatchablePlanMetadata.setTotalWorkerCount(1);
+    } else if (isLeafStage(dispatchablePlanMetadata)) {
+      // --- LEAF STAGE ---
+      assignWorkerToLeafStage(requestId, dispatchablePlanMetadata);
     } else {
       // --- INTERMEDIATE STAGES ---
       // If the query has more than one table, it is possible that the tables could be hosted on different tenants.
       // The intermediate stage will be processed on servers randomly picked from the tenants belonging to either or
       // all of the tables in the query.
       // TODO: actually make assignment strategy decisions for intermediate stages
-      Set<ServerInstance> serverInstances = new HashSet<>();
-      if (tableNames.size() == 0) {
-        // This could be the case from queries that don't actually fetch values from the tables. In such cases the
-        // routing need not be tenant aware.
-        // Eg: SELECT 1 AS one FROM select_having_expression_test_test_having HAVING 1 > 2;
-        serverInstances = _routingManager.getEnabledServerInstanceMap().values().stream().collect(Collectors.toSet());
+      assignWorkerToIntermediateStage(dispatchablePlanMetadata, tableNames, options);
+    }
+  }
+
+  private void assignWorkerToLeafStage(long requestId, DispatchablePlanMetadata dispatchablePlanMetadata) {
+    // table scan stage, need to attach server as well as segment info for each physical table type.
+    List<String> scannedTables = dispatchablePlanMetadata.getScannedTables();
+    String logicalTableName = scannedTables.get(0);
+    Map<String, RoutingTable> routingTableMap = getRoutingTable(logicalTableName, requestId);
+    if (routingTableMap.size() == 0) {
+      throw new IllegalArgumentException("Unable to find routing entries for table: " + logicalTableName);
+    }
+    // acquire time boundary info if it is a hybrid table.
+    if (routingTableMap.size() > 1) {
+      TimeBoundaryInfo timeBoundaryInfo = _routingManager.getTimeBoundaryInfo(TableNameBuilder
+          .forType(TableType.OFFLINE).tableNameWithType(TableNameBuilder.extractRawTableName(logicalTableName)));
+      if (timeBoundaryInfo != null) {
+        dispatchablePlanMetadata.setTimeBoundaryInfo(timeBoundaryInfo);
       } else {
-        serverInstances = fetchServersForIntermediateStage(tableNames);
+        // remove offline table routing if no time boundary info is acquired.
+        routingTableMap.remove(TableType.OFFLINE.name());
       }
+    }
 
-      stageMetadata.setServerInstances(
-          assignServers(serverInstances, stageMetadata.isRequiresSingletonInstance(), options));
+    // extract all the instances associated to each table type
+    Map<ServerInstance, Map<String, List<String>>> serverInstanceToSegmentsMap = new HashMap<>();
+    for (Map.Entry<String, RoutingTable> routingEntry : routingTableMap.entrySet()) {
+      String tableType = routingEntry.getKey();
+      RoutingTable routingTable = routingEntry.getValue();
+      // for each server instance, attach all table types and their associated segment list.
+      for (Map.Entry<ServerInstance, List<String>> serverEntry
+          : routingTable.getServerInstanceToSegmentsMap().entrySet()) {
+        serverInstanceToSegmentsMap.putIfAbsent(serverEntry.getKey(), new HashMap<>());
+        Map<String, List<String>> tableTypeToSegmentListMap = serverInstanceToSegmentsMap.get(serverEntry.getKey());
+        Preconditions.checkState(tableTypeToSegmentListMap.put(tableType, serverEntry.getValue()) == null,
+            "Entry for server {} and table type: {} already exist!", serverEntry.getKey(), tableType);
+      }
+    }
+    int globalIdx = 0;
+    Map<QueryServerInstance, List<Integer>> serverInstanceToWorkerIdMap = new HashMap<>();
+    Map<Integer, Map<String, List<String>>> workerIdToSegmentsMap = new HashMap<>();
+    for (Map.Entry<ServerInstance, Map<String, List<String>>> entry : serverInstanceToSegmentsMap.entrySet()) {
+      QueryServerInstance queryServerInstance = new QueryServerInstance(entry.getKey());
+      serverInstanceToWorkerIdMap.put(queryServerInstance, Collections.singletonList(globalIdx));
+      workerIdToSegmentsMap.put(globalIdx, entry.getValue());
+      globalIdx++;
     }
+    dispatchablePlanMetadata.setServerInstanceToWorkerIdMap(serverInstanceToWorkerIdMap);
+    dispatchablePlanMetadata.setWorkerIdToSegmentsMap(workerIdToSegmentsMap);
+    dispatchablePlanMetadata.setTotalWorkerCount(globalIdx);
   }
 
-  private static List<VirtualServer> assignServers(Set<ServerInstance> servers,
+  private void assignWorkerToIntermediateStage(DispatchablePlanMetadata dispatchablePlanMetadata,
+      Set<String> tableNames, Map<String, String> options) {
+    // If the query has more than one table, it is possible that the tables could be hosted on different tenants.
+    // The intermediate stage will be processed on servers randomly picked from the tenants belonging to either or
+    // all of the tables in the query.
+    // TODO: actually make assignment strategy decisions for intermediate stages
+    Set<ServerInstance> serverInstances = new HashSet<>();
+    if (tableNames.size() == 0) {
+      // This could be the case from queries that don't actually fetch values from the tables. In such cases the
+      // routing need not be tenant aware.
+      // Eg: SELECT 1 AS one FROM select_having_expression_test_test_having HAVING 1 > 2;
+      serverInstances = _routingManager.getEnabledServerInstanceMap().values().stream().collect(Collectors.toSet());
+    } else {
+      serverInstances = fetchServersForIntermediateStage(tableNames);
+    }
+    assignServers(dispatchablePlanMetadata, serverInstances, dispatchablePlanMetadata.isRequiresSingletonInstance(),
+        options);
+  }
+
+  private static void assignServers(DispatchablePlanMetadata dispatchablePlanMetadata, Set<ServerInstance> servers,
       boolean requiresSingletonInstance, Map<String, String> options) {
     int stageParallelism = Integer.parseInt(
         options.getOrDefault(CommonConstants.Broker.Request.QueryOptionKey.STAGE_PARALLELISM, "1"));
     List<ServerInstance> serverInstances = new ArrayList<>(servers);
-    List<VirtualServer> virtualServerList = new ArrayList<>();
+    Map<QueryServerInstance, List<Integer>> serverInstanceToWorkerIdMap = new HashMap<>();
     if (requiresSingletonInstance) {
       // require singleton should return a single global worker ID with 0;
       ServerInstance serverInstance = serverInstances.get(RANDOM.nextInt(serverInstances.size()));
-      virtualServerList.add(new VirtualServer(serverInstance, 0));
+      serverInstanceToWorkerIdMap.put(new QueryServerInstance(serverInstance), Collections.singletonList(0));
+      dispatchablePlanMetadata.setTotalWorkerCount(1);
     } else {
       int globalIdx = 0;
       for (ServerInstance server : servers) {
-        String hostname = server.getHostname();
-        if (server.getQueryServicePort() > 0 && server.getQueryMailboxPort() > 0
-            && !hostname.startsWith(CommonConstants.Helix.PREFIX_OF_BROKER_INSTANCE)
-            && !hostname.startsWith(CommonConstants.Helix.PREFIX_OF_CONTROLLER_INSTANCE)
-            && !hostname.startsWith(CommonConstants.Helix.PREFIX_OF_MINION_INSTANCE)) {
-          for (int virtualId = 0; virtualId < stageParallelism; virtualId++) {
-              virtualServerList.add(new VirtualServer(server, globalIdx++));
-            }
-          }
+        List<Integer> workerIdList = new ArrayList<>();
+        for (int virtualId = 0; virtualId < stageParallelism; virtualId++) {
+          workerIdList.add(globalIdx++);
         }
+        serverInstanceToWorkerIdMap.put(new QueryServerInstance(server), workerIdList);
+      }
+      dispatchablePlanMetadata.setTotalWorkerCount(globalIdx);
     }
-    return virtualServerList;
+    dispatchablePlanMetadata.setServerInstanceToWorkerIdMap(serverInstanceToWorkerIdMap);
   }
 
   /**
@@ -198,8 +214,8 @@ public class WorkerManager {
 
   // TODO: Find a better way to determine whether a stage is leaf stage or intermediary. We could have query plans that
   //       process table data even in intermediary stages.
-  private boolean isLeafStage(StageMetadata stageMetadata) {
-    return stageMetadata.getScannedTables().size() == 1;
+  private boolean isLeafStage(DispatchablePlanMetadata dispatchablePlanMetadata) {
+    return dispatchablePlanMetadata.getScannedTables().size() == 1;
   }
 
   private Set<ServerInstance> fetchServersForIntermediateStage(Set<String> tableNames) {
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/routing/WorkerMetadata.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/routing/WorkerMetadata.java
new file mode 100644
index 0000000000..b99adeee29
--- /dev/null
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/routing/WorkerMetadata.java
@@ -0,0 +1,106 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.query.routing;
+
+import com.fasterxml.jackson.core.JsonProcessingException;
+import com.fasterxml.jackson.core.type.TypeReference;
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import org.apache.pinot.spi.utils.JsonUtils;
+
+
+/**
+ * {@code WorkerMetadata} is used to send worker-level info about how to execute a stage on a particular worker.
+ *
+ * <p>It contains information specific to a single worker within a stage, such as:</p>
+ * <ul>
+ *   <li>the underlying segments this particular worker needs to execute.</li>
+ *   <li>the mailbox info required to construct data transfer linkages.</li>
+ *   <li>the partition mechanism of the data being execute on this worker.</li>
+ * </ul>
+ *
+ * TODO: WorkerMetadata now doesn't have info directly about how to construct the mailboxes. instead it rely on
+ * MailboxSendNode and MailboxReceiveNode to derive the info during runtime. this should changed to plan time soon.
+ */
+public class WorkerMetadata {
+  private final VirtualServerAddress _virtualServerAddress;
+  private final Map<String, String> _customProperties;
+
+  private WorkerMetadata(VirtualServerAddress virtualServerAddress, Map<String, String> customProperties) {
+    _virtualServerAddress = virtualServerAddress;
+    _customProperties = customProperties;
+  }
+
+  public VirtualServerAddress getVirtualServerAddress() {
+    return _virtualServerAddress;
+  }
+
+  public Map<String, String> getCustomProperties() {
+    return _customProperties;
+  }
+
+  public static class Builder {
+    public static final String TABLE_SEGMENTS_MAP_KEY = "tableSegmentsMap";
+    private VirtualServerAddress _virtualServerAddress;
+    private Map<String, String> _customProperties;
+
+    public Builder() {
+      _customProperties = new HashMap<>();
+    }
+
+    public Builder setVirtualServerAddress(VirtualServerAddress virtualServerAddress) {
+      _virtualServerAddress = virtualServerAddress;
+      return this;
+    }
+
+    public Builder addTableSegmentsMap(Map<String, List<String>> tableSegmentsMap) {
+      try {
+        String tableSegmentsMapStr = JsonUtils.objectToString(tableSegmentsMap);
+        _customProperties.put(TABLE_SEGMENTS_MAP_KEY, tableSegmentsMapStr);
+      } catch (JsonProcessingException e) {
+        throw new RuntimeException("Unable to serialize table segments map", e);
+      }
+      return this;
+    }
+
+    public WorkerMetadata build() {
+      return new WorkerMetadata(_virtualServerAddress, _customProperties);
+    }
+
+    public void putAllCustomProperties(Map<String, String> customPropertyMap) {
+      _customProperties.putAll(customPropertyMap);
+    }
+  }
+
+  public static Map<String, List<String>> getTableSegmentsMap(WorkerMetadata workerMetadata) {
+    String tableSegmentKeyStr = workerMetadata.getCustomProperties().get(Builder.TABLE_SEGMENTS_MAP_KEY);
+    if (tableSegmentKeyStr != null) {
+      try {
+        return JsonUtils.stringToObject(tableSegmentKeyStr, new TypeReference<Map<String, List<String>>>() {
+        });
+      } catch (IOException e) {
+        throw new RuntimeException("Unable to deserialize table segments map", e);
+      }
+    } else {
+      return null;
+    }
+  }
+}
diff --git a/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryCompilationTest.java b/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryCompilationTest.java
index 4177e0b0b2..6d6c4937a0 100644
--- a/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryCompilationTest.java
+++ b/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryCompilationTest.java
@@ -28,9 +28,10 @@ import java.util.concurrent.locks.Lock;
 import java.util.concurrent.locks.ReentrantLock;
 import java.util.stream.Collectors;
 import org.apache.calcite.rel.RelDistribution;
+import org.apache.pinot.query.planner.ExplainPlanStageVisitor;
 import org.apache.pinot.query.planner.PlannerUtils;
 import org.apache.pinot.query.planner.QueryPlan;
-import org.apache.pinot.query.planner.StageMetadata;
+import org.apache.pinot.query.planner.physical.DispatchablePlanMetadata;
 import org.apache.pinot.query.planner.stage.AbstractStageNode;
 import org.apache.pinot.query.planner.stage.AggregateNode;
 import org.apache.pinot.query.planner.stage.FilterNode;
@@ -38,7 +39,6 @@ import org.apache.pinot.query.planner.stage.JoinNode;
 import org.apache.pinot.query.planner.stage.MailboxReceiveNode;
 import org.apache.pinot.query.planner.stage.ProjectNode;
 import org.apache.pinot.query.planner.stage.StageNode;
-import org.apache.pinot.query.routing.VirtualServer;
 import org.testng.Assert;
 import org.testng.annotations.DataProvider;
 import org.testng.annotations.Test;
@@ -79,7 +79,7 @@ public class QueryCompilationTest extends QueryEnvironmentTestBase {
   }
 
   private static void assertGroupBySingletonAfterJoin(QueryPlan queryPlan, boolean shouldRewrite) throws Exception {
-    for (Map.Entry<Integer, StageMetadata> e : queryPlan.getStageMetadataMap().entrySet()) {
+    for (Map.Entry<Integer, DispatchablePlanMetadata> e : queryPlan.getDispatchablePlanMetadataMap().entrySet()) {
       if (e.getValue().getScannedTables().size() == 0 && !PlannerUtils.isRootStage(e.getKey())) {
         StageNode node = queryPlan.getQueryStageMap().get(e.getKey());
         while (node != null) {
@@ -113,25 +113,31 @@ public class QueryCompilationTest extends QueryEnvironmentTestBase {
     String query = "SELECT * FROM a JOIN b ON a.col1 = b.col2";
     QueryPlan queryPlan = _queryEnvironment.planQuery(query);
     Assert.assertEquals(queryPlan.getQueryStageMap().size(), 4);
-    Assert.assertEquals(queryPlan.getStageMetadataMap().size(), 4);
-    for (Map.Entry<Integer, StageMetadata> e : queryPlan.getStageMetadataMap().entrySet()) {
+    Assert.assertEquals(queryPlan.getDispatchablePlanMetadataMap().size(), 4);
+    for (Map.Entry<Integer, DispatchablePlanMetadata> e : queryPlan.getDispatchablePlanMetadataMap().entrySet()) {
       List<String> tables = e.getValue().getScannedTables();
       if (tables.size() == 1) {
         // table scan stages; for tableA it should have 2 hosts, for tableB it should have only 1
         Assert.assertEquals(
-            e.getValue().getServerInstances().stream().map(VirtualServer::toString).collect(Collectors.toList()),
-            tables.get(0).equals("a") ? ImmutableList.of("0@Server_localhost_2", "1@Server_localhost_1")
-                : ImmutableList.of("0@Server_localhost_1"));
+            e.getValue().getServerInstanceToWorkerIdMap().entrySet().stream()
+                .map(ExplainPlanStageVisitor::stringifyQueryServerInstanceToWorkerIdsEntry)
+                .collect(Collectors.toSet()),
+            tables.get(0).equals("a") ? ImmutableList.of("localhost@{1,1}|[1]", "localhost@{2,2}|[0]")
+                : ImmutableList.of("localhost@{1,1}|[0]"));
       } else if (!PlannerUtils.isRootStage(e.getKey())) {
         // join stage should have both servers used.
         Assert.assertEquals(
-            e.getValue().getServerInstances().stream().map(VirtualServer::toString).collect(Collectors.toSet()),
-            ImmutableSet.of("1@Server_localhost_1", "0@Server_localhost_2"));
+            e.getValue().getServerInstanceToWorkerIdMap().entrySet().stream()
+                .map(ExplainPlanStageVisitor::stringifyQueryServerInstanceToWorkerIdsEntry)
+                .collect(Collectors.toSet()),
+            ImmutableSet.of("localhost@{1,1}|[1]", "localhost@{2,2}|[0]"));
       } else {
         // reduce stage should have the reducer instance.
         Assert.assertEquals(
-            e.getValue().getServerInstances().stream().map(VirtualServer::toString).collect(Collectors.toSet()),
-            ImmutableSet.of("0@Server_localhost_3"));
+            e.getValue().getServerInstanceToWorkerIdMap().entrySet().stream()
+                .map(ExplainPlanStageVisitor::stringifyQueryServerInstanceToWorkerIdsEntry)
+                .collect(Collectors.toSet()),
+            ImmutableSet.of("localhost@{3,3}|[0]"));
       }
     }
   }
@@ -142,7 +148,8 @@ public class QueryCompilationTest extends QueryEnvironmentTestBase {
         + "WHERE a.col3 >= 0 AND a.col2 IN ('b') AND b.col3 < 0";
     QueryPlan queryPlan = _queryEnvironment.planQuery(query);
     List<StageNode> intermediateStageRoots =
-        queryPlan.getStageMetadataMap().entrySet().stream().filter(e -> e.getValue().getScannedTables().size() == 0)
+        queryPlan.getDispatchablePlanMetadataMap().entrySet().stream()
+            .filter(e -> e.getValue().getScannedTables().size() == 0)
             .map(e -> queryPlan.getQueryStageMap().get(e.getKey())).collect(Collectors.toList());
     // Assert that no project of filter node for any intermediate stage because all should've been pushed down.
     for (StageNode roots : intermediateStageRoots) {
@@ -154,25 +161,24 @@ public class QueryCompilationTest extends QueryEnvironmentTestBase {
   public void testQueryRoutingManagerCompilation() {
     String query = "SELECT * FROM d_OFFLINE";
     QueryPlan queryPlan = _queryEnvironment.planQuery(query);
-    List<StageMetadata> tableScanMetadataList = queryPlan.getStageMetadataMap().values().stream()
+    List<DispatchablePlanMetadata> tableScanMetadataList = queryPlan.getDispatchablePlanMetadataMap().values().stream()
         .filter(stageMetadata -> stageMetadata.getScannedTables().size() != 0).collect(Collectors.toList());
     Assert.assertEquals(tableScanMetadataList.size(), 1);
-    Assert.assertEquals(tableScanMetadataList.get(0).getServerInstances().size(), 2);
+    Assert.assertEquals(tableScanMetadataList.get(0).getServerInstanceToWorkerIdMap().size(), 2);
 
     query = "SELECT * FROM d_REALTIME";
     queryPlan = _queryEnvironment.planQuery(query);
-    tableScanMetadataList = queryPlan.getStageMetadataMap().values().stream()
+    tableScanMetadataList = queryPlan.getDispatchablePlanMetadataMap().values().stream()
         .filter(stageMetadata -> stageMetadata.getScannedTables().size() != 0).collect(Collectors.toList());
     Assert.assertEquals(tableScanMetadataList.size(), 1);
-    Assert.assertEquals(tableScanMetadataList.get(0).getServerInstances().size(), 1);
-    Assert.assertEquals(tableScanMetadataList.get(0).getServerInstances().get(0).toString(), "0@Server_localhost_2");
+    Assert.assertEquals(tableScanMetadataList.get(0).getServerInstanceToWorkerIdMap().size(), 1);
 
     query = "SELECT * FROM d";
     queryPlan = _queryEnvironment.planQuery(query);
-    tableScanMetadataList = queryPlan.getStageMetadataMap().values().stream()
+    tableScanMetadataList = queryPlan.getDispatchablePlanMetadataMap().values().stream()
         .filter(stageMetadata -> stageMetadata.getScannedTables().size() != 0).collect(Collectors.toList());
     Assert.assertEquals(tableScanMetadataList.size(), 1);
-    Assert.assertEquals(tableScanMetadataList.get(0).getServerInstances().size(), 2);
+    Assert.assertEquals(tableScanMetadataList.get(0).getServerInstanceToWorkerIdMap().size(), 2);
   }
 
   // Test that plan query can be run as multi-thread.
@@ -233,19 +239,21 @@ public class QueryCompilationTest extends QueryEnvironmentTestBase {
     String query = "SELECT /*+ aggFinalStage */ col1, COUNT(*) FROM b GROUP BY col1";
     QueryPlan queryPlan = _queryEnvironment.planQuery(query);
     Assert.assertEquals(queryPlan.getQueryStageMap().size(), 2);
-    Assert.assertEquals(queryPlan.getStageMetadataMap().size(), 2);
-    for (Map.Entry<Integer, StageMetadata> e : queryPlan.getStageMetadataMap().entrySet()) {
+    Assert.assertEquals(queryPlan.getDispatchablePlanMetadataMap().size(), 2);
+    for (Map.Entry<Integer, DispatchablePlanMetadata> e : queryPlan.getDispatchablePlanMetadataMap().entrySet()) {
       List<String> tables = e.getValue().getScannedTables();
       if (tables.size() != 0) {
         // table scan stages; for tableB it should have only 1
-        Assert.assertEquals(e.getValue().getServerInstances().stream()
-                .map(VirtualServer::toString).sorted().collect(Collectors.toList()),
-            ImmutableList.of("0@Server_localhost_1"));
+        Assert.assertEquals(e.getValue().getServerInstanceToWorkerIdMap().entrySet().stream()
+                .map(ExplainPlanStageVisitor::stringifyQueryServerInstanceToWorkerIdsEntry)
+                .collect(Collectors.toSet()),
+            ImmutableList.of("localhost@{1,1}|[0]"));
       } else if (!PlannerUtils.isRootStage(e.getKey())) {
         // join stage should have both servers used.
-        Assert.assertEquals(e.getValue().getServerInstances().stream()
-                .map(VirtualServer::toString).sorted().collect(Collectors.toList()),
-            ImmutableList.of("0@Server_localhost_1", "0@Server_localhost_2"));
+        Assert.assertEquals(e.getValue().getServerInstanceToWorkerIdMap().entrySet().stream()
+                .map(ExplainPlanStageVisitor::stringifyQueryServerInstanceToWorkerIdsEntry)
+                .collect(Collectors.toSet()),
+            ImmutableList.of("localhost@{1,1}|[1]", "localhost@{2,2}|[0]"));
       }
     }
   }
diff --git a/pinot-query-planner/src/test/java/org/apache/pinot/query/testutils/MockRoutingManagerFactory.java b/pinot-query-planner/src/test/java/org/apache/pinot/query/testutils/MockRoutingManagerFactory.java
index 46dd4aff8f..18850dbccc 100644
--- a/pinot-query-planner/src/test/java/org/apache/pinot/query/testutils/MockRoutingManagerFactory.java
+++ b/pinot-query-planner/src/test/java/org/apache/pinot/query/testutils/MockRoutingManagerFactory.java
@@ -24,15 +24,17 @@ import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.concurrent.TimeUnit;
+import org.apache.helix.model.InstanceConfig;
+import org.apache.helix.zookeeper.datamodel.ZNRecord;
 import org.apache.pinot.common.config.provider.TableCache;
 import org.apache.pinot.common.request.BrokerRequest;
 import org.apache.pinot.core.routing.RoutingManager;
 import org.apache.pinot.core.routing.RoutingTable;
 import org.apache.pinot.core.routing.TimeBoundaryInfo;
 import org.apache.pinot.core.transport.ServerInstance;
-import org.apache.pinot.query.routing.WorkerInstance;
 import org.apache.pinot.spi.config.table.TableType;
 import org.apache.pinot.spi.data.Schema;
+import org.apache.pinot.spi.utils.CommonConstants;
 import org.apache.pinot.spi.utils.builder.TableNameBuilder;
 
 import static org.mockito.ArgumentMatchers.anyString;
@@ -65,7 +67,7 @@ public class MockRoutingManagerFactory {
 
     _tableServerSegmentMap = new HashMap<>();
     for (int port : ports) {
-      _serverInstances.put(toHostname(port), new WorkerInstance(HOST_NAME, port, port, port, port));
+      _serverInstances.put(toHostname(port), getServerInstance(HOST_NAME, port, port, port, port));
     }
   }
 
@@ -119,6 +121,20 @@ public class MockRoutingManagerFactory {
     return String.format("%s_%d", HOST_NAME, port);
   }
 
+  private static ServerInstance getServerInstance(String hostname, int nettyPort, int grpcPort, int servicePort,
+      int mailboxPort) {
+    String server = String.format("%s%s_%d", CommonConstants.Helix.PREFIX_OF_SERVER_INSTANCE, hostname, nettyPort);
+    InstanceConfig instanceConfig = InstanceConfig.toInstanceConfig(server);
+    ZNRecord znRecord = instanceConfig.getRecord();
+    Map<String, String> simpleFields = znRecord.getSimpleFields();
+    simpleFields.put(CommonConstants.Helix.Instance.GRPC_PORT_KEY, String.valueOf(grpcPort));
+    simpleFields.put(CommonConstants.Helix.Instance.MULTI_STAGE_QUERY_ENGINE_SERVICE_PORT_KEY,
+        String.valueOf(servicePort));
+    simpleFields.put(CommonConstants.Helix.Instance.MULTI_STAGE_QUERY_ENGINE_MAILBOX_PORT_KEY,
+        String.valueOf(mailboxPort));
+    return new ServerInstance(instanceConfig);
+  }
+
   private void registerTableNameWithType(Schema schema, String tableNameWithType) {
     String rawTableName = TableNameBuilder.extractRawTableName(tableNameWithType);
     _tableNameMap.put(tableNameWithType, rawTableName);
diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/QueryRunner.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/QueryRunner.java
index 5d6e764125..b30a1f9f93 100644
--- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/QueryRunner.java
+++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/QueryRunner.java
@@ -18,7 +18,6 @@
  */
 package org.apache.pinot.query.runtime;
 
-import com.google.common.base.Preconditions;
 import java.util.ArrayList;
 import java.util.List;
 import java.util.Map;
@@ -40,11 +39,11 @@ import org.apache.pinot.core.query.executor.ServerQueryExecutorV1Impl;
 import org.apache.pinot.core.query.request.ServerQueryRequest;
 import org.apache.pinot.core.query.scheduler.resources.ResourceManager;
 import org.apache.pinot.query.mailbox.MailboxService;
-import org.apache.pinot.query.planner.StageMetadata;
 import org.apache.pinot.query.planner.stage.MailboxSendNode;
 import org.apache.pinot.query.planner.stage.StageNode;
-import org.apache.pinot.query.routing.VirtualServer;
+import org.apache.pinot.query.routing.StageMetadata;
 import org.apache.pinot.query.routing.VirtualServerAddress;
+import org.apache.pinot.query.routing.WorkerMetadata;
 import org.apache.pinot.query.runtime.blocks.TransferableBlockUtils;
 import org.apache.pinot.query.runtime.executor.OpChainSchedulerService;
 import org.apache.pinot.query.runtime.executor.RoundRobinScheduler;
@@ -169,8 +168,7 @@ public class QueryRunner {
       StageNode stageRoot = distributedStagePlan.getStageRoot();
       OpChain rootOperator = PhysicalPlanVisitor.build(stageRoot,
           new PlanRequestContext(_mailboxService, requestId, stageRoot.getStageId(), timeoutMs, deadlineMs,
-              new VirtualServerAddress(distributedStagePlan.getServer()), distributedStagePlan.getMetadataMap(),
-              isTraceEnabled));
+              distributedStagePlan.getServer(), distributedStagePlan.getStageMetadataList(), isTraceEnabled));
       _scheduler.register(rootOperator);
     }
   }
@@ -214,8 +212,8 @@ public class QueryRunner {
       MailboxSendNode sendNode = (MailboxSendNode) distributedStagePlan.getStageRoot();
       OpChainExecutionContext opChainExecutionContext =
           new OpChainExecutionContext(_mailboxService, requestId, sendNode.getStageId(),
-              new VirtualServerAddress(distributedStagePlan.getServer()), timeoutMs, deadlineMs,
-              distributedStagePlan.getMetadataMap(), isTraceEnabled);
+              distributedStagePlan.getServer(), timeoutMs, deadlineMs, distributedStagePlan.getStageMetadataList(),
+              isTraceEnabled);
       MultiStageOperator leafStageOperator =
           new LeafStageTransferableBlockOperator(opChainExecutionContext, serverQueryResults, sendNode.getDataSchema());
       mailboxSendOperator =
@@ -238,12 +236,10 @@ public class QueryRunner {
   private static List<ServerPlanRequestContext> constructServerQueryRequests(DistributedStagePlan distributedStagePlan,
       Map<String, String> requestMetadataMap, ZkHelixPropertyStore<ZNRecord> helixPropertyStore,
       MailboxService mailboxService, long deadlineMs) {
-    StageMetadata stageMetadata = distributedStagePlan.getMetadataMap().get(distributedStagePlan.getStageId());
-    Preconditions.checkState(stageMetadata.getScannedTables().size() == 1,
-        "Server request for V2 engine should only have 1 scan table per request.");
-    String rawTableName = stageMetadata.getScannedTables().get(0);
-    Map<String, List<String>> tableToSegmentListMap =
-        stageMetadata.getServerInstanceToSegmentsMap().get(distributedStagePlan.getServer().getServer());
+    StageMetadata stageMetadata = distributedStagePlan.getCurrentStageMetadata();
+    WorkerMetadata workerMetadata = distributedStagePlan.getCurrentWorkerMetadata();
+    String rawTableName = StageMetadata.getTableName(stageMetadata);
+    Map<String, List<String>> tableToSegmentListMap = WorkerMetadata.getTableSegmentsMap(workerMetadata);
     List<ServerPlanRequestContext> requests = new ArrayList<>();
     for (Map.Entry<String, List<String>> tableEntry : tableToSegmentListMap.entrySet()) {
       String tableType = tableEntry.getKey();
@@ -255,17 +251,17 @@ public class QueryRunner {
             TableNameBuilder.forType(TableType.OFFLINE).tableNameWithType(rawTableName));
         Schema schema = ZKMetadataProvider.getTableSchema(helixPropertyStore,
             TableNameBuilder.forType(TableType.OFFLINE).tableNameWithType(rawTableName));
-        requests.add(
-            ServerRequestPlanVisitor.build(mailboxService, distributedStagePlan, requestMetadataMap, tableConfig,
-                schema, stageMetadata.getTimeBoundaryInfo(), TableType.OFFLINE, tableEntry.getValue(), deadlineMs));
+        requests.add(ServerRequestPlanVisitor.build(mailboxService, distributedStagePlan, requestMetadataMap,
+            tableConfig, schema, StageMetadata.getTimeBoundary(stageMetadata), TableType.OFFLINE,
+            tableEntry.getValue(), deadlineMs));
       } else if (TableType.REALTIME.name().equals(tableType)) {
         TableConfig tableConfig = ZKMetadataProvider.getTableConfig(helixPropertyStore,
             TableNameBuilder.forType(TableType.REALTIME).tableNameWithType(rawTableName));
         Schema schema = ZKMetadataProvider.getTableSchema(helixPropertyStore,
             TableNameBuilder.forType(TableType.REALTIME).tableNameWithType(rawTableName));
-        requests.add(
-            ServerRequestPlanVisitor.build(mailboxService, distributedStagePlan, requestMetadataMap, tableConfig,
-                schema, stageMetadata.getTimeBoundaryInfo(), TableType.REALTIME, tableEntry.getValue(), deadlineMs));
+        requests.add(ServerRequestPlanVisitor.build(mailboxService, distributedStagePlan, requestMetadataMap,
+            tableConfig, schema, StageMetadata.getTimeBoundary(stageMetadata), TableType.REALTIME,
+            tableEntry.getValue(), deadlineMs));
       } else {
         throw new IllegalArgumentException("Unsupported table type key: " + tableType);
       }
@@ -286,10 +282,8 @@ public class QueryRunner {
   }
 
   private boolean isLeafStage(DistributedStagePlan distributedStagePlan) {
-    int stageId = distributedStagePlan.getStageId();
-    VirtualServer serverInstance = distributedStagePlan.getServer();
-    StageMetadata stageMetadata = distributedStagePlan.getMetadataMap().get(stageId);
-    Map<String, List<String>> segments = stageMetadata.getServerInstanceToSegmentsMap().get(serverInstance.getServer());
+    WorkerMetadata workerMetadata = distributedStagePlan.getCurrentWorkerMetadata();
+    Map<String, List<String>> segments = WorkerMetadata.getTableSegmentsMap(workerMetadata);
     return segments != null && segments.size() > 0;
   }
 }
diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/BaseMailboxReceiveOperator.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/BaseMailboxReceiveOperator.java
index 33c346cff3..6934754b3d 100644
--- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/BaseMailboxReceiveOperator.java
+++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/BaseMailboxReceiveOperator.java
@@ -29,8 +29,8 @@ import org.apache.calcite.rel.RelDistribution;
 import org.apache.pinot.query.mailbox.MailboxIdUtils;
 import org.apache.pinot.query.mailbox.MailboxService;
 import org.apache.pinot.query.mailbox.ReceivingMailbox;
-import org.apache.pinot.query.routing.VirtualServer;
 import org.apache.pinot.query.routing.VirtualServerAddress;
+import org.apache.pinot.query.routing.WorkerMetadata;
 import org.apache.pinot.query.runtime.plan.OpChainExecutionContext;
 
 
@@ -59,26 +59,27 @@ public abstract class BaseMailboxReceiveOperator extends MultiStageOperator {
 
     long requestId = context.getRequestId();
     int receiverStageId = context.getStageId();
-    List<VirtualServer> senders = context.getMetadataMap().get(senderStageId).getServerInstances();
+    List<WorkerMetadata> senderMetadataList = context.getStageMetadataList().get(senderStageId).getWorkerMetadataList();
     VirtualServerAddress receiver = context.getServer();
     if (exchangeType == RelDistribution.Type.SINGLETON) {
-      VirtualServer singletonInstance = null;
-      for (VirtualServer sender : senders) {
-        if (sender.getHostname().equals(_mailboxService.getHostname())
-            && sender.getQueryMailboxPort() == _mailboxService.getPort()) {
-          Preconditions.checkState(singletonInstance == null, "Multiple instances found for SINGLETON exchange type");
-          singletonInstance = sender;
+      VirtualServerAddress singletonSender = null;
+      for (WorkerMetadata senderMetadata : senderMetadataList) {
+        VirtualServerAddress sender = senderMetadata.getVirtualServerAddress();
+        if (sender.hostname().equals(_mailboxService.getHostname()) && sender.port() == _mailboxService.getPort()) {
+          Preconditions.checkState(singletonSender == null, "Multiple instances found for SINGLETON exchange type");
+          singletonSender = sender;
         }
       }
-      Preconditions.checkState(singletonInstance != null, "Failed to find instance for SINGLETON exchange type");
+      Preconditions.checkState(singletonSender != null, "Failed to find instance for SINGLETON exchange type");
       _mailboxIds = Collections.singletonList(
-          MailboxIdUtils.toMailboxId(requestId, senderStageId, singletonInstance.getVirtualId(), receiverStageId,
-              receiver.virtualId()));
+          MailboxIdUtils.toMailboxId(requestId, senderStageId, singletonSender.workerId(), receiverStageId,
+              receiver.workerId()));
     } else {
-      _mailboxIds = new ArrayList<>(senders.size());
-      for (VirtualServer sender : senders) {
-        _mailboxIds.add(MailboxIdUtils.toMailboxId(requestId, senderStageId, sender.getVirtualId(), receiverStageId,
-            receiver.virtualId()));
+      _mailboxIds = new ArrayList<>(senderMetadataList.size());
+      for (WorkerMetadata senderMetadata : senderMetadataList) {
+        VirtualServerAddress sender = senderMetadata.getVirtualServerAddress();
+        _mailboxIds.add(MailboxIdUtils.toMailboxId(requestId, senderStageId, sender.workerId(), receiverStageId,
+            receiver.workerId()));
       }
     }
     _mailboxes = _mailboxIds.stream().map(_mailboxService::getReceivingMailbox)
diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MailboxSendOperator.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MailboxSendOperator.java
index 75181de262..01447b2c72 100644
--- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MailboxSendOperator.java
+++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MailboxSendOperator.java
@@ -33,7 +33,8 @@ import org.apache.pinot.query.mailbox.MailboxService;
 import org.apache.pinot.query.mailbox.SendingMailbox;
 import org.apache.pinot.query.planner.logical.RexExpression;
 import org.apache.pinot.query.planner.partitioning.KeySelector;
-import org.apache.pinot.query.routing.VirtualServer;
+import org.apache.pinot.query.routing.VirtualServerAddress;
+import org.apache.pinot.query.routing.WorkerMetadata;
 import org.apache.pinot.query.runtime.blocks.TransferableBlock;
 import org.apache.pinot.query.runtime.blocks.TransferableBlockUtils;
 import org.apache.pinot.query.runtime.operator.exchange.BlockExchange;
@@ -89,34 +90,36 @@ public class MailboxSendOperator extends MultiStageOperator {
     MailboxService mailboxService = context.getMailboxService();
     long requestId = context.getRequestId();
     int senderStageId = context.getStageId();
-    int senderWorkerId = context.getServer().virtualId();
+    int senderWorkerId = context.getServer().workerId();
     long deadlineMs = context.getDeadlineMs();
-    List<VirtualServer> receivingStageInstances = context.getMetadataMap().get(receiverStageId).getServerInstances();
+    List<WorkerMetadata> receivingMetadataList = context.getStageMetadataList().get(receiverStageId)
+        .getWorkerMetadataList();
     List<SendingMailbox> sendingMailboxes;
     if (exchangeType == RelDistribution.Type.SINGLETON) {
       // TODO: this logic should be moved into SingletonExchange
-      VirtualServer singletonInstance = null;
-      for (VirtualServer serverInstance : receivingStageInstances) {
-        if (serverInstance.getHostname().equals(mailboxService.getHostname())
-            && serverInstance.getQueryMailboxPort() == mailboxService.getPort()) {
-          Preconditions.checkState(singletonInstance == null, "Multiple instances found for SINGLETON exchange type");
-          singletonInstance = serverInstance;
+      VirtualServerAddress singletonReceiver = null;
+      for (WorkerMetadata receivingMetadata : receivingMetadataList) {
+        VirtualServerAddress receiver = receivingMetadata.getVirtualServerAddress();
+        if (receiver.hostname().equals(mailboxService.getHostname())
+            && receiver.port() == mailboxService.getPort()) {
+          Preconditions.checkState(singletonReceiver == null, "Multiple instances found for SINGLETON exchange type");
+          singletonReceiver = receiver;
         }
       }
-      Preconditions.checkState(singletonInstance != null, "Failed to find instance for SINGLETON exchange type");
+      Preconditions.checkState(singletonReceiver != null, "Failed to find instance for SINGLETON exchange type");
       String mailboxId = MailboxIdUtils.toMailboxId(requestId, senderStageId, senderWorkerId, receiverStageId,
-          singletonInstance.getVirtualId());
+          singletonReceiver.workerId());
       sendingMailboxes = Collections.singletonList(
-          mailboxService.getSendingMailbox(singletonInstance.getHostname(), singletonInstance.getQueryMailboxPort(),
-              mailboxId, deadlineMs));
+          mailboxService.getSendingMailbox(singletonReceiver.hostname(), singletonReceiver.port(), mailboxId,
+              deadlineMs));
     } else {
-      sendingMailboxes = new ArrayList<>(receivingStageInstances.size());
-      for (VirtualServer instance : receivingStageInstances) {
-        String mailboxId = MailboxIdUtils.toMailboxId(requestId, senderStageId, senderWorkerId, receiverStageId,
-            instance.getVirtualId());
+      sendingMailboxes = new ArrayList<>(receivingMetadataList.size());
+      for (WorkerMetadata receivingMetadata : receivingMetadataList) {
+        VirtualServerAddress receiver = receivingMetadata.getVirtualServerAddress();
+        String mailboxId =
+            MailboxIdUtils.toMailboxId(requestId, senderStageId, senderWorkerId, receiverStageId, receiver.workerId());
         sendingMailboxes.add(
-            mailboxService.getSendingMailbox(instance.getHostname(), instance.getQueryMailboxPort(), mailboxId,
-                deadlineMs));
+            mailboxService.getSendingMailbox(receiver.hostname(), receiver.port(), mailboxId, deadlineMs));
       }
     }
     return BlockExchange.getExchange(sendingMailboxes, exchangeType, keySelector, TransferableBlockUtils::splitBlock);
diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/utils/OperatorUtils.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/utils/OperatorUtils.java
index 601f169cb3..341c0b5969 100644
--- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/utils/OperatorUtils.java
+++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/utils/OperatorUtils.java
@@ -20,13 +20,12 @@ package org.apache.pinot.query.runtime.operator.utils;
 
 import com.fasterxml.jackson.core.type.TypeReference;
 import com.fasterxml.jackson.databind.JsonNode;
-import com.google.common.base.Joiner;
 import java.util.HashMap;
 import java.util.Map;
 import org.apache.commons.lang.StringUtils;
 import org.apache.pinot.common.datablock.MetadataBlock;
 import org.apache.pinot.common.datatable.DataTable;
-import org.apache.pinot.query.planner.StageMetadata;
+import org.apache.pinot.query.routing.StageMetadata;
 import org.apache.pinot.query.routing.VirtualServerAddress;
 import org.apache.pinot.query.runtime.operator.OperatorStats;
 import org.apache.pinot.spi.utils.JsonUtils;
@@ -68,10 +67,10 @@ public class OperatorUtils {
     return functionName;
   }
 
-  public static void recordTableName(OperatorStats operatorStats, StageMetadata operatorStageMetadata) {
-    if (!operatorStageMetadata.getScannedTables().isEmpty()) {
+  public static void recordTableName(OperatorStats operatorStats, StageMetadata stageMetadata) {
+    if (StageMetadata.getTableName(stageMetadata) != null) {
       operatorStats.recordSingleStat(DataTable.MetadataKey.TABLE.getName(),
-          Joiner.on("::").join(operatorStageMetadata.getScannedTables()));
+          StageMetadata.getTableName(stageMetadata));
     }
   }
 
diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/DistributedStagePlan.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/DistributedStagePlan.java
index c508d746ca..c67de128cb 100644
--- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/DistributedStagePlan.java
+++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/DistributedStagePlan.java
@@ -18,11 +18,12 @@
  */
 package org.apache.pinot.query.runtime.plan;
 
-import java.util.HashMap;
-import java.util.Map;
-import org.apache.pinot.query.planner.StageMetadata;
+import java.util.ArrayList;
+import java.util.List;
 import org.apache.pinot.query.planner.stage.StageNode;
-import org.apache.pinot.query.routing.VirtualServer;
+import org.apache.pinot.query.routing.StageMetadata;
+import org.apache.pinot.query.routing.VirtualServerAddress;
+import org.apache.pinot.query.routing.WorkerMetadata;
 
 
 /**
@@ -33,28 +34,28 @@ import org.apache.pinot.query.routing.VirtualServer;
  */
 public class DistributedStagePlan {
   private int _stageId;
-  private VirtualServer _server;
+  private VirtualServerAddress _server;
   private StageNode _stageRoot;
-  private Map<Integer, StageMetadata> _metadataMap;
+  private List<StageMetadata> _stageMetadataList;
 
   public DistributedStagePlan(int stageId) {
     _stageId = stageId;
-    _metadataMap = new HashMap<>();
+    _stageMetadataList = new ArrayList<>();
   }
 
-  public DistributedStagePlan(int stageId, VirtualServer server, StageNode stageRoot,
-      Map<Integer, StageMetadata> metadataMap) {
+  public DistributedStagePlan(int stageId, VirtualServerAddress server, StageNode stageRoot,
+      List<StageMetadata> stageMetadataList) {
     _stageId = stageId;
     _server = server;
     _stageRoot = stageRoot;
-    _metadataMap = metadataMap;
+    _stageMetadataList = stageMetadataList;
   }
 
   public int getStageId() {
     return _stageId;
   }
 
-  public VirtualServer getServer() {
+  public VirtualServerAddress getServer() {
     return _server;
   }
 
@@ -62,15 +63,23 @@ public class DistributedStagePlan {
     return _stageRoot;
   }
 
-  public Map<Integer, StageMetadata> getMetadataMap() {
-    return _metadataMap;
+  public List<StageMetadata> getStageMetadataList() {
+    return _stageMetadataList;
   }
 
-  public void setServer(VirtualServer serverInstance) {
-    _server = serverInstance;
+  public void setServer(VirtualServerAddress serverAddress) {
+    _server = serverAddress;
   }
 
   public void setStageRoot(StageNode stageRoot) {
     _stageRoot = stageRoot;
   }
+
+  public StageMetadata getCurrentStageMetadata() {
+    return _stageMetadataList.get(_stageId);
+  }
+
+  public WorkerMetadata getCurrentWorkerMetadata() {
+    return getCurrentStageMetadata().getWorkerMetadataList().get(_server.workerId());
+  }
 }
diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/OpChainExecutionContext.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/OpChainExecutionContext.java
index 0850c0414c..bb94ba32e6 100644
--- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/OpChainExecutionContext.java
+++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/OpChainExecutionContext.java
@@ -18,9 +18,9 @@
  */
 package org.apache.pinot.query.runtime.plan;
 
-import java.util.Map;
+import java.util.List;
 import org.apache.pinot.query.mailbox.MailboxService;
-import org.apache.pinot.query.planner.StageMetadata;
+import org.apache.pinot.query.routing.StageMetadata;
 import org.apache.pinot.query.routing.VirtualServerAddress;
 import org.apache.pinot.query.runtime.operator.OpChainId;
 import org.apache.pinot.query.runtime.operator.OpChainStats;
@@ -38,13 +38,13 @@ public class OpChainExecutionContext {
   private final VirtualServerAddress _server;
   private final long _timeoutMs;
   private final long _deadlineMs;
-  private final Map<Integer, StageMetadata> _metadataMap;
+  private final List<StageMetadata> _stageMetadataList;
   private final OpChainId _id;
   private final OpChainStats _stats;
   private final boolean _traceEnabled;
 
   public OpChainExecutionContext(MailboxService mailboxService, long requestId, int stageId,
-      VirtualServerAddress server, long timeoutMs, long deadlineMs, Map<Integer, StageMetadata> metadataMap,
+      VirtualServerAddress server, long timeoutMs, long deadlineMs, List<StageMetadata> stageMetadataList,
       boolean traceEnabled) {
     _mailboxService = mailboxService;
     _requestId = requestId;
@@ -52,8 +52,8 @@ public class OpChainExecutionContext {
     _server = server;
     _timeoutMs = timeoutMs;
     _deadlineMs = deadlineMs;
-    _metadataMap = metadataMap;
-    _id = new OpChainId(requestId, server.virtualId(), stageId);
+    _stageMetadataList = stageMetadataList;
+    _id = new OpChainId(requestId, server.workerId(), stageId);
     _stats = new OpChainStats(_id.toString());
     _traceEnabled = traceEnabled;
   }
@@ -61,7 +61,7 @@ public class OpChainExecutionContext {
   public OpChainExecutionContext(PlanRequestContext planRequestContext) {
     this(planRequestContext.getMailboxService(), planRequestContext.getRequestId(), planRequestContext.getStageId(),
         planRequestContext.getServer(), planRequestContext.getTimeoutMs(), planRequestContext.getDeadlineMs(),
-        planRequestContext.getMetadataMap(), planRequestContext.isTraceEnabled());
+        planRequestContext.getStageMetadataList(), planRequestContext.isTraceEnabled());
   }
 
   public MailboxService getMailboxService() {
@@ -88,8 +88,8 @@ public class OpChainExecutionContext {
     return _deadlineMs;
   }
 
-  public Map<Integer, StageMetadata> getMetadataMap() {
-    return _metadataMap;
+  public List<StageMetadata> getStageMetadataList() {
+    return _stageMetadataList;
   }
 
   public OpChainId getId() {
diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/PlanRequestContext.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/PlanRequestContext.java
index cdd570a45d..b61021958c 100644
--- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/PlanRequestContext.java
+++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/PlanRequestContext.java
@@ -20,9 +20,8 @@ package org.apache.pinot.query.runtime.plan;
 
 import java.util.ArrayList;
 import java.util.List;
-import java.util.Map;
 import org.apache.pinot.query.mailbox.MailboxService;
-import org.apache.pinot.query.planner.StageMetadata;
+import org.apache.pinot.query.routing.StageMetadata;
 import org.apache.pinot.query.routing.VirtualServerAddress;
 
 
@@ -34,20 +33,20 @@ public class PlanRequestContext {
   private final long _timeoutMs;
   private final long _deadlineMs;
   protected final VirtualServerAddress _server;
-  protected final Map<Integer, StageMetadata> _metadataMap;
+  protected final List<StageMetadata> _stageMetadataList;
   protected final List<String> _receivingMailboxIds = new ArrayList<>();
   private final OpChainExecutionContext _opChainExecutionContext;
   private final boolean _traceEnabled;
 
   public PlanRequestContext(MailboxService mailboxService, long requestId, int stageId, long timeoutMs, long deadlineMs,
-      VirtualServerAddress server, Map<Integer, StageMetadata> metadataMap, boolean traceEnabled) {
+      VirtualServerAddress server, List<StageMetadata> stageMetadataList, boolean traceEnabled) {
     _mailboxService = mailboxService;
     _requestId = requestId;
     _stageId = stageId;
     _timeoutMs = timeoutMs;
     _deadlineMs = deadlineMs;
     _server = server;
-    _metadataMap = metadataMap;
+    _stageMetadataList = stageMetadataList;
     _traceEnabled = traceEnabled;
     _opChainExecutionContext = new OpChainExecutionContext(this);
   }
@@ -72,8 +71,8 @@ public class PlanRequestContext {
     return _server;
   }
 
-  public Map<Integer, StageMetadata> getMetadataMap() {
-    return _metadataMap;
+  public List<StageMetadata> getStageMetadataList() {
+    return _stageMetadataList;
   }
 
   public MailboxService getMailboxService() {
diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/ServerRequestPlanVisitor.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/ServerRequestPlanVisitor.java
index e0e2921b4a..ef2984e4ff 100644
--- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/ServerRequestPlanVisitor.java
+++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/ServerRequestPlanVisitor.java
@@ -50,7 +50,6 @@ import org.apache.pinot.query.planner.stage.StageNodeVisitor;
 import org.apache.pinot.query.planner.stage.TableScanNode;
 import org.apache.pinot.query.planner.stage.ValueNode;
 import org.apache.pinot.query.planner.stage.WindowNode;
-import org.apache.pinot.query.routing.VirtualServerAddress;
 import org.apache.pinot.query.runtime.plan.server.ServerPlanRequestContext;
 import org.apache.pinot.query.service.QueryConfig;
 import org.apache.pinot.spi.config.table.TableConfig;
@@ -108,8 +107,8 @@ public class ServerRequestPlanVisitor implements StageNodeVisitor<Void, ServerPl
     pinotQuery.setExplain(false);
     ServerPlanRequestContext context =
         new ServerPlanRequestContext(mailboxService, requestId, stagePlan.getStageId(), timeoutMs, deadlineMs,
-            new VirtualServerAddress(stagePlan.getServer()), stagePlan.getMetadataMap(), pinotQuery, tableType,
-            timeBoundaryInfo, traceEnabled);
+            stagePlan.getServer(), stagePlan.getStageMetadataList(), pinotQuery, tableType, timeBoundaryInfo,
+            traceEnabled);
 
     // visit the plan and create query physical plan.
     ServerRequestPlanVisitor.walkStageNode(stagePlan.getStageRoot(), context);
diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/serde/QueryPlanSerDeUtils.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/serde/QueryPlanSerDeUtils.java
index 36f7e483cd..590a4c2902 100644
--- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/serde/QueryPlanSerDeUtils.java
+++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/serde/QueryPlanSerDeUtils.java
@@ -18,19 +18,16 @@
  */
 package org.apache.pinot.query.runtime.plan.serde;
 
-import java.util.HashMap;
+import java.util.ArrayList;
 import java.util.List;
-import java.util.Map;
 import java.util.regex.Matcher;
 import java.util.regex.Pattern;
 import org.apache.pinot.common.proto.Worker;
-import org.apache.pinot.core.routing.TimeBoundaryInfo;
-import org.apache.pinot.core.transport.ServerInstance;
-import org.apache.pinot.query.planner.StageMetadata;
 import org.apache.pinot.query.planner.stage.AbstractStageNode;
 import org.apache.pinot.query.planner.stage.StageNodeSerDeUtils;
-import org.apache.pinot.query.routing.VirtualServer;
-import org.apache.pinot.query.routing.WorkerInstance;
+import org.apache.pinot.query.routing.StageMetadata;
+import org.apache.pinot.query.routing.VirtualServerAddress;
+import org.apache.pinot.query.routing.WorkerMetadata;
 import org.apache.pinot.query.runtime.plan.DistributedStagePlan;
 
 
@@ -45,109 +42,87 @@ public class QueryPlanSerDeUtils {
 
   public static DistributedStagePlan deserialize(Worker.StagePlan stagePlan) {
     DistributedStagePlan distributedStagePlan = new DistributedStagePlan(stagePlan.getStageId());
-    distributedStagePlan.setServer(stringToInstance(stagePlan.getInstanceId()));
+    distributedStagePlan.setServer(protoToAddress(stagePlan.getVirtualAddress()));
     distributedStagePlan.setStageRoot(StageNodeSerDeUtils.deserializeStageNode(stagePlan.getStageRoot()));
-    Map<Integer, Worker.StageMetadata> metadataMap = stagePlan.getStageMetadataMap();
-    distributedStagePlan.getMetadataMap().putAll(protoMapToStageMetadataMap(metadataMap));
+    distributedStagePlan.getStageMetadataList().addAll(protoListToStageMetadataList(stagePlan.getStageMetadataList()));
     return distributedStagePlan;
   }
 
   public static Worker.StagePlan serialize(DistributedStagePlan distributedStagePlan) {
     return Worker.StagePlan.newBuilder()
         .setStageId(distributedStagePlan.getStageId())
-        .setInstanceId(instanceToString(distributedStagePlan.getServer()))
+        .setVirtualAddress(addressToProto(distributedStagePlan.getServer()))
         .setStageRoot(StageNodeSerDeUtils.serializeStageNode((AbstractStageNode) distributedStagePlan.getStageRoot()))
-        .putAllStageMetadata(stageMetadataMapToProtoMap(distributedStagePlan.getMetadataMap())).build();
+        .addAllStageMetadata(stageMetadataListToProtoList(distributedStagePlan.getStageMetadataList())).build();
   }
 
   private static final Pattern VIRTUAL_SERVER_PATTERN = Pattern.compile(
-      "(?<virtualid>[0-9]+)@(?<host>[^:]+):(?<port>[0-9]+)\\((?<grpc>[0-9]+):(?<service>[0-9]+):(?<mailbox>[0-9]+)\\)");
+      "(?<virtualid>[0-9]+)@(?<host>[^:]+):(?<port>[0-9]+)");
 
-  public static VirtualServer stringToInstance(String serverInstanceString) {
-    Matcher matcher = VIRTUAL_SERVER_PATTERN.matcher(serverInstanceString);
+  public static VirtualServerAddress protoToAddress(String virtualAddressStr) {
+    Matcher matcher = VIRTUAL_SERVER_PATTERN.matcher(virtualAddressStr);
     if (!matcher.matches()) {
-      throw new IllegalArgumentException("Unexpected serverInstanceString '" + serverInstanceString + "'. This might "
+      throw new IllegalArgumentException("Unexpected virtualAddressStr '" + virtualAddressStr + "'. This might "
           + "happen if you are upgrading from an old version of the multistage engine to the current one in a rolling "
           + "fashion.");
     }
 
     // Skipped netty and grpc port as they are not used in worker instance.
-    return new VirtualServer(new WorkerInstance(matcher.group("host"), Integer.parseInt(matcher.group("port")),
-        Integer.parseInt(matcher.group("grpc")), Integer.parseInt(matcher.group("service")),
-        Integer.parseInt(matcher.group("mailbox"))), Integer.parseInt(matcher.group("virtualid")));
+    return new VirtualServerAddress(matcher.group("host"),
+        Integer.parseInt(matcher.group("port")), Integer.parseInt(matcher.group("virtualid")));
   }
 
-  public static String instanceToString(VirtualServer serverInstance) {
-    return String.format("%s@%s:%s(%s:%s:%s)", serverInstance.getVirtualId(), serverInstance.getHostname(),
-        serverInstance.getPort(), serverInstance.getGrpcPort(), serverInstance.getQueryServicePort(),
-        serverInstance.getQueryMailboxPort());
+  public static String addressToProto(VirtualServerAddress serverAddress) {
+    return String.format("%s@%s:%s", serverAddress.workerId(), serverAddress.hostname(), serverAddress.port());
   }
 
-  public static Map<Integer, StageMetadata> protoMapToStageMetadataMap(Map<Integer, Worker.StageMetadata> protoMap) {
-    Map<Integer, StageMetadata> metadataMap = new HashMap<>();
-    for (Map.Entry<Integer, Worker.StageMetadata> e : protoMap.entrySet()) {
-      metadataMap.put(e.getKey(), fromWorkerStageMetadata(e.getValue()));
+  public static List<StageMetadata> protoListToStageMetadataList(List<Worker.StageMetadata> protoList) {
+    List<StageMetadata> stageMetadataList = new ArrayList<>();
+    for (Worker.StageMetadata protoStageMetadata : protoList) {
+      stageMetadataList.add(fromProtoStageMetadata(protoStageMetadata));
     }
-    return metadataMap;
+    return stageMetadataList;
   }
 
-  private static StageMetadata fromWorkerStageMetadata(Worker.StageMetadata workerStageMetadata) {
-    StageMetadata stageMetadata = new StageMetadata();
-    // scanned table
-    stageMetadata.getScannedTables().addAll(workerStageMetadata.getDataSourcesList());
-    // server instance to table-segments mapping
-    for (String serverInstanceString : workerStageMetadata.getInstancesList()) {
-      stageMetadata.getServerInstances().add(stringToInstance(serverInstanceString));
+  private static StageMetadata fromProtoStageMetadata(Worker.StageMetadata protoStageMetadata) {
+    StageMetadata.Builder builder = new StageMetadata.Builder();
+    List<WorkerMetadata> workerMetadataList = new ArrayList<>();
+    for (Worker.WorkerMetadata protoWorkerMetadata : protoStageMetadata.getWorkerMetadataList()) {
+      workerMetadataList.add(fromProtoWorkerMetadata(protoWorkerMetadata));
     }
-    for (Map.Entry<String, Worker.SegmentMap> instanceEntry
-        : workerStageMetadata.getInstanceToSegmentMapMap().entrySet()) {
-      Map<String, List<String>> tableToSegmentMap = new HashMap<>();
-      for (Map.Entry<String, Worker.SegmentList> tableEntry
-          : instanceEntry.getValue().getTableTypeToSegmentListMap().entrySet()) {
-        tableToSegmentMap.put(tableEntry.getKey(), tableEntry.getValue().getSegmentsList());
-      }
-      stageMetadata.getServerInstanceToSegmentsMap()
-          .put(stringToInstance(instanceEntry.getKey()).getServer(), tableToSegmentMap);
-    }
-    // time boundary info
-    if (!workerStageMetadata.getTimeColumn().isEmpty()) {
-      stageMetadata.setTimeBoundaryInfo(new TimeBoundaryInfo(workerStageMetadata.getTimeColumn(),
-          workerStageMetadata.getTimeValue()));
-    }
-    return stageMetadata;
+    builder.setWorkerMetadataList(workerMetadataList);
+    builder.putAllCustomProperties(protoStageMetadata.getCustomPropertyMap());
+    return builder.build();
+  }
+
+  private static WorkerMetadata fromProtoWorkerMetadata(Worker.WorkerMetadata protoWorkerMetadata) {
+    WorkerMetadata.Builder builder = new WorkerMetadata.Builder();
+    builder.setVirtualServerAddress(protoToAddress(protoWorkerMetadata.getVirtualAddress()));
+    builder.putAllCustomProperties(protoWorkerMetadata.getCustomPropertyMap());
+    return builder.build();
   }
 
-  public static Map<Integer, Worker.StageMetadata> stageMetadataMapToProtoMap(Map<Integer, StageMetadata> metadataMap) {
-    Map<Integer, Worker.StageMetadata> protoMap = new HashMap<>();
-    for (Map.Entry<Integer, StageMetadata> e : metadataMap.entrySet()) {
-      protoMap.put(e.getKey(), toWorkerStageMetadata(e.getValue()));
+  public static List<Worker.StageMetadata> stageMetadataListToProtoList(List<StageMetadata> stageMetadataList) {
+    List<Worker.StageMetadata> protoList = new ArrayList<>();
+    for (StageMetadata stageMetadata : stageMetadataList) {
+      protoList.add(toProtoStageMetadata(stageMetadata));
     }
-    return protoMap;
+    return protoList;
   }
 
-  private static Worker.StageMetadata toWorkerStageMetadata(StageMetadata stageMetadata) {
+  private static Worker.StageMetadata toProtoStageMetadata(StageMetadata stageMetadata) {
     Worker.StageMetadata.Builder builder = Worker.StageMetadata.newBuilder();
-    // scanned table
-    builder.addAllDataSources(stageMetadata.getScannedTables());
-    // server instance to table-segments mapping
-    for (VirtualServer serverInstance : stageMetadata.getServerInstances()) {
-      builder.addInstances(instanceToString(serverInstance));
-    }
-    for (Map.Entry<ServerInstance, Map<String, List<String>>> instanceEntry
-        : stageMetadata.getServerInstanceToSegmentsMap().entrySet()) {
-      Map<String, Worker.SegmentList> tableToSegmentMap = new HashMap<>();
-      for (Map.Entry<String, List<String>> tableEntry : instanceEntry.getValue().entrySet()) {
-        tableToSegmentMap.put(tableEntry.getKey(),
-            Worker.SegmentList.newBuilder().addAllSegments(tableEntry.getValue()).build());
-      }
-      builder.putInstanceToSegmentMap(instanceToString(new VirtualServer(instanceEntry.getKey(), 0)),
-          Worker.SegmentMap.newBuilder().putAllTableTypeToSegmentList(tableToSegmentMap).build());
-    }
-    // time boundary info
-    if (stageMetadata.getTimeBoundaryInfo() != null) {
-      builder.setTimeColumn(stageMetadata.getTimeBoundaryInfo().getTimeColumn());
-      builder.setTimeValue(stageMetadata.getTimeBoundaryInfo().getTimeValue());
+    for (WorkerMetadata workerMetadata : stageMetadata.getWorkerMetadataList()) {
+      builder.addWorkerMetadata(toProtoWorkerMetadata(workerMetadata));
     }
+    builder.putAllCustomProperty(stageMetadata.getCustomProperties());
+    return builder.build();
+  }
+
+  private static Worker.WorkerMetadata toProtoWorkerMetadata(WorkerMetadata workerMetadata) {
+    Worker.WorkerMetadata.Builder builder = Worker.WorkerMetadata.newBuilder();
+    builder.setVirtualAddress(addressToProto(workerMetadata.getVirtualServerAddress()));
+    builder.putAllCustomProperty(workerMetadata.getCustomProperties());
     return builder.build();
   }
 }
diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/server/ServerPlanRequestContext.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/server/ServerPlanRequestContext.java
index 9115940a30..b9a7bc25bc 100644
--- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/server/ServerPlanRequestContext.java
+++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/server/ServerPlanRequestContext.java
@@ -18,12 +18,12 @@
  */
 package org.apache.pinot.query.runtime.plan.server;
 
-import java.util.Map;
+import java.util.List;
 import org.apache.pinot.common.request.InstanceRequest;
 import org.apache.pinot.common.request.PinotQuery;
 import org.apache.pinot.core.routing.TimeBoundaryInfo;
 import org.apache.pinot.query.mailbox.MailboxService;
-import org.apache.pinot.query.planner.StageMetadata;
+import org.apache.pinot.query.routing.StageMetadata;
 import org.apache.pinot.query.routing.VirtualServerAddress;
 import org.apache.pinot.query.runtime.plan.PlanRequestContext;
 import org.apache.pinot.spi.config.table.TableType;
@@ -41,9 +41,9 @@ public class ServerPlanRequestContext extends PlanRequestContext {
   protected InstanceRequest _instanceRequest;
 
   public ServerPlanRequestContext(MailboxService mailboxService, long requestId, int stageId, long timeoutMs,
-      long deadlineMs, VirtualServerAddress server, Map<Integer, StageMetadata> metadataMap, PinotQuery pinotQuery,
+      long deadlineMs, VirtualServerAddress server, List<StageMetadata> stageMetadataList, PinotQuery pinotQuery,
       TableType tableType, TimeBoundaryInfo timeBoundaryInfo, boolean traceEnabled) {
-    super(mailboxService, requestId, stageId, timeoutMs, deadlineMs, server, metadataMap, traceEnabled);
+    super(mailboxService, requestId, stageId, timeoutMs, deadlineMs, server, stageMetadataList, traceEnabled);
     _pinotQuery = pinotQuery;
     _tableType = tableType;
     _timeBoundaryInfo = timeBoundaryInfo;
diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/AsyncQueryDispatchResponse.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/AsyncQueryDispatchResponse.java
index a42d068d44..185ba4f607 100644
--- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/AsyncQueryDispatchResponse.java
+++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/AsyncQueryDispatchResponse.java
@@ -21,7 +21,7 @@ package org.apache.pinot.query.service.dispatch;
 import io.grpc.stub.StreamObserver;
 import javax.annotation.Nullable;
 import org.apache.pinot.common.proto.Worker;
-import org.apache.pinot.query.routing.VirtualServer;
+import org.apache.pinot.query.routing.QueryServerInstance;
 
 
 /**
@@ -30,12 +30,12 @@ import org.apache.pinot.query.routing.VirtualServer;
  * {@link #getThrowable()} to check if it is null.
  */
 class AsyncQueryDispatchResponse {
-  private final VirtualServer _virtualServer;
+  private final QueryServerInstance _virtualServer;
   private final int _stageId;
   private final Worker.QueryResponse _queryResponse;
   private final Throwable _throwable;
 
-  public AsyncQueryDispatchResponse(VirtualServer virtualServer, int stageId, Worker.QueryResponse queryResponse,
+  public AsyncQueryDispatchResponse(QueryServerInstance virtualServer, int stageId, Worker.QueryResponse queryResponse,
       @Nullable Throwable throwable) {
     _virtualServer = virtualServer;
     _stageId = stageId;
@@ -43,7 +43,7 @@ class AsyncQueryDispatchResponse {
     _throwable = throwable;
   }
 
-  public VirtualServer getVirtualServer() {
+  public QueryServerInstance getVirtualServer() {
     return _virtualServer;
   }
 
diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/DispatchClient.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/DispatchClient.java
index d388d341b6..03861a436e 100644
--- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/DispatchClient.java
+++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/DispatchClient.java
@@ -25,14 +25,14 @@ import io.grpc.stub.StreamObserver;
 import java.util.function.Consumer;
 import org.apache.pinot.common.proto.PinotQueryWorkerGrpc;
 import org.apache.pinot.common.proto.Worker;
-import org.apache.pinot.query.routing.VirtualServer;
+import org.apache.pinot.query.routing.QueryServerInstance;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 
 /**
- * Dispatches a query plan to given a {@link VirtualServer}. Each {@link DispatchClient} has its own gRPC Channel and
- * Client Stub.
+ * Dispatches a query plan to given a {@link QueryServerInstance}. Each {@link DispatchClient} has its own gRPC Channel
+ * and Client Stub.
  * TODO: It might be neater to implement pooling at the client level. Two options: (1) Pass a channel provider and
  *       let that take care of pooling. (2) Create a DispatchClient interface and implement pooled/non-pooled versions.
  */
@@ -51,7 +51,7 @@ class DispatchClient {
     return _channel;
   }
 
-  public void submit(Worker.QueryRequest request, int stageId, VirtualServer virtualServer, Deadline deadline,
+  public void submit(Worker.QueryRequest request, int stageId, QueryServerInstance virtualServer, Deadline deadline,
       Consumer<AsyncQueryDispatchResponse> callback) {
     try {
       _dispatchStub.withDeadline(deadline).submit(request, new DispatchObserver(stageId, virtualServer, callback));
diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/DispatchObserver.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/DispatchObserver.java
index b4b1494f2f..2a7425dd99 100644
--- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/DispatchObserver.java
+++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/DispatchObserver.java
@@ -21,7 +21,7 @@ package org.apache.pinot.query.service.dispatch;
 import io.grpc.stub.StreamObserver;
 import java.util.function.Consumer;
 import org.apache.pinot.common.proto.Worker;
-import org.apache.pinot.query.routing.VirtualServer;
+import org.apache.pinot.query.routing.QueryServerInstance;
 
 
 /**
@@ -29,11 +29,12 @@ import org.apache.pinot.query.routing.VirtualServer;
  */
 class DispatchObserver implements StreamObserver<Worker.QueryResponse> {
   private int _stageId;
-  private VirtualServer _virtualServer;
+  private QueryServerInstance _virtualServer;
   private Consumer<AsyncQueryDispatchResponse> _callback;
   private Worker.QueryResponse _queryResponse;
 
-  public DispatchObserver(int stageId, VirtualServer virtualServer, Consumer<AsyncQueryDispatchResponse> callback) {
+  public DispatchObserver(int stageId, QueryServerInstance virtualServer,
+      Consumer<AsyncQueryDispatchResponse> callback) {
     _stageId = stageId;
     _virtualServer = virtualServer;
     _callback = callback;
diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/QueryDispatcher.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/QueryDispatcher.java
index 0eeeac4e4f..8bcd6676a4 100644
--- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/QueryDispatcher.java
+++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/QueryDispatcher.java
@@ -47,9 +47,10 @@ import org.apache.pinot.core.util.trace.TracedThreadFactory;
 import org.apache.pinot.query.mailbox.MailboxService;
 import org.apache.pinot.query.planner.ExplainPlanStageVisitor;
 import org.apache.pinot.query.planner.QueryPlan;
-import org.apache.pinot.query.planner.StageMetadata;
+import org.apache.pinot.query.planner.physical.DispatchablePlanMetadata;
 import org.apache.pinot.query.planner.stage.MailboxReceiveNode;
-import org.apache.pinot.query.routing.VirtualServer;
+import org.apache.pinot.query.routing.QueryServerInstance;
+import org.apache.pinot.query.routing.StageMetadata;
 import org.apache.pinot.query.routing.VirtualServerAddress;
 import org.apache.pinot.query.runtime.blocks.TransferableBlock;
 import org.apache.pinot.query.runtime.blocks.TransferableBlockUtils;
@@ -101,12 +102,12 @@ public class QueryDispatcher {
 
   private void cancel(long requestId, QueryPlan queryPlan) {
     Set<DispatchClient> dispatchClientSet = new HashSet<>();
-    for (Map.Entry<Integer, StageMetadata> stage : queryPlan.getStageMetadataMap().entrySet()) {
+    for (Map.Entry<Integer, DispatchablePlanMetadata> stage : queryPlan.getDispatchablePlanMetadataMap().entrySet()) {
       int stageId = stage.getKey();
       // stage rooting at a mailbox receive node means reduce stage.
       if (!(queryPlan.getQueryStageMap().get(stageId) instanceof MailboxReceiveNode)) {
-        List<VirtualServer> serverInstances = stage.getValue().getServerInstances();
-        for (VirtualServer serverInstance : serverInstances) {
+        Set<QueryServerInstance> serverInstances = stage.getValue().getServerInstanceToWorkerIdMap().keySet();
+        for (QueryServerInstance serverInstance : serverInstances) {
           String host = serverInstance.getHostname();
           int servicePort = serverInstance.getQueryServicePort();
           dispatchClientSet.add(getOrCreateDispatchClient(host, servicePort));
@@ -125,25 +126,32 @@ public class QueryDispatcher {
     Deadline deadline = Deadline.after(timeoutMs, TimeUnit.MILLISECONDS);
     BlockingQueue<AsyncQueryDispatchResponse> dispatchCallbacks = new LinkedBlockingQueue<>();
     int dispatchCalls = 0;
-    for (Map.Entry<Integer, StageMetadata> stage : queryPlan.getStageMetadataMap().entrySet()) {
+    for (Map.Entry<Integer, DispatchablePlanMetadata> stage : queryPlan.getDispatchablePlanMetadataMap().entrySet()) {
       int stageId = stage.getKey();
       // stage rooting at a mailbox receive node means reduce stage.
       if (queryPlan.getQueryStageMap().get(stageId) instanceof MailboxReceiveNode) {
         reduceStageId = stageId;
       } else {
-        List<VirtualServer> serverInstances = stage.getValue().getServerInstances();
-        for (VirtualServer serverInstance : serverInstances) {
-          String host = serverInstance.getHostname();
-          int servicePort = serverInstance.getQueryServicePort();
-          DispatchClient client = getOrCreateDispatchClient(host, servicePort);
-          dispatchCalls++;
-          _executorService.submit(() -> {
-            client.submit(Worker.QueryRequest.newBuilder().setStagePlan(
-                    QueryPlanSerDeUtils.serialize(constructDistributedStagePlan(queryPlan, stageId, serverInstance)))
-                .putMetadata(QueryConfig.KEY_OF_BROKER_REQUEST_ID, String.valueOf(requestId))
-                .putMetadata(QueryConfig.KEY_OF_BROKER_REQUEST_TIMEOUT_MS, String.valueOf(timeoutMs))
-                .putAllMetadata(queryOptions).build(), stageId, serverInstance, deadline, dispatchCallbacks::offer);
-          });
+        for (Map.Entry<QueryServerInstance, List<Integer>> queryServerEntry
+            : stage.getValue().getServerInstanceToWorkerIdMap().entrySet()) {
+          QueryServerInstance queryServerInstance = queryServerEntry.getKey();
+          for (int workerId : queryServerEntry.getValue()) {
+            String host = queryServerInstance.getHostname();
+            int servicePort = queryServerInstance.getQueryServicePort();
+            int mailboxPort = queryServerInstance.getQueryMailboxPort();
+            VirtualServerAddress virtualServerAddress = new VirtualServerAddress(host, mailboxPort, workerId);
+            DispatchClient client = getOrCreateDispatchClient(host, servicePort);
+            dispatchCalls++;
+            _executorService.submit(() -> {
+              client.submit(Worker.QueryRequest.newBuilder().setStagePlan(
+                      QueryPlanSerDeUtils.serialize(constructDistributedStagePlan(queryPlan, stageId,
+                          virtualServerAddress)))
+                  .putMetadata(QueryConfig.KEY_OF_BROKER_REQUEST_ID, String.valueOf(requestId))
+                  .putMetadata(QueryConfig.KEY_OF_BROKER_REQUEST_TIMEOUT_MS, String.valueOf(timeoutMs))
+                  .putAllMetadata(queryOptions).build(), stageId, queryServerInstance, deadline,
+                  dispatchCallbacks::offer);
+            });
+          }
         }
       }
     }
@@ -181,7 +189,7 @@ public class QueryDispatcher {
     VirtualServerAddress server = new VirtualServerAddress(mailboxService.getHostname(), mailboxService.getPort(), 0);
     OpChainExecutionContext context =
         new OpChainExecutionContext(mailboxService, requestId, reduceStageId, server, timeoutMs,
-            System.currentTimeMillis() + timeoutMs, queryPlan.getStageMetadataMap(), traceEnabled);
+            System.currentTimeMillis() + timeoutMs, queryPlan.getStageMetadataList(), traceEnabled);
     MailboxReceiveOperator mailboxReceiveOperator = createReduceStageOperator(context, reduceNode.getSenderStageId());
     List<DataBlock> resultDataBlocks =
         reduceMailboxReceive(mailboxReceiveOperator, timeoutMs, statsAggregatorMap, queryPlan, context.getStats());
@@ -191,9 +199,9 @@ public class QueryDispatcher {
 
   @VisibleForTesting
   public static DistributedStagePlan constructDistributedStagePlan(QueryPlan queryPlan, int stageId,
-      VirtualServer serverInstance) {
-    return new DistributedStagePlan(stageId, serverInstance, queryPlan.getQueryStageMap().get(stageId),
-        queryPlan.getStageMetadataMap());
+      VirtualServerAddress serverAddress) {
+    return new DistributedStagePlan(stageId, serverAddress, queryPlan.getQueryStageMap().get(stageId),
+        queryPlan.getStageMetadataList());
   }
 
   private static List<DataBlock> reduceMailboxReceive(MailboxReceiveOperator mailboxReceiveOperator, long timeoutMs,
@@ -223,8 +231,9 @@ public class QueryDispatcher {
             rootStatsAggregator.aggregate(null, entry.getValue().getExecutionStats(), new HashMap<>());
             if (stageStatsAggregator != null) {
               if (queryPlan != null) {
-                StageMetadata operatorStageMetadata = queryPlan.getStageMetadataMap().get(operatorStats.getStageId());
-                OperatorUtils.recordTableName(operatorStats, operatorStageMetadata);
+                StageMetadata stageMetadata = queryPlan.getStageMetadataList()
+                    .get(operatorStats.getStageId());
+                OperatorUtils.recordTableName(operatorStats, stageMetadata);
               }
               stageStatsAggregator.aggregate(null, entry.getValue().getExecutionStats(), new HashMap<>());
             }
diff --git a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/QueryRunnerTest.java b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/QueryRunnerTest.java
index 535055484c..0d32e8f7bd 100644
--- a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/QueryRunnerTest.java
+++ b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/QueryRunnerTest.java
@@ -32,9 +32,7 @@ import org.apache.pinot.query.QueryServerEnclosure;
 import org.apache.pinot.query.mailbox.MailboxService;
 import org.apache.pinot.query.planner.QueryPlan;
 import org.apache.pinot.query.planner.stage.MailboxReceiveNode;
-import org.apache.pinot.query.routing.VirtualServer;
-import org.apache.pinot.query.routing.WorkerInstance;
-import org.apache.pinot.query.runtime.plan.DistributedStagePlan;
+import org.apache.pinot.query.routing.QueryServerInstance;
 import org.apache.pinot.query.service.QueryConfig;
 import org.apache.pinot.query.service.dispatch.QueryDispatcher;
 import org.apache.pinot.query.testutils.MockInstanceDataManagerFactory;
@@ -142,8 +140,8 @@ public class QueryRunnerTest extends QueryRunnerTestBase {
     // this is only use for test identifier purpose.
     int port1 = server1.getPort();
     int port2 = server2.getPort();
-    _servers.put(new WorkerInstance("localhost", port1, port1, port1, port1), server1);
-    _servers.put(new WorkerInstance("localhost", port2, port2, port2, port2), server2);
+    _servers.put(new QueryServerInstance("localhost", port1, port1), server1);
+    _servers.put(new QueryServerInstance("localhost", port2, port2), server2);
   }
 
   @AfterClass
@@ -193,15 +191,11 @@ public class QueryRunnerTest extends QueryRunnerTestBase {
             QueryConfig.KEY_OF_BROKER_REQUEST_TIMEOUT_MS,
             String.valueOf(CommonConstants.Broker.DEFAULT_BROKER_TIMEOUT_MS));
     int reducerStageId = -1;
-    for (int stageId : queryPlan.getStageMetadataMap().keySet()) {
+    for (int stageId : queryPlan.getDispatchablePlanMetadataMap().keySet()) {
       if (queryPlan.getQueryStageMap().get(stageId) instanceof MailboxReceiveNode) {
         reducerStageId = stageId;
       } else {
-        for (VirtualServer serverInstance : queryPlan.getStageMetadataMap().get(stageId).getServerInstances()) {
-          DistributedStagePlan distributedStagePlan =
-              QueryDispatcher.constructDistributedStagePlan(queryPlan, stageId, serverInstance);
-          _servers.get(serverInstance.getServer()).processQuery(distributedStagePlan, requestMetadataMap);
-        }
+        processDistributedStagePlans(queryPlan, stageId, requestMetadataMap);
       }
     }
     Preconditions.checkState(reducerStageId != -1);
diff --git a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/QueryRunnerTestBase.java b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/QueryRunnerTestBase.java
index 6af12ba67a..5e45b765c8 100644
--- a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/QueryRunnerTestBase.java
+++ b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/QueryRunnerTestBase.java
@@ -43,14 +43,14 @@ import org.apache.commons.codec.DecoderException;
 import org.apache.commons.codec.binary.Hex;
 import org.apache.pinot.common.response.broker.ResultTable;
 import org.apache.pinot.core.query.reduce.ExecutionStatsAggregator;
-import org.apache.pinot.core.transport.ServerInstance;
 import org.apache.pinot.query.QueryEnvironment;
 import org.apache.pinot.query.QueryServerEnclosure;
 import org.apache.pinot.query.QueryTestSet;
 import org.apache.pinot.query.mailbox.MailboxService;
 import org.apache.pinot.query.planner.QueryPlan;
 import org.apache.pinot.query.planner.stage.MailboxReceiveNode;
-import org.apache.pinot.query.routing.VirtualServer;
+import org.apache.pinot.query.routing.QueryServerInstance;
+import org.apache.pinot.query.routing.VirtualServerAddress;
 import org.apache.pinot.query.runtime.plan.DistributedStagePlan;
 import org.apache.pinot.query.service.QueryConfig;
 import org.apache.pinot.query.service.dispatch.QueryDispatcher;
@@ -76,7 +76,7 @@ public abstract class QueryRunnerTestBase extends QueryTestSet {
   protected QueryEnvironment _queryEnvironment;
   protected String _reducerHostname;
   protected int _reducerGrpcPort;
-  protected Map<ServerInstance, QueryServerEnclosure> _servers = new HashMap<>();
+  protected Map<QueryServerInstance, QueryServerEnclosure> _servers = new HashMap<>();
   protected MailboxService _mailboxService;
 
   static {
@@ -107,15 +107,11 @@ public abstract class QueryRunnerTestBase extends QueryTestSet {
     }
 
     int reducerStageId = -1;
-    for (int stageId : queryPlan.getStageMetadataMap().keySet()) {
+    for (int stageId : queryPlan.getDispatchablePlanMetadataMap().keySet()) {
       if (queryPlan.getQueryStageMap().get(stageId) instanceof MailboxReceiveNode) {
         reducerStageId = stageId;
       } else {
-        for (VirtualServer serverInstance : queryPlan.getStageMetadataMap().get(stageId).getServerInstances()) {
-          DistributedStagePlan distributedStagePlan =
-              QueryDispatcher.constructDistributedStagePlan(queryPlan, stageId, serverInstance);
-          _servers.get(serverInstance.getServer()).processQuery(distributedStagePlan, requestMetadataMap);
-        }
+        processDistributedStagePlans(queryPlan, stageId, requestMetadataMap);
       }
       if (executionStatsAggregatorMap != null) {
         executionStatsAggregatorMap.put(stageId, new ExecutionStatsAggregator(true));
@@ -128,6 +124,20 @@ public abstract class QueryRunnerTestBase extends QueryTestSet {
     return resultTable.getRows();
   }
 
+  protected void processDistributedStagePlans(QueryPlan queryPlan, int stageId,
+      Map<String, String> requestMetadataMap) {
+    Map<QueryServerInstance, List<Integer>> serverInstanceToWorkerIdMap =
+        queryPlan.getDispatchablePlanMetadataMap().get(stageId).getServerInstanceToWorkerIdMap();
+    for (Map.Entry<QueryServerInstance, List<Integer>> entry : serverInstanceToWorkerIdMap.entrySet()) {
+      QueryServerInstance server = entry.getKey();
+      for (int workerId : entry.getValue()) {
+        DistributedStagePlan distributedStagePlan = QueryDispatcher.constructDistributedStagePlan(
+            queryPlan, stageId, new VirtualServerAddress(server, workerId));
+        _servers.get(server).processQuery(distributedStagePlan, requestMetadataMap);
+      }
+    }
+  }
+
   protected List<Object[]> queryH2(String sql)
       throws Exception {
     int firstSemi = sql.indexOf(';');
diff --git a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/MailboxReceiveOperatorTest.java b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/MailboxReceiveOperatorTest.java
index 5f64e3a04b..399eb5e998 100644
--- a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/MailboxReceiveOperatorTest.java
+++ b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/MailboxReceiveOperatorTest.java
@@ -21,6 +21,8 @@ package org.apache.pinot.query.runtime.operator;
 import java.util.Arrays;
 import java.util.Collections;
 import java.util.List;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
 import org.apache.calcite.rel.RelDistribution;
 import org.apache.pinot.common.datablock.MetadataBlock;
 import org.apache.pinot.common.exception.QueryException;
@@ -28,9 +30,9 @@ import org.apache.pinot.common.utils.DataSchema;
 import org.apache.pinot.query.mailbox.MailboxIdUtils;
 import org.apache.pinot.query.mailbox.MailboxService;
 import org.apache.pinot.query.mailbox.ReceivingMailbox;
-import org.apache.pinot.query.planner.StageMetadata;
-import org.apache.pinot.query.routing.VirtualServer;
+import org.apache.pinot.query.routing.StageMetadata;
 import org.apache.pinot.query.routing.VirtualServerAddress;
+import org.apache.pinot.query.routing.WorkerMetadata;
 import org.apache.pinot.query.runtime.blocks.TransferableBlock;
 import org.apache.pinot.query.runtime.blocks.TransferableBlockUtils;
 import org.apache.pinot.query.runtime.plan.OpChainExecutionContext;
@@ -58,25 +60,30 @@ public class MailboxReceiveOperatorTest {
   @Mock
   private MailboxService _mailboxService;
   @Mock
-  private VirtualServer _server1;
-  @Mock
-  private VirtualServer _server2;
-  @Mock
   private ReceivingMailbox _mailbox1;
   @Mock
   private ReceivingMailbox _mailbox2;
+  private List<StageMetadata> _stageMetadataListBoth;
+  private List<StageMetadata> _stageMetadataList1;
 
   @BeforeMethod
   public void setUp() {
     _mocks = MockitoAnnotations.openMocks(this);
     when(_mailboxService.getHostname()).thenReturn("localhost");
     when(_mailboxService.getPort()).thenReturn(123);
-    when(_server1.getHostname()).thenReturn("localhost");
-    when(_server1.getQueryMailboxPort()).thenReturn(123);
-    when(_server1.getVirtualId()).thenReturn(0);
-    when(_server2.getHostname()).thenReturn("localhost");
-    when(_server2.getQueryMailboxPort()).thenReturn(123);
-    when(_server2.getVirtualId()).thenReturn(1);
+    VirtualServerAddress server1 = new VirtualServerAddress("localhost", 123, 0);
+    VirtualServerAddress server2 = new VirtualServerAddress("localhost", 123, 1);
+    StageMetadata stageMetadata = new StageMetadata.Builder()
+        .setWorkerMetadataList(Stream.of(server1, server2).map(
+            s -> new WorkerMetadata.Builder().setVirtualServerAddress(s).build()).collect(Collectors.toList()))
+        .build();
+    // sending stage is 0, receiving stage is 1
+    _stageMetadataListBoth = Arrays.asList(null, stageMetadata);
+    stageMetadata = new StageMetadata.Builder()
+        .setWorkerMetadataList(Stream.of(server1).map(
+            s -> new WorkerMetadata.Builder().setVirtualServerAddress(s).build()).collect(Collectors.toList()))
+        .build();
+    _stageMetadataList1 = Arrays.asList(null, stageMetadata);
   }
 
   @AfterMethod
@@ -87,24 +94,24 @@ public class MailboxReceiveOperatorTest {
 
   @Test(expectedExceptions = IllegalStateException.class, expectedExceptionsMessageRegExp = "Failed to find instance.*")
   public void shouldThrowSingletonNoMatchMailboxServer() {
-    when(_server1.getQueryMailboxPort()).thenReturn(456);
-    when(_server2.getQueryMailboxPort()).thenReturn(789);
-    StageMetadata stageMetadata = new StageMetadata();
-    stageMetadata.setServerInstances(Arrays.asList(_server1, _server2));
+    VirtualServerAddress server1 = new VirtualServerAddress("localhost", 456, 0);
+    VirtualServerAddress server2 = new VirtualServerAddress("localhost", 789, 1);
+    StageMetadata stageMetadata = new StageMetadata.Builder()
+        .setWorkerMetadataList(Stream.of(server1, server2).map(
+            s -> new WorkerMetadata.Builder().setVirtualServerAddress(s).build()).collect(Collectors.toList()))
+        .build();
     OpChainExecutionContext context =
         new OpChainExecutionContext(_mailboxService, 0, 0, RECEIVER_ADDRESS, Long.MAX_VALUE, Long.MAX_VALUE,
-            Collections.singletonMap(1, stageMetadata), false);
+            Arrays.asList(null, stageMetadata), false);
     //noinspection resource
     new MailboxReceiveOperator(context, RelDistribution.Type.SINGLETON, 1);
   }
 
   @Test(expectedExceptions = IllegalStateException.class, expectedExceptionsMessageRegExp = "Multiple instances.*")
   public void shouldThrowReceiveSingletonFromMultiMatchMailboxServer() {
-    StageMetadata stageMetadata = new StageMetadata();
-    stageMetadata.setServerInstances(Arrays.asList(_server1, _server2));
     OpChainExecutionContext context =
         new OpChainExecutionContext(_mailboxService, 0, 0, RECEIVER_ADDRESS, Long.MAX_VALUE, Long.MAX_VALUE,
-            Collections.singletonMap(1, stageMetadata), false);
+            _stageMetadataListBoth, false);
     //noinspection resource
     new MailboxReceiveOperator(context, RelDistribution.Type.SINGLETON, 1);
   }
@@ -113,7 +120,7 @@ public class MailboxReceiveOperatorTest {
   public void shouldThrowRangeDistributionNotSupported() {
     OpChainExecutionContext context =
         new OpChainExecutionContext(_mailboxService, 0, 0, RECEIVER_ADDRESS, Long.MAX_VALUE, Long.MAX_VALUE,
-            Collections.emptyMap(), false);
+            Collections.emptyList(), false);
     //noinspection resource
     new MailboxReceiveOperator(context, RelDistribution.Type.RANGE_DISTRIBUTED, 1);
   }
@@ -124,11 +131,9 @@ public class MailboxReceiveOperatorTest {
     when(_mailboxService.getReceivingMailbox(MAILBOX_ID_1)).thenReturn(_mailbox1);
 
     // Short timeoutMs should result in timeout
-    StageMetadata stageMetadata = new StageMetadata();
-    stageMetadata.setServerInstances(Collections.singletonList(_server1));
     OpChainExecutionContext context =
         new OpChainExecutionContext(_mailboxService, 0, 0, RECEIVER_ADDRESS, 10L, System.currentTimeMillis() + 10L,
-            Collections.singletonMap(1, stageMetadata), false);
+            _stageMetadataList1, false);
     try (MailboxReceiveOperator receiveOp = new MailboxReceiveOperator(context, RelDistribution.Type.SINGLETON, 1)) {
       Thread.sleep(100L);
       TransferableBlock mailbox = receiveOp.nextBlock();
@@ -139,7 +144,7 @@ public class MailboxReceiveOperatorTest {
 
     // Longer timeout or default timeout (10s) doesn't result in timeout
     context = new OpChainExecutionContext(_mailboxService, 0, 0, RECEIVER_ADDRESS, 10_000L,
-        System.currentTimeMillis() + 10_000L, Collections.singletonMap(1, stageMetadata), false);
+        System.currentTimeMillis() + 10_000L, _stageMetadataList1, false);
     try (MailboxReceiveOperator receiveOp = new MailboxReceiveOperator(context, RelDistribution.Type.SINGLETON, 1)) {
       Thread.sleep(100L);
       TransferableBlock mailbox = receiveOp.nextBlock();
@@ -151,11 +156,9 @@ public class MailboxReceiveOperatorTest {
   public void shouldReceiveSingletonNullMailbox() {
     when(_mailboxService.getReceivingMailbox(MAILBOX_ID_1)).thenReturn(_mailbox1);
 
-    StageMetadata stageMetadata = new StageMetadata();
-    stageMetadata.setServerInstances(Collections.singletonList(_server1));
     OpChainExecutionContext context =
         new OpChainExecutionContext(_mailboxService, 0, 0, RECEIVER_ADDRESS, Long.MAX_VALUE, Long.MAX_VALUE,
-            Collections.singletonMap(1, stageMetadata), false);
+            _stageMetadataList1, false);
     try (MailboxReceiveOperator receiveOp = new MailboxReceiveOperator(context, RelDistribution.Type.SINGLETON, 1)) {
       assertTrue(receiveOp.nextBlock().isNoOpBlock());
     }
@@ -166,11 +169,9 @@ public class MailboxReceiveOperatorTest {
     when(_mailboxService.getReceivingMailbox(MAILBOX_ID_1)).thenReturn(_mailbox1);
     when(_mailbox1.poll()).thenReturn(TransferableBlockUtils.getEndOfStreamTransferableBlock());
 
-    StageMetadata stageMetadata = new StageMetadata();
-    stageMetadata.setServerInstances(Collections.singletonList(_server1));
     OpChainExecutionContext context =
         new OpChainExecutionContext(_mailboxService, 0, 0, RECEIVER_ADDRESS, Long.MAX_VALUE, Long.MAX_VALUE,
-            Collections.singletonMap(1, stageMetadata), false);
+            _stageMetadataList1, false);
     try (MailboxReceiveOperator receiveOp = new MailboxReceiveOperator(context, RelDistribution.Type.SINGLETON, 1)) {
       assertTrue(receiveOp.nextBlock().isEndOfStreamBlock());
     }
@@ -183,11 +184,9 @@ public class MailboxReceiveOperatorTest {
     when(_mailbox1.poll()).thenReturn(OperatorTestUtil.block(DATA_SCHEMA, row),
         TransferableBlockUtils.getEndOfStreamTransferableBlock());
 
-    StageMetadata stageMetadata = new StageMetadata();
-    stageMetadata.setServerInstances(Collections.singletonList(_server1));
     OpChainExecutionContext context =
         new OpChainExecutionContext(_mailboxService, 0, 0, RECEIVER_ADDRESS, Long.MAX_VALUE, Long.MAX_VALUE,
-            Collections.singletonMap(1, stageMetadata), false);
+            _stageMetadataList1, false);
     try (MailboxReceiveOperator receiveOp = new MailboxReceiveOperator(context, RelDistribution.Type.SINGLETON, 1)) {
       List<Object[]> actualRows = receiveOp.nextBlock().getContainer();
       assertEquals(actualRows.size(), 1);
@@ -203,11 +202,9 @@ public class MailboxReceiveOperatorTest {
     when(_mailbox1.poll()).thenReturn(
         TransferableBlockUtils.getErrorTransferableBlock(new RuntimeException(errorMessage)));
 
-    StageMetadata stageMetadata = new StageMetadata();
-    stageMetadata.setServerInstances(Collections.singletonList(_server1));
     OpChainExecutionContext context =
         new OpChainExecutionContext(_mailboxService, 0, 0, RECEIVER_ADDRESS, Long.MAX_VALUE, Long.MAX_VALUE,
-            Collections.singletonMap(1, stageMetadata), false);
+            _stageMetadataList1, false);
     try (MailboxReceiveOperator receiveOp = new MailboxReceiveOperator(context, RelDistribution.Type.SINGLETON, 1)) {
       TransferableBlock block = receiveOp.nextBlock();
       assertTrue(block.isErrorBlock());
@@ -224,11 +221,9 @@ public class MailboxReceiveOperatorTest {
     when(_mailbox2.poll()).thenReturn(OperatorTestUtil.block(DATA_SCHEMA, row),
         TransferableBlockUtils.getEndOfStreamTransferableBlock());
 
-    StageMetadata stageMetadata = new StageMetadata();
-    stageMetadata.setServerInstances(Arrays.asList(_server1, _server2));
     OpChainExecutionContext context =
         new OpChainExecutionContext(_mailboxService, 0, 0, RECEIVER_ADDRESS, Long.MAX_VALUE, Long.MAX_VALUE,
-            Collections.singletonMap(1, stageMetadata), false);
+            _stageMetadataListBoth, false);
     try (MailboxReceiveOperator receiveOp = new MailboxReceiveOperator(context, RelDistribution.Type.HASH_DISTRIBUTED,
         1)) {
       List<Object[]> actualRows = receiveOp.nextBlock().getContainer();
@@ -250,11 +245,9 @@ public class MailboxReceiveOperatorTest {
     when(_mailbox2.poll()).thenReturn(OperatorTestUtil.block(DATA_SCHEMA, row2),
         TransferableBlockUtils.getEndOfStreamTransferableBlock());
 
-    StageMetadata stageMetadata = new StageMetadata();
-    stageMetadata.setServerInstances(Arrays.asList(_server1, _server2));
     OpChainExecutionContext context =
         new OpChainExecutionContext(_mailboxService, 0, 0, RECEIVER_ADDRESS, Long.MAX_VALUE, Long.MAX_VALUE,
-            Collections.singletonMap(1, stageMetadata), false);
+            _stageMetadataListBoth, false);
     try (MailboxReceiveOperator receiveOp = new MailboxReceiveOperator(context, RelDistribution.Type.HASH_DISTRIBUTED,
         1)) {
       // Receive first block from server1
@@ -278,11 +271,9 @@ public class MailboxReceiveOperatorTest {
     when(_mailbox2.poll()).thenReturn(OperatorTestUtil.block(DATA_SCHEMA, row),
         TransferableBlockUtils.getEndOfStreamTransferableBlock());
 
-    StageMetadata stageMetadata = new StageMetadata();
-    stageMetadata.setServerInstances(Arrays.asList(_server1, _server2));
     OpChainExecutionContext context =
         new OpChainExecutionContext(_mailboxService, 0, 0, RECEIVER_ADDRESS, Long.MAX_VALUE, Long.MAX_VALUE,
-            Collections.singletonMap(1, stageMetadata), false);
+            _stageMetadataListBoth, false);
     try (MailboxReceiveOperator receiveOp = new MailboxReceiveOperator(context, RelDistribution.Type.HASH_DISTRIBUTED,
         1)) {
       TransferableBlock block = receiveOp.nextBlock();
diff --git a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/MailboxSendOperatorTest.java b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/MailboxSendOperatorTest.java
index 612c2a3611..4efa6aba34 100644
--- a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/MailboxSendOperatorTest.java
+++ b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/MailboxSendOperatorTest.java
@@ -23,9 +23,9 @@ import java.util.List;
 import java.util.Map;
 import org.apache.pinot.common.utils.DataSchema;
 import org.apache.pinot.query.mailbox.MailboxService;
-import org.apache.pinot.query.planner.StageMetadata;
-import org.apache.pinot.query.routing.VirtualServer;
+import org.apache.pinot.query.routing.StageMetadata;
 import org.apache.pinot.query.routing.VirtualServerAddress;
+import org.apache.pinot.query.routing.WorkerMetadata;
 import org.apache.pinot.query.runtime.blocks.TransferableBlock;
 import org.apache.pinot.query.runtime.blocks.TransferableBlockUtils;
 import org.apache.pinot.query.runtime.operator.exchange.BlockExchange;
@@ -49,12 +49,11 @@ import static org.testng.Assert.assertTrue;
 
 public class MailboxSendOperatorTest {
   private static final int SENDER_STAGE_ID = 1;
-  private static final int RECEIVER_STAGE_ID = 0;
 
   private AutoCloseable _mocks;
 
   @Mock
-  private VirtualServer _server;
+  private VirtualServerAddress _server;
   @Mock
   private MultiStageOperator _sourceOperator;
   @Mock
@@ -65,9 +64,9 @@ public class MailboxSendOperatorTest {
   @BeforeMethod
   public void setUp() {
     _mocks = MockitoAnnotations.openMocks(this);
-    when(_server.getHostname()).thenReturn("localhost");
-    when(_server.getQueryMailboxPort()).thenReturn(123);
-    when(_server.getVirtualId()).thenReturn(0);
+    when(_server.hostname()).thenReturn("mock");
+    when(_server.port()).thenReturn(0);
+    when(_server.workerId()).thenReturn(0);
   }
 
   @AfterMethod
@@ -166,12 +165,12 @@ public class MailboxSendOperatorTest {
   }
 
   private MailboxSendOperator getMailboxSendOperator() {
-    StageMetadata stageMetadata = new StageMetadata();
-    stageMetadata.setServerInstances(Collections.singletonList(_server));
-    Map<Integer, StageMetadata> stageMetadataMap = Collections.singletonMap(RECEIVER_STAGE_ID, stageMetadata);
+    StageMetadata stageMetadata = new StageMetadata.Builder()
+        .setWorkerMetadataList(Collections.singletonList(
+            new WorkerMetadata.Builder().setVirtualServerAddress(_server).build())).build();
     OpChainExecutionContext context =
-        new OpChainExecutionContext(_mailboxService, 0, SENDER_STAGE_ID, new VirtualServerAddress(_server),
-            Long.MAX_VALUE, Long.MAX_VALUE, stageMetadataMap, false);
+        new OpChainExecutionContext(_mailboxService, 0, SENDER_STAGE_ID, _server, Long.MAX_VALUE, Long.MAX_VALUE,
+            Collections.singletonList(stageMetadata), false);
     return new MailboxSendOperator(context, _sourceOperator, _exchange, null, null, false);
   }
 }
diff --git a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/OpChainTest.java b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/OpChainTest.java
index 87b4c98395..9e4d56e940 100644
--- a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/OpChainTest.java
+++ b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/OpChainTest.java
@@ -21,10 +21,11 @@ package org.apache.pinot.query.runtime.operator;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collections;
-import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Stack;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
 import javax.annotation.Nullable;
 import org.apache.calcite.rel.RelDistribution;
 import org.apache.pinot.common.datatable.DataTable;
@@ -35,10 +36,10 @@ import org.apache.pinot.core.query.request.context.QueryContext;
 import org.apache.pinot.core.query.request.context.utils.QueryContextConverterUtils;
 import org.apache.pinot.query.mailbox.MailboxService;
 import org.apache.pinot.query.mailbox.ReceivingMailbox;
-import org.apache.pinot.query.planner.StageMetadata;
 import org.apache.pinot.query.planner.logical.RexExpression;
-import org.apache.pinot.query.routing.VirtualServer;
+import org.apache.pinot.query.routing.StageMetadata;
 import org.apache.pinot.query.routing.VirtualServerAddress;
+import org.apache.pinot.query.routing.WorkerMetadata;
 import org.apache.pinot.query.runtime.blocks.TransferableBlock;
 import org.apache.pinot.query.runtime.blocks.TransferableBlockUtils;
 import org.apache.pinot.query.runtime.operator.exchange.BlockExchange;
@@ -64,8 +65,6 @@ public class OpChainTest {
 
   private AutoCloseable _mocks;
   @Mock
-  private VirtualServer _server;
-  @Mock
   private MultiStageOperator _sourceOperator;
   @Mock
   private MailboxService _mailboxService1;
@@ -78,12 +77,17 @@ public class OpChainTest {
   @Mock
   private BlockExchange _exchange;
 
+  private VirtualServerAddress _serverAddress;
+  private StageMetadata _receivingStageMetadata;
+
   @BeforeMethod
   public void setUp() {
     _mocks = MockitoAnnotations.openMocks(this);
-    when(_server.getHostname()).thenReturn("localhost");
-    when(_server.getQueryMailboxPort()).thenReturn(123);
-    when(_server.getVirtualId()).thenReturn(0);
+    _serverAddress = new VirtualServerAddress("localhost", 123, 0);
+    _receivingStageMetadata = new StageMetadata.Builder()
+            .setWorkerMetadataList(Stream.of(_serverAddress).map(
+                s -> new WorkerMetadata.Builder().setVirtualServerAddress(s).build()).collect(Collectors.toList()))
+            .build();
 
     when(_mailboxService1.getReceivingMailbox(any())).thenReturn(_mailbox1);
     when(_mailboxService2.getReceivingMailbox(any())).thenReturn(_mailbox2);
@@ -176,12 +180,9 @@ public class OpChainTest {
 
     int receivedStageId = 2;
     int senderStageId = 1;
-    StageMetadata stageMetadata = new StageMetadata();
-    stageMetadata.setServerInstances(Collections.singletonList(_server));
-    Map<Integer, StageMetadata> stageMetadataMap = Collections.singletonMap(receivedStageId, stageMetadata);
     OpChainExecutionContext context =
-        new OpChainExecutionContext(_mailboxService1, 1, senderStageId, new VirtualServerAddress(_server), 1000,
-            System.currentTimeMillis() + 1000, stageMetadataMap, true);
+        new OpChainExecutionContext(_mailboxService1, 1, senderStageId, _serverAddress, 1000,
+            System.currentTimeMillis() + 1000, Arrays.asList(null, null, _receivingStageMetadata), true);
 
     Stack<MultiStageOperator> operators =
         getFullOpchain(receivedStageId, senderStageId, context, dummyOperatorWaitTime);
@@ -194,8 +195,8 @@ public class OpChainTest {
     opChain.getStats().queued();
 
     OpChainExecutionContext secondStageContext =
-        new OpChainExecutionContext(_mailboxService2, 1, senderStageId + 1, new VirtualServerAddress(_server), 1000,
-            System.currentTimeMillis() + 1000, stageMetadataMap, true);
+        new OpChainExecutionContext(_mailboxService2, 1, senderStageId + 1, _serverAddress, 1000,
+            System.currentTimeMillis() + 1000, Arrays.asList(null, null, _receivingStageMetadata), true);
 
     MailboxReceiveOperator secondStageReceiveOp =
         new MailboxReceiveOperator(secondStageContext, RelDistribution.Type.BROADCAST_DISTRIBUTED, senderStageId + 1);
@@ -219,15 +220,11 @@ public class OpChainTest {
 
     int receivedStageId = 2;
     int senderStageId = 1;
-    StageMetadata stageMetadata = new StageMetadata();
-    stageMetadata.setServerInstances(Collections.singletonList(_server));
-    Map<Integer, StageMetadata> stageMetadataMap = new HashMap<>();
-    for (int i = 0; i < 3; i++) {
-      stageMetadataMap.put(i, stageMetadata);
-    }
+    List<StageMetadata> metadataList =
+        Arrays.asList(_receivingStageMetadata, _receivingStageMetadata, _receivingStageMetadata);
     OpChainExecutionContext context =
-        new OpChainExecutionContext(_mailboxService1, 1, senderStageId, new VirtualServerAddress(_server), 1000,
-            System.currentTimeMillis() + 1000, stageMetadataMap, false);
+        new OpChainExecutionContext(_mailboxService1, 1, senderStageId, _serverAddress, 1000,
+            System.currentTimeMillis() + 1000, metadataList, false);
 
     Stack<MultiStageOperator> operators =
         getFullOpchain(receivedStageId, senderStageId, context, dummyOperatorWaitTime);
@@ -238,8 +235,8 @@ public class OpChainTest {
     opChain.getStats().queued();
 
     OpChainExecutionContext secondStageContext =
-        new OpChainExecutionContext(_mailboxService2, 1, senderStageId + 1, new VirtualServerAddress(_server), 1000,
-            System.currentTimeMillis() + 1000, stageMetadataMap, false);
+        new OpChainExecutionContext(_mailboxService2, 1, senderStageId + 1, _serverAddress, 1000,
+            System.currentTimeMillis() + 1000, metadataList, false);
     MailboxReceiveOperator secondStageReceiveOp =
         new MailboxReceiveOperator(secondStageContext, RelDistribution.Type.BROADCAST_DISTRIBUTED, senderStageId);
 
diff --git a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/OperatorTestUtil.java b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/OperatorTestUtil.java
index 575df48d30..0d04d2bb29 100644
--- a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/OperatorTestUtil.java
+++ b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/OperatorTestUtil.java
@@ -18,8 +18,8 @@
  */
 package org.apache.pinot.query.runtime.operator;
 
+import java.util.ArrayList;
 import java.util.Arrays;
-import java.util.HashMap;
 import java.util.List;
 import org.apache.pinot.common.datablock.DataBlock;
 import org.apache.pinot.common.utils.DataSchema;
@@ -65,18 +65,18 @@ public class OperatorTestUtil {
   public static OpChainExecutionContext getDefaultContext() {
     VirtualServerAddress virtualServerAddress = new VirtualServerAddress("mock", 80, 0);
     return new OpChainExecutionContext(null, 1, 2, virtualServerAddress, Long.MAX_VALUE, Long.MAX_VALUE,
-        new HashMap<>(), true);
+        new ArrayList<>(), true);
   }
 
   public static OpChainExecutionContext getDefaultContextWithTracingDisabled() {
     VirtualServerAddress virtualServerAddress = new VirtualServerAddress("mock", 80, 0);
     return new OpChainExecutionContext(null, 1, 2, virtualServerAddress, Long.MAX_VALUE, Long.MAX_VALUE,
-        new HashMap<>(), false);
+        new ArrayList<>(), false);
   }
 
   public static OpChainExecutionContext getContext(long requestId, int stageId,
       VirtualServerAddress virtualServerAddress) {
     return new OpChainExecutionContext(null, requestId, stageId, virtualServerAddress, Long.MAX_VALUE, Long.MAX_VALUE,
-        new HashMap<>(), true);
+        new ArrayList<>(), true);
   }
 }
diff --git a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/SortedMailboxReceiveOperatorTest.java b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/SortedMailboxReceiveOperatorTest.java
index 8091b0e8da..90e8225278 100644
--- a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/SortedMailboxReceiveOperatorTest.java
+++ b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/SortedMailboxReceiveOperatorTest.java
@@ -21,6 +21,8 @@ package org.apache.pinot.query.runtime.operator;
 import java.util.Arrays;
 import java.util.Collections;
 import java.util.List;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
 import org.apache.calcite.rel.RelDistribution;
 import org.apache.calcite.rel.RelFieldCollation;
 import org.apache.pinot.common.datablock.MetadataBlock;
@@ -29,10 +31,10 @@ import org.apache.pinot.common.utils.DataSchema;
 import org.apache.pinot.query.mailbox.MailboxIdUtils;
 import org.apache.pinot.query.mailbox.MailboxService;
 import org.apache.pinot.query.mailbox.ReceivingMailbox;
-import org.apache.pinot.query.planner.StageMetadata;
 import org.apache.pinot.query.planner.logical.RexExpression;
-import org.apache.pinot.query.routing.VirtualServer;
+import org.apache.pinot.query.routing.StageMetadata;
 import org.apache.pinot.query.routing.VirtualServerAddress;
+import org.apache.pinot.query.routing.WorkerMetadata;
 import org.apache.pinot.query.runtime.blocks.TransferableBlock;
 import org.apache.pinot.query.runtime.blocks.TransferableBlockUtils;
 import org.apache.pinot.query.runtime.plan.OpChainExecutionContext;
@@ -64,25 +66,31 @@ public class SortedMailboxReceiveOperatorTest {
   @Mock
   private MailboxService _mailboxService;
   @Mock
-  private VirtualServer _server1;
-  @Mock
-  private VirtualServer _server2;
-  @Mock
   private ReceivingMailbox _mailbox1;
   @Mock
   private ReceivingMailbox _mailbox2;
 
+  private List<StageMetadata> _stageMetadataListBoth;
+  private List<StageMetadata> _stageMetadataList1;
+
   @BeforeMethod
   public void setUp() {
     _mocks = MockitoAnnotations.openMocks(this);
     when(_mailboxService.getHostname()).thenReturn("localhost");
     when(_mailboxService.getPort()).thenReturn(123);
-    when(_server1.getHostname()).thenReturn("localhost");
-    when(_server1.getQueryMailboxPort()).thenReturn(123);
-    when(_server1.getVirtualId()).thenReturn(0);
-    when(_server2.getHostname()).thenReturn("localhost");
-    when(_server2.getQueryMailboxPort()).thenReturn(123);
-    when(_server2.getVirtualId()).thenReturn(1);
+    VirtualServerAddress server1 = new VirtualServerAddress("localhost", 123, 0);
+    VirtualServerAddress server2 = new VirtualServerAddress("localhost", 123, 1);
+    StageMetadata stageMetadata = new StageMetadata.Builder()
+        .setWorkerMetadataList(Stream.of(server1, server2).map(
+            s -> new WorkerMetadata.Builder().setVirtualServerAddress(s).build()).collect(Collectors.toList()))
+        .build();
+    // sending stage is 0, receiving stage is 1
+    _stageMetadataListBoth = Arrays.asList(null, stageMetadata);
+    stageMetadata = new StageMetadata.Builder()
+        .setWorkerMetadataList(Stream.of(server1).map(
+            s -> new WorkerMetadata.Builder().setVirtualServerAddress(s).build()).collect(Collectors.toList()))
+        .build();
+    _stageMetadataList1 = Arrays.asList(null, stageMetadata);
   }
 
   @AfterMethod
@@ -93,13 +101,15 @@ public class SortedMailboxReceiveOperatorTest {
 
   @Test(expectedExceptions = IllegalStateException.class, expectedExceptionsMessageRegExp = "Failed to find instance.*")
   public void shouldThrowSingletonNoMatchMailboxServer() {
-    when(_server1.getQueryMailboxPort()).thenReturn(456);
-    when(_server2.getQueryMailboxPort()).thenReturn(789);
-    StageMetadata stageMetadata = new StageMetadata();
-    stageMetadata.setServerInstances(Arrays.asList(_server1, _server2));
+    VirtualServerAddress server1 = new VirtualServerAddress("localhost", 456, 0);
+    VirtualServerAddress server2 = new VirtualServerAddress("localhost", 789, 1);
+    StageMetadata stageMetadata = new StageMetadata.Builder()
+        .setWorkerMetadataList(Stream.of(server1, server2).map(
+            s -> new WorkerMetadata.Builder().setVirtualServerAddress(s).build()).collect(Collectors.toList()))
+        .build();
     OpChainExecutionContext context =
         new OpChainExecutionContext(_mailboxService, 0, 0, RECEIVER_ADDRESS, Long.MAX_VALUE, Long.MAX_VALUE,
-            Collections.singletonMap(1, stageMetadata), false);
+            Arrays.asList(null, stageMetadata), false);
     //noinspection resource
     new SortedMailboxReceiveOperator(context, RelDistribution.Type.SINGLETON, DATA_SCHEMA, COLLATION_KEYS,
         COLLATION_DIRECTIONS, false, 1);
@@ -107,11 +117,9 @@ public class SortedMailboxReceiveOperatorTest {
 
   @Test(expectedExceptions = IllegalStateException.class, expectedExceptionsMessageRegExp = "Multiple instances.*")
   public void shouldThrowReceiveSingletonFromMultiMatchMailboxServer() {
-    StageMetadata stageMetadata = new StageMetadata();
-    stageMetadata.setServerInstances(Arrays.asList(_server1, _server2));
     OpChainExecutionContext context =
         new OpChainExecutionContext(_mailboxService, 0, 0, RECEIVER_ADDRESS, Long.MAX_VALUE, Long.MAX_VALUE,
-            Collections.singletonMap(1, stageMetadata), false);
+            _stageMetadataListBoth, false);
     //noinspection resource
     new SortedMailboxReceiveOperator(context, RelDistribution.Type.SINGLETON, DATA_SCHEMA, COLLATION_KEYS,
         COLLATION_DIRECTIONS, false, 1);
@@ -121,7 +129,7 @@ public class SortedMailboxReceiveOperatorTest {
   public void shouldThrowRangeDistributionNotSupported() {
     OpChainExecutionContext context =
         new OpChainExecutionContext(_mailboxService, 0, 0, RECEIVER_ADDRESS, Long.MAX_VALUE, Long.MAX_VALUE,
-            Collections.emptyMap(), false);
+            Collections.emptyList(), false);
     //noinspection resource
     new SortedMailboxReceiveOperator(context, RelDistribution.Type.RANGE_DISTRIBUTED, DATA_SCHEMA, COLLATION_KEYS,
         COLLATION_DIRECTIONS, false, 1);
@@ -130,12 +138,9 @@ public class SortedMailboxReceiveOperatorTest {
   @Test(expectedExceptions = IllegalStateException.class, expectedExceptionsMessageRegExp = "Collation keys.*")
   public void shouldThrowOnEmptyCollationKey() {
     when(_mailboxService.getReceivingMailbox(MAILBOX_ID_1)).thenReturn(_mailbox1);
-
-    StageMetadata stageMetadata = new StageMetadata();
-    stageMetadata.setServerInstances(Collections.singletonList(_server1));
     OpChainExecutionContext context =
         new OpChainExecutionContext(_mailboxService, 0, 0, RECEIVER_ADDRESS, 10L, System.currentTimeMillis() + 10L,
-            Collections.singletonMap(1, stageMetadata), false);
+            _stageMetadataList1, false);
     //noinspection resource
     new SortedMailboxReceiveOperator(context, RelDistribution.Type.SINGLETON, DATA_SCHEMA, Collections.emptyList(),
         Collections.emptyList(), false, 1);
@@ -145,13 +150,10 @@ public class SortedMailboxReceiveOperatorTest {
   public void shouldTimeoutOnExtraLongSleep()
       throws InterruptedException {
     when(_mailboxService.getReceivingMailbox(MAILBOX_ID_1)).thenReturn(_mailbox1);
-
     // Short timeoutMs should result in timeout
-    StageMetadata stageMetadata = new StageMetadata();
-    stageMetadata.setServerInstances(Collections.singletonList(_server1));
     OpChainExecutionContext context =
         new OpChainExecutionContext(_mailboxService, 0, 0, RECEIVER_ADDRESS, 10L, System.currentTimeMillis() + 10L,
-            Collections.singletonMap(1, stageMetadata), false);
+            _stageMetadataList1, false);
     try (SortedMailboxReceiveOperator receiveOp = new SortedMailboxReceiveOperator(context,
         RelDistribution.Type.SINGLETON, DATA_SCHEMA, COLLATION_KEYS, COLLATION_DIRECTIONS, false, 1)) {
       Thread.sleep(100L);
@@ -163,7 +165,7 @@ public class SortedMailboxReceiveOperatorTest {
 
     // Longer timeout or default timeout (10s) doesn't result in timeout
     context = new OpChainExecutionContext(_mailboxService, 0, 0, RECEIVER_ADDRESS, 10_000L,
-        System.currentTimeMillis() + 10_000L, Collections.singletonMap(1, stageMetadata), false);
+        System.currentTimeMillis() + 10_000L, _stageMetadataList1, false);
     try (SortedMailboxReceiveOperator receiveOp = new SortedMailboxReceiveOperator(context,
         RelDistribution.Type.SINGLETON, DATA_SCHEMA, COLLATION_KEYS, COLLATION_DIRECTIONS, false, 1)) {
       Thread.sleep(100L);
@@ -175,12 +177,9 @@ public class SortedMailboxReceiveOperatorTest {
   @Test
   public void shouldReceiveSingletonNullMailbox() {
     when(_mailboxService.getReceivingMailbox(MAILBOX_ID_1)).thenReturn(_mailbox1);
-
-    StageMetadata stageMetadata = new StageMetadata();
-    stageMetadata.setServerInstances(Collections.singletonList(_server1));
     OpChainExecutionContext context =
         new OpChainExecutionContext(_mailboxService, 0, 0, RECEIVER_ADDRESS, Long.MAX_VALUE, Long.MAX_VALUE,
-            Collections.singletonMap(1, stageMetadata), false);
+            _stageMetadataList1, false);
     try (SortedMailboxReceiveOperator receiveOp = new SortedMailboxReceiveOperator(context,
         RelDistribution.Type.SINGLETON, DATA_SCHEMA, COLLATION_KEYS, COLLATION_DIRECTIONS, false, 1)) {
       assertTrue(receiveOp.nextBlock().isNoOpBlock());
@@ -191,12 +190,9 @@ public class SortedMailboxReceiveOperatorTest {
   public void shouldReceiveEosDirectlyFromSender() {
     when(_mailboxService.getReceivingMailbox(MAILBOX_ID_1)).thenReturn(_mailbox1);
     when(_mailbox1.poll()).thenReturn(TransferableBlockUtils.getEndOfStreamTransferableBlock());
-
-    StageMetadata stageMetadata = new StageMetadata();
-    stageMetadata.setServerInstances(Collections.singletonList(_server1));
     OpChainExecutionContext context =
         new OpChainExecutionContext(_mailboxService, 0, 0, RECEIVER_ADDRESS, Long.MAX_VALUE, Long.MAX_VALUE,
-            Collections.singletonMap(1, stageMetadata), false);
+            _stageMetadataList1, false);
     try (SortedMailboxReceiveOperator receiveOp = new SortedMailboxReceiveOperator(context,
         RelDistribution.Type.SINGLETON, DATA_SCHEMA, COLLATION_KEYS, COLLATION_DIRECTIONS, false, 1)) {
       assertTrue(receiveOp.nextBlock().isEndOfStreamBlock());
@@ -209,12 +205,9 @@ public class SortedMailboxReceiveOperatorTest {
     Object[] row = new Object[]{1, 1};
     when(_mailbox1.poll()).thenReturn(OperatorTestUtil.block(DATA_SCHEMA, row),
         TransferableBlockUtils.getEndOfStreamTransferableBlock());
-
-    StageMetadata stageMetadata = new StageMetadata();
-    stageMetadata.setServerInstances(Collections.singletonList(_server1));
     OpChainExecutionContext context =
         new OpChainExecutionContext(_mailboxService, 0, 0, RECEIVER_ADDRESS, Long.MAX_VALUE, Long.MAX_VALUE,
-            Collections.singletonMap(1, stageMetadata), false);
+            _stageMetadataList1, false);
     try (SortedMailboxReceiveOperator receiveOp = new SortedMailboxReceiveOperator(context,
         RelDistribution.Type.SINGLETON, DATA_SCHEMA, COLLATION_KEYS, COLLATION_DIRECTIONS, false, 1)) {
       List<Object[]> actualRows = receiveOp.nextBlock().getContainer();
@@ -230,12 +223,9 @@ public class SortedMailboxReceiveOperatorTest {
     String errorMessage = "TEST ERROR";
     when(_mailbox1.poll()).thenReturn(
         TransferableBlockUtils.getErrorTransferableBlock(new RuntimeException(errorMessage)));
-
-    StageMetadata stageMetadata = new StageMetadata();
-    stageMetadata.setServerInstances(Collections.singletonList(_server1));
     OpChainExecutionContext context =
         new OpChainExecutionContext(_mailboxService, 0, 0, RECEIVER_ADDRESS, Long.MAX_VALUE, Long.MAX_VALUE,
-            Collections.singletonMap(1, stageMetadata), false);
+            _stageMetadataList1, false);
     try (SortedMailboxReceiveOperator receiveOp = new SortedMailboxReceiveOperator(context,
         RelDistribution.Type.SINGLETON, DATA_SCHEMA, COLLATION_KEYS, COLLATION_DIRECTIONS, false, 1)) {
       TransferableBlock block = receiveOp.nextBlock();
@@ -252,12 +242,9 @@ public class SortedMailboxReceiveOperatorTest {
     Object[] row = new Object[]{1, 1};
     when(_mailbox2.poll()).thenReturn(OperatorTestUtil.block(DATA_SCHEMA, row),
         TransferableBlockUtils.getEndOfStreamTransferableBlock());
-
-    StageMetadata stageMetadata = new StageMetadata();
-    stageMetadata.setServerInstances(Arrays.asList(_server1, _server2));
     OpChainExecutionContext context =
         new OpChainExecutionContext(_mailboxService, 0, 0, RECEIVER_ADDRESS, Long.MAX_VALUE, Long.MAX_VALUE,
-            Collections.singletonMap(1, stageMetadata), false);
+            _stageMetadataListBoth, false);
     try (SortedMailboxReceiveOperator receiveOp = new SortedMailboxReceiveOperator(context,
         RelDistribution.Type.HASH_DISTRIBUTED, DATA_SCHEMA, COLLATION_KEYS, COLLATION_DIRECTIONS, false, 1)) {
       assertTrue(receiveOp.nextBlock().isNoOpBlock());
@@ -278,12 +265,9 @@ public class SortedMailboxReceiveOperatorTest {
     Object[] row = new Object[]{3, 3};
     when(_mailbox2.poll()).thenReturn(OperatorTestUtil.block(DATA_SCHEMA, row),
         TransferableBlockUtils.getEndOfStreamTransferableBlock());
-
-    StageMetadata stageMetadata = new StageMetadata();
-    stageMetadata.setServerInstances(Arrays.asList(_server1, _server2));
     OpChainExecutionContext context =
         new OpChainExecutionContext(_mailboxService, 0, 0, RECEIVER_ADDRESS, Long.MAX_VALUE, Long.MAX_VALUE,
-            Collections.singletonMap(1, stageMetadata), false);
+            _stageMetadataListBoth, false);
     try (SortedMailboxReceiveOperator receiveOp = new SortedMailboxReceiveOperator(context,
         RelDistribution.Type.HASH_DISTRIBUTED, DATA_SCHEMA, COLLATION_KEYS, COLLATION_DIRECTIONS, false, 1)) {
       TransferableBlock block = receiveOp.nextBlock();
@@ -306,12 +290,9 @@ public class SortedMailboxReceiveOperatorTest {
     when(_mailbox2.poll()).thenReturn(OperatorTestUtil.block(DATA_SCHEMA, row3),
         OperatorTestUtil.block(DATA_SCHEMA, row4), OperatorTestUtil.block(DATA_SCHEMA, row5),
         TransferableBlockUtils.getEndOfStreamTransferableBlock());
-
-    StageMetadata stageMetadata = new StageMetadata();
-    stageMetadata.setServerInstances(Arrays.asList(_server1, _server2));
     OpChainExecutionContext context =
         new OpChainExecutionContext(_mailboxService, 0, 0, RECEIVER_ADDRESS, Long.MAX_VALUE, Long.MAX_VALUE,
-            Collections.singletonMap(1, stageMetadata), false);
+            _stageMetadataListBoth, false);
     try (SortedMailboxReceiveOperator receiveOp = new SortedMailboxReceiveOperator(context,
         RelDistribution.Type.HASH_DISTRIBUTED, DATA_SCHEMA, COLLATION_KEYS, COLLATION_DIRECTIONS, false, 1)) {
       assertEquals(receiveOp.nextBlock().getContainer(), Arrays.asList(row5, row2, row4, row1, row3));
@@ -340,11 +321,9 @@ public class SortedMailboxReceiveOperatorTest {
         OperatorTestUtil.block(dataSchema, row4), OperatorTestUtil.block(dataSchema, row5),
         TransferableBlockUtils.getEndOfStreamTransferableBlock());
 
-    StageMetadata stageMetadata = new StageMetadata();
-    stageMetadata.setServerInstances(Arrays.asList(_server1, _server2));
     OpChainExecutionContext context =
         new OpChainExecutionContext(_mailboxService, 0, 0, RECEIVER_ADDRESS, Long.MAX_VALUE, Long.MAX_VALUE,
-            Collections.singletonMap(1, stageMetadata), false);
+            _stageMetadataListBoth, false);
     try (SortedMailboxReceiveOperator receiveOp = new SortedMailboxReceiveOperator(context,
         RelDistribution.Type.HASH_DISTRIBUTED, dataSchema, collationKeys, collationDirection, false, 1)) {
       assertEquals(receiveOp.nextBlock().getContainer(), Arrays.asList(row1, row2, row3, row5, row4));
diff --git a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/plan/serde/QueryPlanSerDeUtilsTest.java b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/plan/serde/QueryPlanSerDeUtilsTest.java
index 62cd9b6e18..9ca24ebf48 100644
--- a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/plan/serde/QueryPlanSerDeUtilsTest.java
+++ b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/plan/serde/QueryPlanSerDeUtilsTest.java
@@ -19,48 +19,40 @@
 
 package org.apache.pinot.query.runtime.plan.serde;
 
-import org.apache.pinot.query.routing.VirtualServer;
+import org.apache.pinot.query.routing.VirtualServerAddress;
 import org.mockito.Mockito;
 import org.testng.Assert;
 import org.testng.annotations.Test;
 
-import static org.testng.Assert.*;
-
 
 public class QueryPlanSerDeUtilsTest {
 
   @Test
   public void shouldSerializeServer() {
     // Given:
-    VirtualServer server = Mockito.mock(VirtualServer.class);
-    Mockito.when(server.getVirtualId()).thenReturn(1);
-    Mockito.when(server.getHostname()).thenReturn("Server_192.987.1.123");
-    Mockito.when(server.getPort()).thenReturn(80);
-    Mockito.when(server.getGrpcPort()).thenReturn(10);
-    Mockito.when(server.getQueryServicePort()).thenReturn(20);
-    Mockito.when(server.getQueryMailboxPort()).thenReturn(30);
+    VirtualServerAddress server = Mockito.mock(VirtualServerAddress.class);
+    Mockito.when(server.workerId()).thenReturn(1);
+    Mockito.when(server.hostname()).thenReturn("Server_192.987.1.123");
+    Mockito.when(server.port()).thenReturn(80);
 
     // When:
-    String serialized = QueryPlanSerDeUtils.instanceToString(server);
+    String serialized = QueryPlanSerDeUtils.addressToProto(server);
 
     // Then:
-    Assert.assertEquals(serialized, "1@Server_192.987.1.123:80(10:20:30)");
+    Assert.assertEquals(serialized, "1@Server_192.987.1.123:80");
   }
 
   @Test
   public void shouldDeserializeServerString() {
     // Given:
-    String serverString = "1@Server_192.987.1.123:80(10:20:30)";
+    String serverString = "1@Server_192.987.1.123:80";
 
     // When:
-    VirtualServer server = QueryPlanSerDeUtils.stringToInstance(serverString);
+    VirtualServerAddress server = QueryPlanSerDeUtils.protoToAddress(serverString);
 
     // Then:
-    Assert.assertEquals(server.getVirtualId(), 1);
-    Assert.assertEquals(server.getHostname(), "Server_192.987.1.123");
-    Assert.assertEquals(server.getPort(), 80);
-    Assert.assertEquals(server.getGrpcPort(), 10);
-    Assert.assertEquals(server.getQueryServicePort(), 20);
-    Assert.assertEquals(server.getQueryMailboxPort(), 30);
+    Assert.assertEquals(server.workerId(), 1);
+    Assert.assertEquals(server.hostname(), "Server_192.987.1.123");
+    Assert.assertEquals(server.port(), 80);
   }
 }
diff --git a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/queries/ResourceBasedQueriesTest.java b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/queries/ResourceBasedQueriesTest.java
index 90ceeb2b72..686d644445 100644
--- a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/queries/ResourceBasedQueriesTest.java
+++ b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/queries/ResourceBasedQueriesTest.java
@@ -45,7 +45,7 @@ import org.apache.pinot.core.query.reduce.ExecutionStatsAggregator;
 import org.apache.pinot.query.QueryEnvironmentTestBase;
 import org.apache.pinot.query.QueryServerEnclosure;
 import org.apache.pinot.query.mailbox.MailboxService;
-import org.apache.pinot.query.routing.WorkerInstance;
+import org.apache.pinot.query.routing.QueryServerInstance;
 import org.apache.pinot.query.runtime.QueryRunnerTestBase;
 import org.apache.pinot.query.service.QueryConfig;
 import org.apache.pinot.query.testutils.MockInstanceDataManagerFactory;
@@ -191,8 +191,8 @@ public class ResourceBasedQueriesTest extends QueryRunnerTestBase {
     // this is only use for test identifier purpose.
     int port1 = server1.getPort();
     int port2 = server2.getPort();
-    _servers.put(new WorkerInstance("localhost", port1, port1, port1, port1), server1);
-    _servers.put(new WorkerInstance("localhost", port2, port2, port2, port2), server2);
+    _servers.put(new QueryServerInstance("localhost", port1, port1), server1);
+    _servers.put(new QueryServerInstance("localhost", port2, port2), server2);
   }
 
   @AfterClass
diff --git a/pinot-query-runtime/src/test/java/org/apache/pinot/query/service/QueryServerTest.java b/pinot-query-runtime/src/test/java/org/apache/pinot/query/service/QueryServerTest.java
index cb2199177c..fe0f54667b 100644
--- a/pinot-query-runtime/src/test/java/org/apache/pinot/query/service/QueryServerTest.java
+++ b/pinot-query-runtime/src/test/java/org/apache/pinot/query/service/QueryServerTest.java
@@ -32,20 +32,22 @@ import org.apache.pinot.common.proto.PinotQueryWorkerGrpc;
 import org.apache.pinot.common.proto.Worker;
 import org.apache.pinot.common.utils.NamedThreadFactory;
 import org.apache.pinot.core.query.scheduler.resources.ResourceManager;
-import org.apache.pinot.core.transport.ServerInstance;
+import org.apache.pinot.core.routing.TimeBoundaryInfo;
 import org.apache.pinot.query.QueryEnvironment;
 import org.apache.pinot.query.QueryEnvironmentTestBase;
 import org.apache.pinot.query.QueryTestSet;
 import org.apache.pinot.query.planner.QueryPlan;
-import org.apache.pinot.query.planner.StageMetadata;
 import org.apache.pinot.query.planner.stage.StageNode;
-import org.apache.pinot.query.routing.VirtualServer;
-import org.apache.pinot.query.routing.WorkerInstance;
+import org.apache.pinot.query.routing.QueryServerInstance;
+import org.apache.pinot.query.routing.StageMetadata;
+import org.apache.pinot.query.routing.VirtualServerAddress;
+import org.apache.pinot.query.routing.WorkerMetadata;
 import org.apache.pinot.query.runtime.QueryRunner;
 import org.apache.pinot.query.runtime.plan.serde.QueryPlanSerDeUtils;
 import org.apache.pinot.query.service.dispatch.QueryDispatcher;
 import org.apache.pinot.query.testutils.QueryTestUtils;
 import org.apache.pinot.spi.utils.CommonConstants;
+import org.apache.pinot.spi.utils.EqualityUtils;
 import org.apache.pinot.util.TestUtils;
 import org.mockito.Mockito;
 import org.testng.Assert;
@@ -65,7 +67,6 @@ public class QueryServerTest extends QueryTestSet {
       ResourceManager.DEFAULT_QUERY_RUNNER_THREADS, new NamedThreadFactory("QueryServerTest_Runner"));
 
   private final Map<Integer, QueryServer> _queryServerMap = new HashMap<>();
-  private final Map<Integer, ServerInstance> _queryServerInstanceMap = new HashMap<>();
   private final Map<Integer, QueryRunner> _queryRunnerMap = new HashMap<>();
 
   private QueryEnvironment _queryEnvironment;
@@ -83,10 +84,6 @@ public class QueryServerTest extends QueryTestSet {
       queryServer.start();
       _queryServerMap.put(availablePort, queryServer);
       _queryRunnerMap.put(availablePort, queryRunner);
-      // this only test the QueryServer functionality so the server port can be the same as the mailbox port.
-      // this is only use for test identifier purpose.
-      _queryServerInstanceMap.put(availablePort, new WorkerInstance("localhost", availablePort, availablePort,
-          availablePort, availablePort));
     }
 
     List<Integer> portList = Lists.newArrayList(_queryServerMap.keySet());
@@ -110,7 +107,7 @@ public class QueryServerTest extends QueryTestSet {
       throws Exception {
     QueryPlan queryPlan = _queryEnvironment.planQuery(sql);
 
-    for (int stageId : queryPlan.getStageMetadataMap().keySet()) {
+    for (int stageId : queryPlan.getDispatchablePlanMetadataMap().keySet()) {
       if (stageId > 0) { // we do not test reduce stage.
         // only get one worker request out.
         Worker.QueryRequest queryRequest = getQueryRequest(queryPlan, stageId);
@@ -118,7 +115,7 @@ public class QueryServerTest extends QueryTestSet {
         // submit the request for testing.
         submitRequest(queryRequest);
 
-        StageMetadata stageMetadata = queryPlan.getStageMetadataMap().get(stageId);
+        List<StageMetadata> stageMetadataList = queryPlan.getStageMetadataList();
 
         // ensure mock query runner received correctly deserialized payload.
         QueryRunner mockRunner = _queryRunnerMap.get(
@@ -131,7 +128,7 @@ public class QueryServerTest extends QueryTestSet {
             Mockito.verify(mockRunner).processQuery(Mockito.argThat(distributedStagePlan -> {
               StageNode stageNode = queryPlan.getQueryStageMap().get(stageId);
               return isStageNodesEqual(stageNode, distributedStagePlan.getStageRoot())
-                  && isMetadataMapsEqual(stageMetadata, distributedStagePlan.getMetadataMap().get(stageId));
+                  && isMetadataMapsEqual(stageId, stageMetadataList, distributedStagePlan.getStageMetadataList());
             }), Mockito.argThat(requestMetadataMap ->
                 requestIdStr.equals(requestMetadataMap.get(QueryConfig.KEY_OF_BROKER_REQUEST_ID))));
             return true;
@@ -139,14 +136,56 @@ public class QueryServerTest extends QueryTestSet {
             return false;
           }
         }, 10000L, "Error verifying mock QueryRunner intercepted query payload!");
+
+        // reset the mock runner.
+        Mockito.reset(mockRunner);
       }
     }
   }
 
-  private static boolean isMetadataMapsEqual(StageMetadata left, StageMetadata right) {
-    return left.getServerInstances().equals(right.getServerInstances())
-        && left.getServerInstanceToSegmentsMap().equals(right.getServerInstanceToSegmentsMap())
-        && left.getScannedTables().equals(right.getScannedTables());
+  private boolean isMetadataMapsEqual(int stageId, List<StageMetadata> expectedStageMetadataList,
+      List<StageMetadata> deserializedStageMetadataList) {
+    StageMetadata expected = expectedStageMetadataList.get(stageId);
+    StageMetadata actual = deserializedStageMetadataList.get(stageId);
+    return isStageMetadataEqual(expected, actual);
+  }
+
+  private boolean isStageMetadataEqual(StageMetadata expected, StageMetadata actual) {
+    if (!EqualityUtils.isEqual(StageMetadata.getTableName(expected), StageMetadata.getTableName(actual))) {
+      return false;
+    }
+    TimeBoundaryInfo expectedTimeBoundaryInfo = StageMetadata.getTimeBoundary(expected);
+    TimeBoundaryInfo actualTimeBoundaryInfo = StageMetadata.getTimeBoundary(actual);
+    if (expectedTimeBoundaryInfo == null && actualTimeBoundaryInfo != null
+        || expectedTimeBoundaryInfo != null && actualTimeBoundaryInfo == null) {
+      return false;
+    }
+    if (expectedTimeBoundaryInfo != null && actualTimeBoundaryInfo != null
+        && (!EqualityUtils.isEqual(expectedTimeBoundaryInfo.getTimeColumn(), actualTimeBoundaryInfo.getTimeColumn())
+        || !EqualityUtils.isEqual(expectedTimeBoundaryInfo.getTimeValue(), actualTimeBoundaryInfo.getTimeValue()))) {
+      return false;
+    }
+    List<WorkerMetadata> expectedWorkerMetadataList = expected.getWorkerMetadataList();
+    List<WorkerMetadata> actualWorkerMetadataList = actual.getWorkerMetadataList();
+    if (expectedWorkerMetadataList.size() != actualWorkerMetadataList.size()) {
+      return false;
+    }
+    for (int i = 0; i < expectedWorkerMetadataList.size(); i++) {
+      if (!isWorkerMetadataEqual(expectedWorkerMetadataList.get(i), actualWorkerMetadataList.get(i))) {
+        return false;
+      }
+    }
+    return true;
+  }
+
+  private static boolean isWorkerMetadataEqual(WorkerMetadata expected, WorkerMetadata actual) {
+    if (!expected.getVirtualServerAddress().hostname().equals(actual.getVirtualServerAddress().hostname())
+        || expected.getVirtualServerAddress().port() != actual.getVirtualServerAddress().port()
+        || expected.getVirtualServerAddress().workerId() != actual.getVirtualServerAddress().workerId()) {
+      return false;
+    }
+    return EqualityUtils.isEqual(WorkerMetadata.getTableSegmentsMap(expected),
+        WorkerMetadata.getTableSegmentsMap(actual));
   }
 
   private static boolean isStageNodesEqual(StageNode left, StageNode right) {
@@ -178,16 +217,22 @@ public class QueryServerTest extends QueryTestSet {
   }
 
   private Worker.QueryRequest getQueryRequest(QueryPlan queryPlan, int stageId) {
-    VirtualServer serverInstance = queryPlan.getStageMetadataMap().get(stageId).getServerInstances().get(0);
+    Map<QueryServerInstance, List<Integer>> serverInstanceToWorkerIdMap =
+        queryPlan.getDispatchablePlanMetadataMap().get(stageId).getServerInstanceToWorkerIdMap();
+    // this particular test set requires the request to have a single QueryServerInstance to dispatch to
+    // as it is not testing the multi-tenancy dispatch (which is in the QueryDispatcherTest)
+    QueryServerInstance serverInstance = serverInstanceToWorkerIdMap.keySet().iterator().next();
+    int workerId = serverInstanceToWorkerIdMap.get(serverInstance).get(0);
 
     return Worker.QueryRequest.newBuilder().setStagePlan(QueryPlanSerDeUtils.serialize(
-            QueryDispatcher.constructDistributedStagePlan(queryPlan, stageId, serverInstance)))
+        QueryDispatcher.constructDistributedStagePlan(queryPlan, stageId,
+            new VirtualServerAddress(serverInstance, workerId))))
         // the default configurations that must exist.
         .putMetadata(QueryConfig.KEY_OF_BROKER_REQUEST_ID, String.valueOf(RANDOM_REQUEST_ID_GEN.nextLong()))
         .putMetadata(QueryConfig.KEY_OF_BROKER_REQUEST_TIMEOUT_MS,
             String.valueOf(CommonConstants.Broker.DEFAULT_BROKER_TIMEOUT_MS))
         // extra configurations we want to test also parsed out correctly.
         .putMetadata(KEY_OF_SERVER_INSTANCE_HOST, serverInstance.getHostname())
-        .putMetadata(KEY_OF_SERVER_INSTANCE_PORT, String.valueOf(serverInstance.getPort())).build();
+        .putMetadata(KEY_OF_SERVER_INSTANCE_PORT, String.valueOf(serverInstance.getQueryServicePort())).build();
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@pinot.apache.org
For additional commands, e-mail: commits-help@pinot.apache.org