You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@pinot.apache.org by yu...@apache.org on 2022/11/29 16:55:31 UTC

[pinot] branch master updated: [multistage] Add Multi Stage Strict Replica Group Routing Strategy (#9808)

This is an automated email from the ASF dual-hosted git repository.

yupeng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git


The following commit(s) were added to refs/heads/master by this push:
     new e41bdd0c6f [multistage] Add Multi Stage Strict Replica Group Routing Strategy (#9808)
e41bdd0c6f is described below

commit e41bdd0c6ff851ca65102fb7bdc83161753e839c
Author: Ankit Sultana <an...@uber.com>
AuthorDate: Tue Nov 29 22:25:23 2022 +0530

    [multistage] Add Multi Stage Strict Replica Group Routing Strategy (#9808)
    
    * Rebase on master
    
    * Cleanup
    
    * Add tests
    
    * Fix test
    
    * Address feedback
    
    * Mark feature is in Beta
    
    * Address feedback
    
    * Address feedback
    
    * Rename to MultiStageReplicaGroup
---
 .../broker/api/resources/PinotBrokerDebug.java     |  15 +-
 .../requesthandler/BaseBrokerRequestHandler.java   |   4 +-
 .../MultiStageBrokerRequestHandler.java            |   2 +-
 .../pinot/broker/routing/BrokerRoutingManager.java |  13 +-
 .../instanceselector/BaseInstanceSelector.java     |  10 +-
 .../routing/instanceselector/InstanceSelector.java |   7 +-
 .../instanceselector/InstanceSelectorFactory.java  |  11 +-
 .../MultiStageReplicaGroupSelector.java            | 150 +++++++++++
 .../broker/broker/HelixBrokerStarterTest.java      |   7 +-
 .../BaseBrokerRequestHandlerTest.java              |   2 +-
 .../instanceselector/InstanceSelectorTest.java     | 277 +++++++++++++++------
 .../apache/pinot/core/routing/RoutingManager.java  |   2 +-
 .../org/apache/pinot/query/QueryEnvironment.java   |  16 +-
 .../pinot/query/planner/logical/StagePlanner.java  |   6 +-
 .../apache/pinot/query/routing/WorkerManager.java  |  16 +-
 .../query/testutils/MockRoutingManagerFactory.java |   2 +-
 .../pinot/spi/config/table/RoutingConfig.java      |   1 +
 17 files changed, 426 insertions(+), 115 deletions(-)

diff --git a/pinot-broker/src/main/java/org/apache/pinot/broker/api/resources/PinotBrokerDebug.java b/pinot-broker/src/main/java/org/apache/pinot/broker/api/resources/PinotBrokerDebug.java
index daa17a0e5c..f9eb55b535 100644
--- a/pinot-broker/src/main/java/org/apache/pinot/broker/api/resources/PinotBrokerDebug.java
+++ b/pinot-broker/src/main/java/org/apache/pinot/broker/api/resources/PinotBrokerDebug.java
@@ -30,6 +30,7 @@ import io.swagger.annotations.SwaggerDefinition;
 import java.util.List;
 import java.util.Map;
 import java.util.TreeMap;
+import java.util.concurrent.atomic.AtomicLong;
 import javax.inject.Inject;
 import javax.ws.rs.GET;
 import javax.ws.rs.Path;
@@ -59,6 +60,9 @@ import static org.apache.pinot.spi.utils.CommonConstants.SWAGGER_AUTHORIZATION_K
 // TODO: Add APIs to return the RoutingTable (with unavailable segments)
 public class PinotBrokerDebug {
 
+  // Request ID is passed to the RoutingManager to rotate the selected replica-group.
+  private final AtomicLong _requestIdGenerator = new AtomicLong();
+
   @Inject
   private BrokerRoutingManager _routingManager;
 
@@ -102,7 +106,7 @@ public class PinotBrokerDebug {
     if (tableType != TableType.REALTIME) {
       String offlineTableName = TableNameBuilder.OFFLINE.tableNameWithType(tableName);
       RoutingTable routingTable = _routingManager.getRoutingTable(
-          CalciteSqlCompiler.compileToBrokerRequest("SELECT * FROM " + offlineTableName));
+          CalciteSqlCompiler.compileToBrokerRequest("SELECT * FROM " + offlineTableName), getRequestId());
       if (routingTable != null) {
         result.put(offlineTableName, routingTable.getServerInstanceToSegmentsMap());
       }
@@ -110,7 +114,7 @@ public class PinotBrokerDebug {
     if (tableType != TableType.OFFLINE) {
       String realtimeTableName = TableNameBuilder.REALTIME.tableNameWithType(tableName);
       RoutingTable routingTable = _routingManager.getRoutingTable(
-          CalciteSqlCompiler.compileToBrokerRequest("SELECT * FROM " + realtimeTableName));
+          CalciteSqlCompiler.compileToBrokerRequest("SELECT * FROM " + realtimeTableName), getRequestId());
       if (routingTable != null) {
         result.put(realtimeTableName, routingTable.getServerInstanceToSegmentsMap());
       }
@@ -133,7 +137,8 @@ public class PinotBrokerDebug {
   })
   public Map<ServerInstance, List<String>> getRoutingTableForQuery(
       @ApiParam(value = "SQL query (table name should have type suffix)") @QueryParam("query") String query) {
-    RoutingTable routingTable = _routingManager.getRoutingTable(CalciteSqlCompiler.compileToBrokerRequest(query));
+    RoutingTable routingTable = _routingManager.getRoutingTable(CalciteSqlCompiler.compileToBrokerRequest(query),
+        getRequestId());
     if (routingTable != null) {
       return routingTable.getServerInstanceToSegmentsMap();
     } else {
@@ -157,4 +162,8 @@ public class PinotBrokerDebug {
   public String getServerRoutingStats() {
     return _serverRoutingStatsManager.getServerRoutingStatsStr();
   }
+
+  private long getRequestId() {
+    return _requestIdGenerator.getAndIncrement();
+  }
 }
diff --git a/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/BaseBrokerRequestHandler.java b/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/BaseBrokerRequestHandler.java
index ade86b61bd..ae88689a51 100644
--- a/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/BaseBrokerRequestHandler.java
+++ b/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/BaseBrokerRequestHandler.java
@@ -551,7 +551,7 @@ public abstract class BaseBrokerRequestHandler implements BrokerRequestHandler {
     int numPrunedSegmentsTotal = 0;
     if (offlineBrokerRequest != null) {
       // NOTE: Routing table might be null if table is just removed
-      RoutingTable routingTable = _routingManager.getRoutingTable(offlineBrokerRequest);
+      RoutingTable routingTable = _routingManager.getRoutingTable(offlineBrokerRequest, requestId);
       if (routingTable != null) {
         unavailableSegments.addAll(routingTable.getUnavailableSegments());
         Map<ServerInstance, List<String>> serverInstanceToSegmentsMap = routingTable.getServerInstanceToSegmentsMap();
@@ -567,7 +567,7 @@ public abstract class BaseBrokerRequestHandler implements BrokerRequestHandler {
     }
     if (realtimeBrokerRequest != null) {
       // NOTE: Routing table might be null if table is just removed
-      RoutingTable routingTable = _routingManager.getRoutingTable(realtimeBrokerRequest);
+      RoutingTable routingTable = _routingManager.getRoutingTable(realtimeBrokerRequest, requestId);
       if (routingTable != null) {
         unavailableSegments.addAll(routingTable.getUnavailableSegments());
         Map<ServerInstance, List<String>> serverInstanceToSegmentsMap = routingTable.getServerInstanceToSegmentsMap();
diff --git a/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/MultiStageBrokerRequestHandler.java b/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/MultiStageBrokerRequestHandler.java
index 0b1a67efbb..75e8d8d3f0 100644
--- a/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/MultiStageBrokerRequestHandler.java
+++ b/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/MultiStageBrokerRequestHandler.java
@@ -144,7 +144,7 @@ public class MultiStageBrokerRequestHandler extends BaseBrokerRequestHandler {
           return constructMultistageExplainPlan(query, plan);
         case SELECT:
         default:
-          queryPlan = _queryEnvironment.planQuery(query, sqlNodeAndOptions);
+          queryPlan = _queryEnvironment.planQuery(query, sqlNodeAndOptions, requestId);
           break;
       }
     } catch (Exception e) {
diff --git a/pinot-broker/src/main/java/org/apache/pinot/broker/routing/BrokerRoutingManager.java b/pinot-broker/src/main/java/org/apache/pinot/broker/routing/BrokerRoutingManager.java
index ee0982606d..883e9cfb02 100644
--- a/pinot-broker/src/main/java/org/apache/pinot/broker/routing/BrokerRoutingManager.java
+++ b/pinot-broker/src/main/java/org/apache/pinot/broker/routing/BrokerRoutingManager.java
@@ -81,7 +81,7 @@ import org.slf4j.LoggerFactory;
  *   <li>{@link #removeRouting(String)}: Removes the routing for a table</li>
  *   <li>{@link #refreshSegment(String, String)}: Refreshes the metadata for a segment</li>
  *   <li>{@link #routingExists(String)}: Returns whether the routing exists for a table</li>
- *   <li>{@link #getRoutingTable(BrokerRequest)}: Returns the routing table for a query</li>
+ *   <li>{@link #getRoutingTable(BrokerRequest, long)}: Returns the routing table for a query</li>
  *   <li>{@link #getTimeBoundaryInfo(String)}: Returns the time boundary info for a table</li>
  *   <li>{@link #getQueryTimeoutMs(String)}: Returns the table-level query timeout in milliseconds for a table</li>
  * </ul>
@@ -437,7 +437,8 @@ public class BrokerRoutingManager implements RoutingManager, ClusterChangeHandle
     AdaptiveServerSelector adaptiveServerSelector =
         AdaptiveServerSelectorFactory.getAdaptiveServerSelector(_serverRoutingStatsManager, _pinotConfig);
     InstanceSelector instanceSelector =
-        InstanceSelectorFactory.getInstanceSelector(tableConfig, _brokerMetrics, adaptiveServerSelector);
+        InstanceSelectorFactory.getInstanceSelector(tableConfig, _propertyStore, _brokerMetrics,
+            adaptiveServerSelector);
     instanceSelector.init(_routableServers, idealState, externalView, preSelectedOnlineSegments);
 
     // Add time boundary manager if both offline and real-time part exist for a hybrid table
@@ -567,13 +568,13 @@ public class BrokerRoutingManager implements RoutingManager, ClusterChangeHandle
    */
   @Nullable
   @Override
-  public RoutingTable getRoutingTable(BrokerRequest brokerRequest) {
+  public RoutingTable getRoutingTable(BrokerRequest brokerRequest, long requestId) {
     String tableNameWithType = brokerRequest.getQuerySource().getTableName();
     RoutingEntry routingEntry = _routingEntryMap.get(tableNameWithType);
     if (routingEntry == null) {
       return null;
     }
-    InstanceSelector.SelectionResult selectionResult = routingEntry.calculateRouting(brokerRequest);
+    InstanceSelector.SelectionResult selectionResult = routingEntry.calculateRouting(brokerRequest, requestId);
     Map<String, String> segmentToInstanceMap = selectionResult.getSegmentToInstanceMap();
     Map<ServerInstance, List<String>> serverInstanceToSegmentsMap = new HashMap<>();
     for (Map.Entry<String, String> entry : segmentToInstanceMap.entrySet()) {
@@ -717,7 +718,7 @@ public class BrokerRoutingManager implements RoutingManager, ClusterChangeHandle
       }
     }
 
-    InstanceSelector.SelectionResult calculateRouting(BrokerRequest brokerRequest) {
+    InstanceSelector.SelectionResult calculateRouting(BrokerRequest brokerRequest, long requestId) {
       Set<String> selectedSegments = _segmentSelector.select(brokerRequest);
       int numTotalSelectedSegments = selectedSegments.size();
       if (!selectedSegments.isEmpty()) {
@@ -728,7 +729,7 @@ public class BrokerRoutingManager implements RoutingManager, ClusterChangeHandle
       int numPrunedSegments = numTotalSelectedSegments - selectedSegments.size();
       if (!selectedSegments.isEmpty()) {
         InstanceSelector.SelectionResult selectionResult = _instanceSelector.select(brokerRequest,
-            new ArrayList<>(selectedSegments));
+            new ArrayList<>(selectedSegments), requestId);
         selectionResult.setNumPrunedSegments(numPrunedSegments);
         return selectionResult;
       } else {
diff --git a/pinot-broker/src/main/java/org/apache/pinot/broker/routing/instanceselector/BaseInstanceSelector.java b/pinot-broker/src/main/java/org/apache/pinot/broker/routing/instanceselector/BaseInstanceSelector.java
index 9b92d6031c..9325035aca 100644
--- a/pinot-broker/src/main/java/org/apache/pinot/broker/routing/instanceselector/BaseInstanceSelector.java
+++ b/pinot-broker/src/main/java/org/apache/pinot/broker/routing/instanceselector/BaseInstanceSelector.java
@@ -26,7 +26,6 @@ import java.util.List;
 import java.util.Map;
 import java.util.Set;
 import java.util.SortedMap;
-import java.util.concurrent.atomic.AtomicLong;
 import javax.annotation.Nullable;
 import org.apache.helix.model.ExternalView;
 import org.apache.helix.model.IdealState;
@@ -50,9 +49,8 @@ abstract class BaseInstanceSelector implements InstanceSelector {
   private static final Logger LOGGER = LoggerFactory.getLogger(BaseInstanceSelector.class);
 
   // To prevent int overflow, reset the request id once it reaches this value
-  private static final int MAX_REQUEST_ID = 1_000_000_000;
+  private static final long MAX_REQUEST_ID = 1_000_000_000;
 
-  private final AtomicLong _requestId = new AtomicLong();
   private final String _tableNameWithType;
   private final BrokerMetrics _brokerMetrics;
   protected final AdaptiveServerSelector _adaptiveServerSelector;
@@ -267,13 +265,13 @@ abstract class BaseInstanceSelector implements InstanceSelector {
   }
 
   @Override
-  public SelectionResult select(BrokerRequest brokerRequest, List<String> segments) {
-    int requestId = (int) (_requestId.getAndIncrement() % MAX_REQUEST_ID);
+  public SelectionResult select(BrokerRequest brokerRequest, List<String> segments, long requestId) {
     Map<String, String> queryOptions = (brokerRequest.getPinotQuery() != null
         && brokerRequest.getPinotQuery().getQueryOptions() != null)
         ? brokerRequest.getPinotQuery().getQueryOptions()
         : Collections.emptyMap();
-    Map<String, String> segmentToInstanceMap = select(segments, requestId, _segmentToEnabledInstancesMap,
+    int requestIdInt = (int) (requestId % MAX_REQUEST_ID);
+    Map<String, String> segmentToInstanceMap = select(segments, requestIdInt, _segmentToEnabledInstancesMap,
         queryOptions);
     Set<String> unavailableSegments = _unavailableSegments;
     if (unavailableSegments.isEmpty()) {
diff --git a/pinot-broker/src/main/java/org/apache/pinot/broker/routing/instanceselector/InstanceSelector.java b/pinot-broker/src/main/java/org/apache/pinot/broker/routing/instanceselector/InstanceSelector.java
index 4edaf69b0a..4c96007fd6 100644
--- a/pinot-broker/src/main/java/org/apache/pinot/broker/routing/instanceselector/InstanceSelector.java
+++ b/pinot-broker/src/main/java/org/apache/pinot/broker/routing/instanceselector/InstanceSelector.java
@@ -55,8 +55,13 @@ public interface InstanceSelector {
    * Selects the server instances for the given segments queried by the given broker request, returns a map from segment
    * to selected server instance hosting the segment and a set of unavailable segments (no enabled instance or all
    * enabled instances are in ERROR state).
+   *
+   * @param brokerRequest BrokerRequest for the query
+   * @param segments segments for which instance needs to be selected
+   * @param requestId requestId generated by the Broker for a query
+   * @return instance of SelectionResult which describes the instance to pick for a given segment
    */
-  SelectionResult select(BrokerRequest brokerRequest, List<String> segments);
+  SelectionResult select(BrokerRequest brokerRequest, List<String> segments, long requestId);
 
   class SelectionResult {
     private final Map<String, String> _segmentToInstanceMap;
diff --git a/pinot-broker/src/main/java/org/apache/pinot/broker/routing/instanceselector/InstanceSelectorFactory.java b/pinot-broker/src/main/java/org/apache/pinot/broker/routing/instanceselector/InstanceSelectorFactory.java
index 0ef46d7b7a..8cc9f260f8 100644
--- a/pinot-broker/src/main/java/org/apache/pinot/broker/routing/instanceselector/InstanceSelectorFactory.java
+++ b/pinot-broker/src/main/java/org/apache/pinot/broker/routing/instanceselector/InstanceSelectorFactory.java
@@ -19,6 +19,8 @@
 package org.apache.pinot.broker.routing.instanceselector;
 
 import javax.annotation.Nullable;
+import org.apache.helix.store.zk.ZkHelixPropertyStore;
+import org.apache.helix.zookeeper.datamodel.ZNRecord;
 import org.apache.pinot.broker.routing.adaptiveserverselector.AdaptiveServerSelector;
 import org.apache.pinot.common.metrics.BrokerMetrics;
 import org.apache.pinot.spi.config.table.RoutingConfig;
@@ -37,7 +39,8 @@ public class InstanceSelectorFactory {
   public static final String LEGACY_REPLICA_GROUP_OFFLINE_ROUTING = "PartitionAwareOffline";
   public static final String LEGACY_REPLICA_GROUP_REALTIME_ROUTING = "PartitionAwareRealtime";
 
-  public static InstanceSelector getInstanceSelector(TableConfig tableConfig, BrokerMetrics brokerMetrics,
+  public static InstanceSelector getInstanceSelector(TableConfig tableConfig,
+      ZkHelixPropertyStore<ZNRecord> propertyStore, BrokerMetrics brokerMetrics,
       @Nullable AdaptiveServerSelector adaptiveServerSelector) {
     String tableNameWithType = tableConfig.getTableName();
     RoutingConfig routingConfig = tableConfig.getRoutingConfig();
@@ -55,6 +58,12 @@ public class InstanceSelectorFactory {
         LOGGER.info("Using StrictReplicaGroupInstanceSelector for table: {}", tableNameWithType);
         return new StrictReplicaGroupInstanceSelector(tableNameWithType, brokerMetrics, adaptiveServerSelector);
       }
+      if (RoutingConfig.MULTI_STAGE_REPLICA_GROUP_SELECTOR_TYPE.equalsIgnoreCase(
+          routingConfig.getInstanceSelectorType())) {
+        LOGGER.info("Using {} for table: {}", routingConfig.getInstanceSelectorType(), tableNameWithType);
+        return new MultiStageReplicaGroupSelector(tableNameWithType, propertyStore, brokerMetrics,
+            adaptiveServerSelector);
+      }
     }
     return new BalancedInstanceSelector(tableNameWithType, brokerMetrics, adaptiveServerSelector);
   }
diff --git a/pinot-broker/src/main/java/org/apache/pinot/broker/routing/instanceselector/MultiStageReplicaGroupSelector.java b/pinot-broker/src/main/java/org/apache/pinot/broker/routing/instanceselector/MultiStageReplicaGroupSelector.java
new file mode 100644
index 0000000000..0a6d66510c
--- /dev/null
+++ b/pinot-broker/src/main/java/org/apache/pinot/broker/routing/instanceselector/MultiStageReplicaGroupSelector.java
@@ -0,0 +1,150 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.broker.routing.instanceselector;
+
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Preconditions;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import javax.annotation.Nullable;
+import org.apache.helix.model.ExternalView;
+import org.apache.helix.model.IdealState;
+import org.apache.helix.store.zk.ZkHelixPropertyStore;
+import org.apache.helix.zookeeper.datamodel.ZNRecord;
+import org.apache.pinot.broker.routing.adaptiveserverselector.AdaptiveServerSelector;
+import org.apache.pinot.common.assignment.InstancePartitions;
+import org.apache.pinot.common.assignment.InstancePartitionsUtils;
+import org.apache.pinot.common.metrics.BrokerMetrics;
+import org.apache.pinot.spi.config.table.TableType;
+import org.apache.pinot.spi.config.table.assignment.InstancePartitionsType;
+import org.apache.pinot.spi.utils.builder.TableNameBuilder;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+
+/**
+ * Instance selector for multi-stage queries which can ensure that Colocated Tables always leverage Colocated Join
+ * whenever possible. To achieve this, this instance-selector uses InstancePartitions (IP) to determine replica-groups,
+ * as opposed to IdealState used by other instance-selectors. Moreover, this also uses the requestId generated by
+ * Pinot broker to determine the replica-group picked for each table involved in the query, as opposed to using a
+ * member variable. There may be scenarios where an instance in the chosen replica-group is down. In that case, this
+ * strategy will try to pick another replica-group. For realtime tables, this strategy uses only CONSUMING partitions.
+ * This is feature is in <strong>Beta</strong>.
+ */
+public class MultiStageReplicaGroupSelector extends BaseInstanceSelector {
+  private static final Logger LOGGER = LoggerFactory.getLogger(MultiStageReplicaGroupSelector.class);
+
+  private final String _tableNameWithType;
+  private final ZkHelixPropertyStore<ZNRecord> _propertyStore;
+  private InstancePartitions _instancePartitions;
+
+  public MultiStageReplicaGroupSelector(String tableNameWithType, ZkHelixPropertyStore<ZNRecord> propertyStore,
+      BrokerMetrics brokerMetrics, @Nullable AdaptiveServerSelector adaptiveServerSelector) {
+    super(tableNameWithType, brokerMetrics, adaptiveServerSelector);
+    _tableNameWithType = tableNameWithType;
+    _propertyStore = propertyStore;
+  }
+
+  @Override
+  public void init(Set<String> enabledInstances, IdealState idealState, ExternalView externalView,
+      Set<String> onlineSegments) {
+    super.init(enabledInstances, idealState, externalView, onlineSegments);
+    _instancePartitions = getInstancePartitions();
+  }
+
+  @Override
+  public void onInstancesChange(Set<String> enabledInstances, List<String> changedInstances) {
+    super.onInstancesChange(enabledInstances, changedInstances);
+    _instancePartitions = getInstancePartitions();
+  }
+
+  @Override
+  public void onAssignmentChange(IdealState idealState, ExternalView externalView, Set<String> onlineSegments) {
+    super.onAssignmentChange(idealState, externalView, onlineSegments);
+    _instancePartitions = getInstancePartitions();
+  }
+
+  @Override
+  Map<String, String> select(List<String> segments, int requestId,
+      Map<String, List<String>> segmentToEnabledInstancesMap, Map<String, String> queryOptions) {
+    // Create a copy of InstancePartitions to avoid race-condition with event-listeners above.
+    InstancePartitions instancePartitions = _instancePartitions;
+    int replicaGroupSelected = requestId % instancePartitions.getNumReplicaGroups();
+    for (int iteration = 0; iteration < instancePartitions.getNumReplicaGroups(); iteration++) {
+      int replicaGroup = (replicaGroupSelected + iteration) % instancePartitions.getNumReplicaGroups();
+      try {
+        return tryAssigning(segmentToEnabledInstancesMap, instancePartitions, replicaGroup);
+      } catch (Exception e) {
+        LOGGER.warn("Unable to select replica-group {} for table: {}", replicaGroup, _tableNameWithType, e);
+      }
+    }
+    throw new RuntimeException(String.format("Unable to find any replica-group to serve table: %s",
+        _tableNameWithType));
+  }
+
+  /**
+   * Returns a map from the segmentName to the corresponding server in the given replica-group. If the is not enabled,
+   * we throw an exception.
+   */
+  private Map<String, String> tryAssigning(Map<String, List<String>> segmentToEnabledInstancesMap,
+      InstancePartitions instancePartitions, int replicaId) {
+    Set<String> instanceLookUpSet = new HashSet<>();
+    for (int partition = 0; partition < instancePartitions.getNumPartitions(); partition++) {
+      List<String> instances = instancePartitions.getInstances(partition, replicaId);
+      instanceLookUpSet.addAll(instances);
+    }
+    Map<String, String> result = new HashMap<>();
+    for (Map.Entry<String, List<String>> entry : segmentToEnabledInstancesMap.entrySet()) {
+      String segmentName = entry.getKey();
+      boolean found = false;
+      for (String enabledInstanceForSegment : entry.getValue()) {
+        if (instanceLookUpSet.contains(enabledInstanceForSegment)) {
+          found = true;
+          result.put(segmentName, enabledInstanceForSegment);
+          break;
+        }
+      }
+      if (!found) {
+        throw new RuntimeException(String.format("Unable to find an enabled instance for segment: %s", segmentName));
+      }
+    }
+    return result;
+  }
+
+  @VisibleForTesting
+  protected InstancePartitions getInstancePartitions() {
+    // TODO: Evaluate whether we need to provide support for COMPLETE partitions.
+    TableType tableType = TableNameBuilder.getTableTypeFromTableName(_tableNameWithType);
+    Preconditions.checkNotNull(tableType);
+    InstancePartitions instancePartitions = null;
+    if (tableType.equals(TableType.OFFLINE)) {
+      instancePartitions = InstancePartitionsUtils.fetchInstancePartitions(_propertyStore,
+          InstancePartitionsUtils.getInstancePartitionsName(_tableNameWithType, tableType.name()));
+    } else {
+      instancePartitions = InstancePartitionsUtils.fetchInstancePartitions(_propertyStore,
+          InstancePartitionsUtils.getInstancePartitionsName(_tableNameWithType,
+              InstancePartitionsType.CONSUMING.name()));
+    }
+    Preconditions.checkNotNull(instancePartitions);
+    return instancePartitions;
+  }
+}
diff --git a/pinot-broker/src/test/java/org/apache/pinot/broker/broker/HelixBrokerStarterTest.java b/pinot-broker/src/test/java/org/apache/pinot/broker/broker/HelixBrokerStarterTest.java
index b2abbd59a9..53e75d0557 100644
--- a/pinot-broker/src/test/java/org/apache/pinot/broker/broker/HelixBrokerStarterTest.java
+++ b/pinot-broker/src/test/java/org/apache/pinot/broker/broker/HelixBrokerStarterTest.java
@@ -153,7 +153,7 @@ public class HelixBrokerStarterTest extends ControllerTest {
     assertTrue(routingManager.routingExists(REALTIME_TABLE_NAME));
 
     BrokerRequest brokerRequest = CalciteSqlCompiler.compileToBrokerRequest("SELECT * FROM " + OFFLINE_TABLE_NAME);
-    RoutingTable routingTable = routingManager.getRoutingTable(brokerRequest);
+    RoutingTable routingTable = routingManager.getRoutingTable(brokerRequest, 0);
     assertNotNull(routingTable);
     assertEquals(routingTable.getServerInstanceToSegmentsMap().size(), NUM_SERVERS);
     assertEquals(routingTable.getServerInstanceToSegmentsMap().values().iterator().next().size(), NUM_OFFLINE_SEGMENTS);
@@ -164,8 +164,9 @@ public class HelixBrokerStarterTest extends ControllerTest {
         SegmentMetadataMockUtils.mockSegmentMetadata(RAW_TABLE_NAME), "downloadUrl");
 
     TestUtils.waitForCondition(aVoid ->
-        routingManager.getRoutingTable(brokerRequest).getServerInstanceToSegmentsMap().values().iterator().next().size()
-            == NUM_OFFLINE_SEGMENTS + 1, 30_000L, "Failed to add the new segment into the routing table");
+        routingManager.getRoutingTable(brokerRequest, 0).getServerInstanceToSegmentsMap()
+            .values().iterator().next().size() == NUM_OFFLINE_SEGMENTS + 1, 30_000L, "Failed to add the new segment "
+        + "into the routing table");
 
     // Add a new table with different broker tenant
     String newRawTableName = "newTable";
diff --git a/pinot-broker/src/test/java/org/apache/pinot/broker/requesthandler/BaseBrokerRequestHandlerTest.java b/pinot-broker/src/test/java/org/apache/pinot/broker/requesthandler/BaseBrokerRequestHandlerTest.java
index cff4eff193..88f72600e5 100644
--- a/pinot-broker/src/test/java/org/apache/pinot/broker/requesthandler/BaseBrokerRequestHandlerTest.java
+++ b/pinot-broker/src/test/java/org/apache/pinot/broker/requesthandler/BaseBrokerRequestHandlerTest.java
@@ -199,7 +199,7 @@ public class BaseBrokerRequestHandlerTest {
     RoutingTable rt = mock(RoutingTable.class);
     when(rt.getServerInstanceToSegmentsMap()).thenReturn(Collections
         .singletonMap(new ServerInstance(new InstanceConfig("server01_9000")), Collections.singletonList("segment01")));
-    when(routingManager.getRoutingTable(any())).thenReturn(rt);
+    when(routingManager.getRoutingTable(any(), Mockito.anyLong())).thenReturn(rt);
     QueryQuotaManager queryQuotaManager = mock(QueryQuotaManager.class);
     when(queryQuotaManager.acquire(anyString())).thenReturn(true);
     CountDownLatch latch = new CountDownLatch(1);
diff --git a/pinot-broker/src/test/java/org/apache/pinot/broker/routing/instanceselector/InstanceSelectorTest.java b/pinot-broker/src/test/java/org/apache/pinot/broker/routing/instanceselector/InstanceSelectorTest.java
index 5025782a49..07efbe101a 100644
--- a/pinot-broker/src/test/java/org/apache/pinot/broker/routing/instanceselector/InstanceSelectorTest.java
+++ b/pinot-broker/src/test/java/org/apache/pinot/broker/routing/instanceselector/InstanceSelectorTest.java
@@ -18,6 +18,9 @@
  */
 package org.apache.pinot.broker.routing.instanceselector;
 
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collections;
 import java.util.HashMap;
@@ -28,7 +31,10 @@ import java.util.Set;
 import java.util.TreeMap;
 import org.apache.helix.model.ExternalView;
 import org.apache.helix.model.IdealState;
+import org.apache.helix.store.zk.ZkHelixPropertyStore;
+import org.apache.helix.zookeeper.datamodel.ZNRecord;
 import org.apache.pinot.broker.routing.adaptiveserverselector.AdaptiveServerSelector;
+import org.apache.pinot.common.assignment.InstancePartitions;
 import org.apache.pinot.common.metrics.BrokerMetrics;
 import org.apache.pinot.common.request.BrokerRequest;
 import org.apache.pinot.common.request.PinotQuery;
@@ -41,10 +47,13 @@ import static org.apache.pinot.spi.utils.CommonConstants.Helix.StateModel.Segmen
 import static org.apache.pinot.spi.utils.CommonConstants.Helix.StateModel.SegmentStateModel.ERROR;
 import static org.apache.pinot.spi.utils.CommonConstants.Helix.StateModel.SegmentStateModel.OFFLINE;
 import static org.apache.pinot.spi.utils.CommonConstants.Helix.StateModel.SegmentStateModel.ONLINE;
+import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.spy;
 import static org.mockito.Mockito.when;
 import static org.testng.Assert.assertEquals;
 import static org.testng.Assert.assertTrue;
+import static org.testng.Assert.fail;
 
 
 public class InstanceSelectorTest {
@@ -55,26 +64,27 @@ public class InstanceSelectorTest {
   public void testInstanceSelectorFactory() {
     TableConfig tableConfig = mock(TableConfig.class);
     BrokerMetrics brokerMetrics = mock(BrokerMetrics.class);
+    ZkHelixPropertyStore<ZNRecord> propertyStore = mock(ZkHelixPropertyStore.class);
     AdaptiveServerSelector adaptiveServerSelector = null;
 
     // Routing config is missing
-    assertTrue(InstanceSelectorFactory.getInstanceSelector(tableConfig, brokerMetrics,
+    assertTrue(InstanceSelectorFactory.getInstanceSelector(tableConfig, propertyStore, brokerMetrics,
         adaptiveServerSelector) instanceof BalancedInstanceSelector);
 
     // Instance selector type is not configured
     RoutingConfig routingConfig = mock(RoutingConfig.class);
     when(tableConfig.getRoutingConfig()).thenReturn(routingConfig);
-    assertTrue(InstanceSelectorFactory.getInstanceSelector(tableConfig, brokerMetrics,
+    assertTrue(InstanceSelectorFactory.getInstanceSelector(tableConfig, propertyStore, brokerMetrics,
         adaptiveServerSelector) instanceof BalancedInstanceSelector);
 
     // Replica-group instance selector should be returned
     when(routingConfig.getInstanceSelectorType()).thenReturn(RoutingConfig.REPLICA_GROUP_INSTANCE_SELECTOR_TYPE);
-    assertTrue(InstanceSelectorFactory.getInstanceSelector(tableConfig, brokerMetrics,
+    assertTrue(InstanceSelectorFactory.getInstanceSelector(tableConfig, propertyStore, brokerMetrics,
         adaptiveServerSelector) instanceof ReplicaGroupInstanceSelector);
 
     // Strict replica-group instance selector should be returned
     when(routingConfig.getInstanceSelectorType()).thenReturn(RoutingConfig.STRICT_REPLICA_GROUP_INSTANCE_SELECTOR_TYPE);
-    assertTrue(InstanceSelectorFactory.getInstanceSelector(tableConfig, brokerMetrics,
+    assertTrue(InstanceSelectorFactory.getInstanceSelector(tableConfig, propertyStore, brokerMetrics,
         adaptiveServerSelector) instanceof StrictReplicaGroupInstanceSelector);
 
     // Should be backward-compatible with legacy config
@@ -82,12 +92,12 @@ public class InstanceSelectorTest {
     when(tableConfig.getTableType()).thenReturn(TableType.OFFLINE);
     when(routingConfig.getRoutingTableBuilderName()).thenReturn(
         InstanceSelectorFactory.LEGACY_REPLICA_GROUP_OFFLINE_ROUTING);
-    assertTrue(InstanceSelectorFactory.getInstanceSelector(tableConfig, brokerMetrics,
+    assertTrue(InstanceSelectorFactory.getInstanceSelector(tableConfig, propertyStore, brokerMetrics,
         adaptiveServerSelector) instanceof ReplicaGroupInstanceSelector);
     when(tableConfig.getTableType()).thenReturn(TableType.REALTIME);
     when(routingConfig.getRoutingTableBuilderName()).thenReturn(
         InstanceSelectorFactory.LEGACY_REPLICA_GROUP_REALTIME_ROUTING);
-    assertTrue(InstanceSelectorFactory.getInstanceSelector(tableConfig, brokerMetrics,
+    assertTrue(InstanceSelectorFactory.getInstanceSelector(tableConfig, propertyStore, brokerMetrics,
         adaptiveServerSelector) instanceof ReplicaGroupInstanceSelector);
   }
 
@@ -169,6 +179,8 @@ public class InstanceSelectorTest {
     replicaGroupInstanceSelector.init(enabledInstances, idealState, externalView, onlineSegments);
     strictReplicaGroupInstanceSelector.init(enabledInstances, idealState, externalView, onlineSegments);
 
+    int requestId = 0;
+
     // For the 1st request:
     //   BalancedInstanceSelector:
     //     segment0 -> instance0
@@ -189,7 +201,8 @@ public class InstanceSelectorTest {
     expectedBalancedInstanceSelectorResult.put(segment1, instance2);
     expectedBalancedInstanceSelectorResult.put(segment2, instance1);
     expectedBalancedInstanceSelectorResult.put(segment3, instance3);
-    InstanceSelector.SelectionResult selectionResult = balancedInstanceSelector.select(brokerRequest, segments);
+    InstanceSelector.SelectionResult selectionResult = balancedInstanceSelector.select(brokerRequest, segments,
+        requestId);
     assertEquals(selectionResult.getSegmentToInstanceMap(), expectedBalancedInstanceSelectorResult);
     assertTrue(selectionResult.getUnavailableSegments().isEmpty());
     Map<String, String> expectedReplicaGroupInstanceSelectorResult = new HashMap<>();
@@ -197,10 +210,10 @@ public class InstanceSelectorTest {
     expectedReplicaGroupInstanceSelectorResult.put(segment1, instance0);
     expectedReplicaGroupInstanceSelectorResult.put(segment2, instance1);
     expectedReplicaGroupInstanceSelectorResult.put(segment3, instance1);
-    selectionResult = replicaGroupInstanceSelector.select(brokerRequest, segments);
+    selectionResult = replicaGroupInstanceSelector.select(brokerRequest, segments, requestId);
     assertEquals(selectionResult.getSegmentToInstanceMap(), expectedReplicaGroupInstanceSelectorResult);
     assertTrue(selectionResult.getUnavailableSegments().isEmpty());
-    selectionResult = strictReplicaGroupInstanceSelector.select(brokerRequest, segments);
+    selectionResult = strictReplicaGroupInstanceSelector.select(brokerRequest, segments, requestId);
     assertEquals(selectionResult.getSegmentToInstanceMap(), expectedReplicaGroupInstanceSelectorResult);
     assertTrue(selectionResult.getUnavailableSegments().isEmpty());
 
@@ -215,12 +228,13 @@ public class InstanceSelectorTest {
     //     segment1 -> instance2
     //     segment2 -> instance3
     //     segment3 -> instance3
+    requestId++;
     expectedBalancedInstanceSelectorResult = new HashMap<>();
     expectedBalancedInstanceSelectorResult.put(segment0, instance2);
     expectedBalancedInstanceSelectorResult.put(segment1, instance0);
     expectedBalancedInstanceSelectorResult.put(segment2, instance3);
     expectedBalancedInstanceSelectorResult.put(segment3, instance1);
-    selectionResult = balancedInstanceSelector.select(brokerRequest, segments);
+    selectionResult = balancedInstanceSelector.select(brokerRequest, segments, requestId);
     assertEquals(selectionResult.getSegmentToInstanceMap(), expectedBalancedInstanceSelectorResult);
     assertTrue(selectionResult.getUnavailableSegments().isEmpty());
     expectedReplicaGroupInstanceSelectorResult = new HashMap<>();
@@ -228,10 +242,10 @@ public class InstanceSelectorTest {
     expectedReplicaGroupInstanceSelectorResult.put(segment1, instance2);
     expectedReplicaGroupInstanceSelectorResult.put(segment2, instance3);
     expectedReplicaGroupInstanceSelectorResult.put(segment3, instance3);
-    selectionResult = replicaGroupInstanceSelector.select(brokerRequest, segments);
+    selectionResult = replicaGroupInstanceSelector.select(brokerRequest, segments, requestId);
     assertEquals(selectionResult.getSegmentToInstanceMap(), expectedReplicaGroupInstanceSelectorResult);
     assertTrue(selectionResult.getUnavailableSegments().isEmpty());
-    selectionResult = strictReplicaGroupInstanceSelector.select(brokerRequest, segments);
+    selectionResult = strictReplicaGroupInstanceSelector.select(brokerRequest, segments, requestId);
     assertEquals(selectionResult.getSegmentToInstanceMap(), expectedReplicaGroupInstanceSelectorResult);
     assertTrue(selectionResult.getUnavailableSegments().isEmpty());
 
@@ -252,12 +266,13 @@ public class InstanceSelectorTest {
     //     segment1 -> instance2
     //     segment2 -> instance1
     //     segment3 -> instance1
+    requestId++;
     expectedBalancedInstanceSelectorResult = new HashMap<>();
     expectedBalancedInstanceSelectorResult.put(segment0, instance2);
     expectedBalancedInstanceSelectorResult.put(segment1, instance2);
     expectedBalancedInstanceSelectorResult.put(segment2, instance1);
     expectedBalancedInstanceSelectorResult.put(segment3, instance3);
-    selectionResult = balancedInstanceSelector.select(brokerRequest, segments);
+    selectionResult = balancedInstanceSelector.select(brokerRequest, segments, requestId);
     assertEquals(selectionResult.getSegmentToInstanceMap(), expectedBalancedInstanceSelectorResult);
     assertTrue(selectionResult.getUnavailableSegments().isEmpty());
     expectedReplicaGroupInstanceSelectorResult = new HashMap<>();
@@ -265,10 +280,10 @@ public class InstanceSelectorTest {
     expectedReplicaGroupInstanceSelectorResult.put(segment1, instance2);
     expectedReplicaGroupInstanceSelectorResult.put(segment2, instance1);
     expectedReplicaGroupInstanceSelectorResult.put(segment3, instance1);
-    selectionResult = replicaGroupInstanceSelector.select(brokerRequest, segments);
+    selectionResult = replicaGroupInstanceSelector.select(brokerRequest, segments, requestId);
     assertEquals(selectionResult.getSegmentToInstanceMap(), expectedReplicaGroupInstanceSelectorResult);
     assertTrue(selectionResult.getUnavailableSegments().isEmpty());
-    selectionResult = strictReplicaGroupInstanceSelector.select(brokerRequest, segments);
+    selectionResult = strictReplicaGroupInstanceSelector.select(brokerRequest, segments, requestId);
     assertEquals(selectionResult.getSegmentToInstanceMap(), expectedReplicaGroupInstanceSelectorResult);
     assertTrue(selectionResult.getUnavailableSegments().isEmpty());
 
@@ -283,12 +298,13 @@ public class InstanceSelectorTest {
     //     segment1 -> instance2
     //     segment2 -> instance3
     //     segment3 -> instance3
+    requestId++;
     expectedBalancedInstanceSelectorResult = new HashMap<>();
     expectedBalancedInstanceSelectorResult.put(segment0, instance2);
     expectedBalancedInstanceSelectorResult.put(segment1, instance2);
     expectedBalancedInstanceSelectorResult.put(segment2, instance3);
     expectedBalancedInstanceSelectorResult.put(segment3, instance1);
-    selectionResult = balancedInstanceSelector.select(brokerRequest, segments);
+    selectionResult = balancedInstanceSelector.select(brokerRequest, segments, requestId);
     assertEquals(selectionResult.getSegmentToInstanceMap(), expectedBalancedInstanceSelectorResult);
     assertTrue(selectionResult.getUnavailableSegments().isEmpty());
     expectedReplicaGroupInstanceSelectorResult = new HashMap<>();
@@ -296,10 +312,10 @@ public class InstanceSelectorTest {
     expectedReplicaGroupInstanceSelectorResult.put(segment1, instance2);
     expectedReplicaGroupInstanceSelectorResult.put(segment2, instance3);
     expectedReplicaGroupInstanceSelectorResult.put(segment3, instance3);
-    selectionResult = replicaGroupInstanceSelector.select(brokerRequest, segments);
+    selectionResult = replicaGroupInstanceSelector.select(brokerRequest, segments, requestId);
     assertEquals(selectionResult.getSegmentToInstanceMap(), expectedReplicaGroupInstanceSelectorResult);
     assertTrue(selectionResult.getUnavailableSegments().isEmpty());
-    selectionResult = strictReplicaGroupInstanceSelector.select(brokerRequest, segments);
+    selectionResult = strictReplicaGroupInstanceSelector.select(brokerRequest, segments, requestId);
     assertEquals(selectionResult.getSegmentToInstanceMap(), expectedReplicaGroupInstanceSelectorResult);
     assertTrue(selectionResult.getUnavailableSegments().isEmpty());
 
@@ -326,21 +342,22 @@ public class InstanceSelectorTest {
     //     segment2 -> instance1
     //     segment3 -> instance1
     //     segment4 -> null
+    requestId++;
     expectedBalancedInstanceSelectorResult = new HashMap<>();
     expectedBalancedInstanceSelectorResult.put(segment1, instance2);
     expectedBalancedInstanceSelectorResult.put(segment2, instance3);
     expectedBalancedInstanceSelectorResult.put(segment3, instance1);
-    selectionResult = balancedInstanceSelector.select(brokerRequest, segments);
+    selectionResult = balancedInstanceSelector.select(brokerRequest, segments, requestId);
     assertEquals(selectionResult.getSegmentToInstanceMap(), expectedBalancedInstanceSelectorResult);
     assertTrue(selectionResult.getUnavailableSegments().isEmpty());
     expectedReplicaGroupInstanceSelectorResult = new HashMap<>();
     expectedReplicaGroupInstanceSelectorResult.put(segment1, instance2);
     expectedReplicaGroupInstanceSelectorResult.put(segment2, instance1);
     expectedReplicaGroupInstanceSelectorResult.put(segment3, instance1);
-    selectionResult = replicaGroupInstanceSelector.select(brokerRequest, segments);
+    selectionResult = replicaGroupInstanceSelector.select(brokerRequest, segments, requestId);
     assertEquals(selectionResult.getSegmentToInstanceMap(), expectedReplicaGroupInstanceSelectorResult);
     assertTrue(selectionResult.getUnavailableSegments().isEmpty());
-    selectionResult = strictReplicaGroupInstanceSelector.select(brokerRequest, segments);
+    selectionResult = strictReplicaGroupInstanceSelector.select(brokerRequest, segments, requestId);
     assertEquals(selectionResult.getSegmentToInstanceMap(), expectedReplicaGroupInstanceSelectorResult);
     assertTrue(selectionResult.getUnavailableSegments().isEmpty());
 
@@ -355,21 +372,22 @@ public class InstanceSelectorTest {
     //     segment2 -> instance3
     //     segment3 -> instance3
     //     segment4 -> null
+    requestId++;
     expectedBalancedInstanceSelectorResult = new HashMap<>();
     expectedBalancedInstanceSelectorResult.put(segment1, instance2);
     expectedBalancedInstanceSelectorResult.put(segment2, instance1);
     expectedBalancedInstanceSelectorResult.put(segment3, instance3);
-    selectionResult = balancedInstanceSelector.select(brokerRequest, segments);
+    selectionResult = balancedInstanceSelector.select(brokerRequest, segments, requestId);
     assertEquals(selectionResult.getSegmentToInstanceMap(), expectedBalancedInstanceSelectorResult);
     assertTrue(selectionResult.getUnavailableSegments().isEmpty());
     expectedReplicaGroupInstanceSelectorResult = new HashMap<>();
     expectedReplicaGroupInstanceSelectorResult.put(segment1, instance2);
     expectedReplicaGroupInstanceSelectorResult.put(segment2, instance3);
     expectedReplicaGroupInstanceSelectorResult.put(segment3, instance3);
-    selectionResult = replicaGroupInstanceSelector.select(brokerRequest, segments);
+    selectionResult = replicaGroupInstanceSelector.select(brokerRequest, segments, requestId);
     assertEquals(selectionResult.getSegmentToInstanceMap(), expectedReplicaGroupInstanceSelectorResult);
     assertTrue(selectionResult.getUnavailableSegments().isEmpty());
-    selectionResult = strictReplicaGroupInstanceSelector.select(brokerRequest, segments);
+    selectionResult = strictReplicaGroupInstanceSelector.select(brokerRequest, segments, requestId);
     assertEquals(selectionResult.getSegmentToInstanceMap(), expectedReplicaGroupInstanceSelectorResult);
     assertTrue(selectionResult.getUnavailableSegments().isEmpty());
 
@@ -389,12 +407,13 @@ public class InstanceSelectorTest {
     //     segment2 -> instance1
     //     segment3 -> instance1
     //     segment4 -> instance2
+    requestId++;
     expectedBalancedInstanceSelectorResult = new HashMap<>();
     expectedBalancedInstanceSelectorResult.put(segment1, instance2);
     expectedBalancedInstanceSelectorResult.put(segment2, instance3);
     expectedBalancedInstanceSelectorResult.put(segment3, instance1);
     expectedBalancedInstanceSelectorResult.put(segment4, instance2);
-    selectionResult = balancedInstanceSelector.select(brokerRequest, segments);
+    selectionResult = balancedInstanceSelector.select(brokerRequest, segments, requestId);
     assertEquals(selectionResult.getSegmentToInstanceMap(), expectedBalancedInstanceSelectorResult);
     assertTrue(selectionResult.getUnavailableSegments().isEmpty());
     expectedReplicaGroupInstanceSelectorResult = new HashMap<>();
@@ -402,10 +421,10 @@ public class InstanceSelectorTest {
     expectedReplicaGroupInstanceSelectorResult.put(segment2, instance1);
     expectedReplicaGroupInstanceSelectorResult.put(segment3, instance1);
     expectedReplicaGroupInstanceSelectorResult.put(segment4, instance2);
-    selectionResult = replicaGroupInstanceSelector.select(brokerRequest, segments);
+    selectionResult = replicaGroupInstanceSelector.select(brokerRequest, segments, requestId);
     assertEquals(selectionResult.getSegmentToInstanceMap(), expectedReplicaGroupInstanceSelectorResult);
     assertTrue(selectionResult.getUnavailableSegments().isEmpty());
-    selectionResult = strictReplicaGroupInstanceSelector.select(brokerRequest, segments);
+    selectionResult = strictReplicaGroupInstanceSelector.select(brokerRequest, segments, requestId);
     assertEquals(selectionResult.getSegmentToInstanceMap(), expectedReplicaGroupInstanceSelectorResult);
     assertTrue(selectionResult.getUnavailableSegments().isEmpty());
 
@@ -420,12 +439,13 @@ public class InstanceSelectorTest {
     //     segment2 -> instance3
     //     segment3 -> instance3
     //     segment4 -> instance2
+    requestId++;
     expectedBalancedInstanceSelectorResult = new HashMap<>();
     expectedBalancedInstanceSelectorResult.put(segment1, instance2);
     expectedBalancedInstanceSelectorResult.put(segment2, instance1);
     expectedBalancedInstanceSelectorResult.put(segment3, instance3);
     expectedBalancedInstanceSelectorResult.put(segment4, instance2);
-    selectionResult = balancedInstanceSelector.select(brokerRequest, segments);
+    selectionResult = balancedInstanceSelector.select(brokerRequest, segments, requestId);
     assertEquals(selectionResult.getSegmentToInstanceMap(), expectedBalancedInstanceSelectorResult);
     assertTrue(selectionResult.getUnavailableSegments().isEmpty());
     expectedReplicaGroupInstanceSelectorResult = new HashMap<>();
@@ -433,10 +453,10 @@ public class InstanceSelectorTest {
     expectedReplicaGroupInstanceSelectorResult.put(segment2, instance3);
     expectedReplicaGroupInstanceSelectorResult.put(segment3, instance3);
     expectedReplicaGroupInstanceSelectorResult.put(segment4, instance2);
-    selectionResult = replicaGroupInstanceSelector.select(brokerRequest, segments);
+    selectionResult = replicaGroupInstanceSelector.select(brokerRequest, segments, requestId);
     assertEquals(selectionResult.getSegmentToInstanceMap(), expectedReplicaGroupInstanceSelectorResult);
     assertTrue(selectionResult.getUnavailableSegments().isEmpty());
-    selectionResult = strictReplicaGroupInstanceSelector.select(brokerRequest, segments);
+    selectionResult = strictReplicaGroupInstanceSelector.select(brokerRequest, segments, requestId);
     assertEquals(selectionResult.getSegmentToInstanceMap(), expectedReplicaGroupInstanceSelectorResult);
     assertTrue(selectionResult.getUnavailableSegments().isEmpty());
 
@@ -457,12 +477,13 @@ public class InstanceSelectorTest {
     //     segment2 -> instance1
     //     segment3 -> instance1
     //     segment4 -> instance0
+    requestId++;
     expectedBalancedInstanceSelectorResult = new HashMap<>();
     expectedBalancedInstanceSelectorResult.put(segment1, instance0);
     expectedBalancedInstanceSelectorResult.put(segment2, instance3);
     expectedBalancedInstanceSelectorResult.put(segment3, instance1);
     expectedBalancedInstanceSelectorResult.put(segment4, instance2);
-    selectionResult = balancedInstanceSelector.select(brokerRequest, segments);
+    selectionResult = balancedInstanceSelector.select(brokerRequest, segments, requestId);
     assertEquals(selectionResult.getSegmentToInstanceMap(), expectedBalancedInstanceSelectorResult);
     assertTrue(selectionResult.getUnavailableSegments().isEmpty());
     expectedReplicaGroupInstanceSelectorResult = new HashMap<>();
@@ -470,10 +491,10 @@ public class InstanceSelectorTest {
     expectedReplicaGroupInstanceSelectorResult.put(segment2, instance1);
     expectedReplicaGroupInstanceSelectorResult.put(segment3, instance1);
     expectedReplicaGroupInstanceSelectorResult.put(segment4, instance0);
-    selectionResult = replicaGroupInstanceSelector.select(brokerRequest, segments);
+    selectionResult = replicaGroupInstanceSelector.select(brokerRequest, segments, requestId);
     assertEquals(selectionResult.getSegmentToInstanceMap(), expectedReplicaGroupInstanceSelectorResult);
     assertTrue(selectionResult.getUnavailableSegments().isEmpty());
-    selectionResult = strictReplicaGroupInstanceSelector.select(brokerRequest, segments);
+    selectionResult = strictReplicaGroupInstanceSelector.select(brokerRequest, segments, requestId);
     assertEquals(selectionResult.getSegmentToInstanceMap(), expectedReplicaGroupInstanceSelectorResult);
     assertTrue(selectionResult.getUnavailableSegments().isEmpty());
 
@@ -488,12 +509,13 @@ public class InstanceSelectorTest {
     //     segment2 -> instance3
     //     segment3 -> instance3
     //     segment4 -> instance2
+    requestId++;
     expectedBalancedInstanceSelectorResult = new HashMap<>();
     expectedBalancedInstanceSelectorResult.put(segment1, instance2);
     expectedBalancedInstanceSelectorResult.put(segment2, instance1);
     expectedBalancedInstanceSelectorResult.put(segment3, instance3);
     expectedBalancedInstanceSelectorResult.put(segment4, instance0);
-    selectionResult = balancedInstanceSelector.select(brokerRequest, segments);
+    selectionResult = balancedInstanceSelector.select(brokerRequest, segments, requestId);
     assertEquals(selectionResult.getSegmentToInstanceMap(), expectedBalancedInstanceSelectorResult);
     assertTrue(selectionResult.getUnavailableSegments().isEmpty());
     expectedReplicaGroupInstanceSelectorResult = new HashMap<>();
@@ -501,10 +523,10 @@ public class InstanceSelectorTest {
     expectedReplicaGroupInstanceSelectorResult.put(segment2, instance3);
     expectedReplicaGroupInstanceSelectorResult.put(segment3, instance3);
     expectedReplicaGroupInstanceSelectorResult.put(segment4, instance2);
-    selectionResult = replicaGroupInstanceSelector.select(brokerRequest, segments);
+    selectionResult = replicaGroupInstanceSelector.select(brokerRequest, segments, requestId);
     assertEquals(selectionResult.getSegmentToInstanceMap(), expectedReplicaGroupInstanceSelectorResult);
     assertTrue(selectionResult.getUnavailableSegments().isEmpty());
-    selectionResult = strictReplicaGroupInstanceSelector.select(brokerRequest, segments);
+    selectionResult = strictReplicaGroupInstanceSelector.select(brokerRequest, segments, requestId);
     assertEquals(selectionResult.getSegmentToInstanceMap(), expectedReplicaGroupInstanceSelectorResult);
     assertTrue(selectionResult.getUnavailableSegments().isEmpty());
 
@@ -534,12 +556,13 @@ public class InstanceSelectorTest {
     //     segment2 -> instance1
     //     segment3 -> instance1
     //     segment4 -> instance2
+    requestId++;
     expectedBalancedInstanceSelectorResult = new HashMap<>();
     expectedBalancedInstanceSelectorResult.put(segment1, instance2);
     expectedBalancedInstanceSelectorResult.put(segment2, instance3);
     expectedBalancedInstanceSelectorResult.put(segment3, instance1);
     expectedBalancedInstanceSelectorResult.put(segment4, instance2);
-    selectionResult = balancedInstanceSelector.select(brokerRequest, segments);
+    selectionResult = balancedInstanceSelector.select(brokerRequest, segments, requestId);
     assertEquals(selectionResult.getSegmentToInstanceMap(), expectedBalancedInstanceSelectorResult);
     assertTrue(selectionResult.getUnavailableSegments().isEmpty());
     expectedReplicaGroupInstanceSelectorResult = new HashMap<>();
@@ -547,7 +570,7 @@ public class InstanceSelectorTest {
     expectedReplicaGroupInstanceSelectorResult.put(segment2, instance1);
     expectedReplicaGroupInstanceSelectorResult.put(segment3, instance1);
     expectedReplicaGroupInstanceSelectorResult.put(segment4, instance0);
-    selectionResult = replicaGroupInstanceSelector.select(brokerRequest, segments);
+    selectionResult = replicaGroupInstanceSelector.select(brokerRequest, segments, requestId);
     assertEquals(selectionResult.getSegmentToInstanceMap(), expectedReplicaGroupInstanceSelectorResult);
     assertTrue(selectionResult.getUnavailableSegments().isEmpty());
     Map<String, String> expectedStrictReplicaGroupInstanceSelectorResult = new HashMap<>();
@@ -555,7 +578,7 @@ public class InstanceSelectorTest {
     expectedStrictReplicaGroupInstanceSelectorResult.put(segment2, instance1);
     expectedStrictReplicaGroupInstanceSelectorResult.put(segment3, instance1);
     expectedStrictReplicaGroupInstanceSelectorResult.put(segment4, instance2);
-    selectionResult = strictReplicaGroupInstanceSelector.select(brokerRequest, segments);
+    selectionResult = strictReplicaGroupInstanceSelector.select(brokerRequest, segments, requestId);
     assertEquals(selectionResult.getSegmentToInstanceMap(), expectedStrictReplicaGroupInstanceSelectorResult);
     assertTrue(selectionResult.getUnavailableSegments().isEmpty());
 
@@ -570,12 +593,13 @@ public class InstanceSelectorTest {
     //     segment2 -> instance3
     //     segment3 -> instance3
     //     segment4 -> instance2
+    requestId++;
     expectedBalancedInstanceSelectorResult = new HashMap<>();
     expectedBalancedInstanceSelectorResult.put(segment1, instance2);
     expectedBalancedInstanceSelectorResult.put(segment2, instance1);
     expectedBalancedInstanceSelectorResult.put(segment3, instance3);
     expectedBalancedInstanceSelectorResult.put(segment4, instance0);
-    selectionResult = balancedInstanceSelector.select(brokerRequest, segments);
+    selectionResult = balancedInstanceSelector.select(brokerRequest, segments, requestId);
     assertEquals(selectionResult.getSegmentToInstanceMap(), expectedBalancedInstanceSelectorResult);
     assertTrue(selectionResult.getUnavailableSegments().isEmpty());
     expectedReplicaGroupInstanceSelectorResult = new HashMap<>();
@@ -583,10 +607,10 @@ public class InstanceSelectorTest {
     expectedReplicaGroupInstanceSelectorResult.put(segment2, instance3);
     expectedReplicaGroupInstanceSelectorResult.put(segment3, instance3);
     expectedReplicaGroupInstanceSelectorResult.put(segment4, instance2);
-    selectionResult = replicaGroupInstanceSelector.select(brokerRequest, segments);
+    selectionResult = replicaGroupInstanceSelector.select(brokerRequest, segments, requestId);
     assertEquals(selectionResult.getSegmentToInstanceMap(), expectedReplicaGroupInstanceSelectorResult);
     assertTrue(selectionResult.getUnavailableSegments().isEmpty());
-    selectionResult = strictReplicaGroupInstanceSelector.select(brokerRequest, segments);
+    selectionResult = strictReplicaGroupInstanceSelector.select(brokerRequest, segments, requestId);
     assertEquals(selectionResult.getSegmentToInstanceMap(), expectedReplicaGroupInstanceSelectorResult);
     assertTrue(selectionResult.getUnavailableSegments().isEmpty());
   }
@@ -668,12 +692,11 @@ public class InstanceSelectorTest {
     expectedReplicaGroupInstanceSelectorResult.put(segments.get(9), instance1);
     expectedReplicaGroupInstanceSelectorResult.put(segments.get(10), instance0);
     expectedReplicaGroupInstanceSelectorResult.put(segments.get(11), instance1);
-    InstanceSelector.SelectionResult selectionResult = replicaGroupInstanceSelector.select(brokerRequest, segments);
+    InstanceSelector.SelectionResult selectionResult = replicaGroupInstanceSelector.select(brokerRequest, segments, 0);
     assertEquals(selectionResult.getSegmentToInstanceMap(), expectedReplicaGroupInstanceSelectorResult);
     assertTrue(selectionResult.getUnavailableSegments().isEmpty());
   }
 
-
   @Test
   public void testReplicaGroupInstanceSelectorNumReplicaGroupsToQueryGreaterThanReplicas() {
     String offlineTableName = "testTable_OFFLINE";
@@ -752,7 +775,7 @@ public class InstanceSelectorTest {
     expectedReplicaGroupInstanceSelectorResult.put(segments.get(9), instance0);
     expectedReplicaGroupInstanceSelectorResult.put(segments.get(10), instance1);
     expectedReplicaGroupInstanceSelectorResult.put(segments.get(11), instance2);
-    InstanceSelector.SelectionResult selectionResult = replicaGroupInstanceSelector.select(brokerRequest, segments);
+    InstanceSelector.SelectionResult selectionResult = replicaGroupInstanceSelector.select(brokerRequest, segments, 0);
     assertEquals(selectionResult.getSegmentToInstanceMap(), expectedReplicaGroupInstanceSelectorResult);
     assertTrue(selectionResult.getUnavailableSegments().isEmpty());
   }
@@ -814,14 +837,126 @@ public class InstanceSelectorTest {
     for (String segment: segments) {
       expectedReplicaGroupInstanceSelectorResult.put(segment, instance0);
     }
-    InstanceSelector.SelectionResult selectionResult = replicaGroupInstanceSelector.select(brokerRequest, segments);
+    InstanceSelector.SelectionResult selectionResult = replicaGroupInstanceSelector.select(brokerRequest, segments, 0);
     assertEquals(selectionResult.getSegmentToInstanceMap(), expectedReplicaGroupInstanceSelectorResult);
 
     for (String segment: segments) {
       expectedReplicaGroupInstanceSelectorResult.put(segment, instance1);
     }
-    selectionResult = replicaGroupInstanceSelector.select(brokerRequest, segments);
+    selectionResult = replicaGroupInstanceSelector.select(brokerRequest, segments, 1);
+    assertEquals(selectionResult.getSegmentToInstanceMap(), expectedReplicaGroupInstanceSelectorResult);
+  }
+
+  @Test
+  public void testMultiStageStrictReplicaGroupSelector() {
+    String offlineTableName = "testTable_OFFLINE";
+    // Create instance-partitions with two replica-groups and 1 partition. Each replica-group has 2 instances.
+    List<String> replicaGroup0 = ImmutableList.of("instance-0", "instance-1");
+    List<String> replicaGroup1 = ImmutableList.of("instance-2", "instance-3");
+    Map<String, List<String>> partitionToInstances = ImmutableMap.of(
+        "0_0", replicaGroup0,
+        "0_1", replicaGroup1);
+    InstancePartitions instancePartitions = new InstancePartitions(offlineTableName);
+    instancePartitions.setInstances(0, 0, partitionToInstances.get("0_0"));
+    instancePartitions.setInstances(0, 1, partitionToInstances.get("0_1"));
+    BrokerMetrics brokerMetrics = mock(BrokerMetrics.class);
+    BrokerRequest brokerRequest = mock(BrokerRequest.class);
+    PinotQuery pinotQuery = mock(PinotQuery.class);
+    Map<String, String> queryOptions = new HashMap<>();
+
+    when(brokerRequest.getPinotQuery()).thenReturn(pinotQuery);
+    when(pinotQuery.getQueryOptions()).thenReturn(queryOptions);
+
+    ZkHelixPropertyStore<ZNRecord> propertyStore = (ZkHelixPropertyStore<ZNRecord>) mock(ZkHelixPropertyStore.class);
+
+    MultiStageReplicaGroupSelector multiStageSelector =
+        new MultiStageReplicaGroupSelector(offlineTableName, propertyStore, brokerMetrics, null);
+    multiStageSelector = spy(multiStageSelector);
+    doReturn(instancePartitions).when(multiStageSelector).getInstancePartitions();
+
+    List<String> enabledInstances = new ArrayList<>();
+    IdealState idealState = new IdealState(offlineTableName);
+    Map<String, Map<String, String>> idealStateSegmentAssignment = idealState.getRecord().getMapFields();
+    ExternalView externalView = new ExternalView(offlineTableName);
+    Map<String, Map<String, String>> externalViewSegmentAssignment = externalView.getRecord().getMapFields();
+    Set<String> onlineSegments = new HashSet<>();
+
+    // Mark all instances as enabled
+    for (int i = 0; i < 4; i++) {
+      enabledInstances.add(String.format("instance-%d", i));
+    }
+
+    List<String> segments = getSegments();
+
+    // Create two idealState and externalView maps. One is used for segments with replica-group=0 and the other for rg=1
+    Map<String, String> idealStateInstanceStateMap0 = new TreeMap<>();
+    Map<String, String> externalViewInstanceStateMap0 = new TreeMap<>();
+    Map<String, String> idealStateInstanceStateMap1 = new TreeMap<>();
+    Map<String, String> externalViewInstanceStateMap1 = new TreeMap<>();
+
+    // instance-0 and instance-2 mirror each other in the two replica-groups. Same for instance-1 and instance-3.
+    for (int i = 0; i < 4; i++) {
+      String instance = enabledInstances.get(i);
+      if (i % 2 == 0) {
+        idealStateInstanceStateMap0.put(instance, ONLINE);
+        externalViewInstanceStateMap0.put(instance, ONLINE);
+      } else {
+        idealStateInstanceStateMap1.put(instance, ONLINE);
+        externalViewInstanceStateMap1.put(instance, ONLINE);
+      }
+    }
+
+    // Even numbered segments get assigned to [instance-0, instance-2], and odd numbered segments get assigned to
+    // [instance-1,instance-3].
+    for (int segmentNum = 0; segmentNum < segments.size(); segmentNum++) {
+      String segment = segments.get(segmentNum);
+      if (segmentNum % 2 == 0) {
+        idealStateSegmentAssignment.put(segment, idealStateInstanceStateMap0);
+        externalViewSegmentAssignment.put(segment, externalViewInstanceStateMap0);
+      } else {
+        idealStateSegmentAssignment.put(segment, idealStateInstanceStateMap1);
+        externalViewSegmentAssignment.put(segment, externalViewInstanceStateMap1);
+      }
+      onlineSegments.add(segment);
+    }
+
+    multiStageSelector.init(new HashSet<>(enabledInstances), idealState, externalView, onlineSegments);
+
+    // Using requestId=0 should select replica-group 0. Even segments get assigned to instance-0 and odd segments get
+    // assigned to instance-1.
+    Map<String, String> expectedReplicaGroupInstanceSelectorResult = new HashMap<>();
+    for (int segmentNum = 0; segmentNum < segments.size(); segmentNum++) {
+      expectedReplicaGroupInstanceSelectorResult.put(segments.get(segmentNum), replicaGroup0.get(segmentNum % 2));
+    }
+    InstanceSelector.SelectionResult selectionResult = multiStageSelector.select(brokerRequest, segments, 0);
+    assertEquals(selectionResult.getSegmentToInstanceMap(), expectedReplicaGroupInstanceSelectorResult);
+
+    // Using same requestId again should return the same selection
+    selectionResult = multiStageSelector.select(brokerRequest, segments, 0);
+    assertEquals(selectionResult.getSegmentToInstanceMap(), expectedReplicaGroupInstanceSelectorResult);
+
+    // Using requestId=1 should select replica-group 1
+    expectedReplicaGroupInstanceSelectorResult = new HashMap<>();
+    for (int segmentNum = 0; segmentNum < segments.size(); segmentNum++) {
+      expectedReplicaGroupInstanceSelectorResult.put(segments.get(segmentNum), replicaGroup1.get(segmentNum % 2));
+    }
+    selectionResult = multiStageSelector.select(brokerRequest, segments, 1);
+    assertEquals(selectionResult.getSegmentToInstanceMap(), expectedReplicaGroupInstanceSelectorResult);
+
+    // If instance-0 is down, replica-group 1 should be picked even with requestId=0
+    enabledInstances.remove("instance-0");
+    multiStageSelector.init(new HashSet<>(enabledInstances), idealState, externalView, onlineSegments);
+    selectionResult = multiStageSelector.select(brokerRequest, segments, 0);
     assertEquals(selectionResult.getSegmentToInstanceMap(), expectedReplicaGroupInstanceSelectorResult);
+
+    // If instance-2 also goes down, no replica-group is eligible
+    enabledInstances.remove("instance-2");
+    multiStageSelector.init(new HashSet<>(enabledInstances), idealState, externalView, onlineSegments);
+    try {
+      multiStageSelector.select(brokerRequest, segments, 0);
+      fail("Method call above should have failed");
+    } catch (Exception ignored) {
+    }
   }
 
   @Test
@@ -880,10 +1015,10 @@ public class InstanceSelectorTest {
     PinotQuery pinotQuery = mock(PinotQuery.class);
     when(brokerRequest.getPinotQuery()).thenReturn(pinotQuery);
     when(pinotQuery.getQueryOptions()).thenReturn(null);
-    InstanceSelector.SelectionResult selectionResult = balancedInstanceSelector.select(brokerRequest, segments);
+    InstanceSelector.SelectionResult selectionResult = balancedInstanceSelector.select(brokerRequest, segments, 0);
     assertTrue(selectionResult.getSegmentToInstanceMap().isEmpty());
     assertEquals(selectionResult.getUnavailableSegments(), Arrays.asList(segment0, segment1));
-    selectionResult = strictReplicaGroupInstanceSelector.select(brokerRequest, segments);
+    selectionResult = strictReplicaGroupInstanceSelector.select(brokerRequest, segments, 0);
     assertTrue(selectionResult.getSegmentToInstanceMap().isEmpty());
     assertEquals(selectionResult.getUnavailableSegments(), Arrays.asList(segment0, segment1));
 
@@ -904,10 +1039,10 @@ public class InstanceSelectorTest {
       enabledInstances.add(errorInstance);
       balancedInstanceSelector.onInstancesChange(enabledInstances, Collections.singletonList(errorInstance));
       strictReplicaGroupInstanceSelector.onInstancesChange(enabledInstances, Collections.singletonList(errorInstance));
-      selectionResult = balancedInstanceSelector.select(brokerRequest, segments);
+      selectionResult = balancedInstanceSelector.select(brokerRequest, segments, 0);
       assertTrue(selectionResult.getSegmentToInstanceMap().isEmpty());
       assertEquals(selectionResult.getUnavailableSegments(), Arrays.asList(segment0, segment1));
-      selectionResult = strictReplicaGroupInstanceSelector.select(brokerRequest, segments);
+      selectionResult = strictReplicaGroupInstanceSelector.select(brokerRequest, segments, 0);
       assertTrue(selectionResult.getSegmentToInstanceMap().isEmpty());
       assertEquals(selectionResult.getUnavailableSegments(), Arrays.asList(segment0, segment1));
 
@@ -925,10 +1060,10 @@ public class InstanceSelectorTest {
       enabledInstances.add(instance);
       balancedInstanceSelector.onInstancesChange(enabledInstances, Collections.singletonList(instance));
       strictReplicaGroupInstanceSelector.onInstancesChange(enabledInstances, Collections.singletonList(instance));
-      selectionResult = balancedInstanceSelector.select(brokerRequest, segments);
+      selectionResult = balancedInstanceSelector.select(brokerRequest, segments, 0);
       assertEquals(selectionResult.getSegmentToInstanceMap().size(), 2);
       assertTrue(selectionResult.getUnavailableSegments().isEmpty());
-      selectionResult = strictReplicaGroupInstanceSelector.select(brokerRequest, segments);
+      selectionResult = strictReplicaGroupInstanceSelector.select(brokerRequest, segments, 0);
       assertEquals(selectionResult.getSegmentToInstanceMap().size(), 2);
       assertTrue(selectionResult.getUnavailableSegments().isEmpty());
 
@@ -936,10 +1071,10 @@ public class InstanceSelectorTest {
       idealStateInstanceStateMap.put(instance, ONLINE);
       balancedInstanceSelector.onAssignmentChange(idealState, externalView, onlineSegments);
       strictReplicaGroupInstanceSelector.onAssignmentChange(idealState, externalView, onlineSegments);
-      selectionResult = balancedInstanceSelector.select(brokerRequest, segments);
+      selectionResult = balancedInstanceSelector.select(brokerRequest, segments, 0);
       assertEquals(selectionResult.getSegmentToInstanceMap().size(), 2);
       assertTrue(selectionResult.getUnavailableSegments().isEmpty());
-      selectionResult = strictReplicaGroupInstanceSelector.select(brokerRequest, segments);
+      selectionResult = strictReplicaGroupInstanceSelector.select(brokerRequest, segments, 0);
       assertEquals(selectionResult.getSegmentToInstanceMap().size(), 2);
       assertTrue(selectionResult.getUnavailableSegments().isEmpty());
 
@@ -958,10 +1093,10 @@ public class InstanceSelectorTest {
       externalViewInstanceStateMap1.put(instance, ONLINE);
       balancedInstanceSelector.onAssignmentChange(idealState, externalView, onlineSegments);
       strictReplicaGroupInstanceSelector.onAssignmentChange(idealState, externalView, onlineSegments);
-      selectionResult = balancedInstanceSelector.select(brokerRequest, segments);
+      selectionResult = balancedInstanceSelector.select(brokerRequest, segments, 0);
       assertEquals(selectionResult.getSegmentToInstanceMap().size(), 2);
       assertTrue(selectionResult.getUnavailableSegments().isEmpty());
-      selectionResult = strictReplicaGroupInstanceSelector.select(brokerRequest, segments);
+      selectionResult = strictReplicaGroupInstanceSelector.select(brokerRequest, segments, 0);
       assertEquals(selectionResult.getSegmentToInstanceMap().size(), 2);
       assertTrue(selectionResult.getUnavailableSegments().isEmpty());
 
@@ -969,10 +1104,10 @@ public class InstanceSelectorTest {
       idealStateInstanceStateMap.remove(instance);
       balancedInstanceSelector.onAssignmentChange(idealState, externalView, onlineSegments);
       strictReplicaGroupInstanceSelector.onAssignmentChange(idealState, externalView, onlineSegments);
-      selectionResult = balancedInstanceSelector.select(brokerRequest, segments);
+      selectionResult = balancedInstanceSelector.select(brokerRequest, segments, 0);
       assertTrue(selectionResult.getSegmentToInstanceMap().isEmpty());
       assertEquals(selectionResult.getUnavailableSegments(), Arrays.asList(segment0, segment1));
-      selectionResult = strictReplicaGroupInstanceSelector.select(brokerRequest, segments);
+      selectionResult = strictReplicaGroupInstanceSelector.select(brokerRequest, segments, 0);
       assertTrue(selectionResult.getSegmentToInstanceMap().isEmpty());
       assertEquals(selectionResult.getUnavailableSegments(), Arrays.asList(segment0, segment1));
 
@@ -993,10 +1128,10 @@ public class InstanceSelectorTest {
       externalViewInstanceStateMap1.put(errorInstance, ONLINE);
       balancedInstanceSelector.onAssignmentChange(idealState, externalView, onlineSegments);
       strictReplicaGroupInstanceSelector.onAssignmentChange(idealState, externalView, onlineSegments);
-      selectionResult = balancedInstanceSelector.select(brokerRequest, segments);
+      selectionResult = balancedInstanceSelector.select(brokerRequest, segments, 0);
       assertEquals(selectionResult.getSegmentToInstanceMap().size(), 2);
       assertTrue(selectionResult.getUnavailableSegments().isEmpty());
-      selectionResult = strictReplicaGroupInstanceSelector.select(brokerRequest, segments);
+      selectionResult = strictReplicaGroupInstanceSelector.select(brokerRequest, segments, 0);
       assertTrue(selectionResult.getSegmentToInstanceMap().isEmpty());
       assertEquals(selectionResult.getUnavailableSegments(), Arrays.asList(segment0, segment1));
 
@@ -1017,10 +1152,10 @@ public class InstanceSelectorTest {
       externalViewInstanceStateMap1.put(errorInstance, ERROR);
       balancedInstanceSelector.onAssignmentChange(idealState, externalView, onlineSegments);
       strictReplicaGroupInstanceSelector.onAssignmentChange(idealState, externalView, onlineSegments);
-      selectionResult = balancedInstanceSelector.select(brokerRequest, segments);
+      selectionResult = balancedInstanceSelector.select(brokerRequest, segments, 0);
       assertTrue(selectionResult.getSegmentToInstanceMap().isEmpty());
       assertTrue(selectionResult.getUnavailableSegments().isEmpty());
-      selectionResult = strictReplicaGroupInstanceSelector.select(brokerRequest, segments);
+      selectionResult = strictReplicaGroupInstanceSelector.select(brokerRequest, segments, 0);
       assertTrue(selectionResult.getSegmentToInstanceMap().isEmpty());
       assertTrue(selectionResult.getUnavailableSegments().isEmpty());
 
@@ -1038,10 +1173,10 @@ public class InstanceSelectorTest {
       enabledInstances.remove(instance);
       balancedInstanceSelector.onInstancesChange(enabledInstances, Collections.singletonList(instance));
       strictReplicaGroupInstanceSelector.onInstancesChange(enabledInstances, Collections.singletonList(instance));
-      selectionResult = balancedInstanceSelector.select(brokerRequest, segments);
+      selectionResult = balancedInstanceSelector.select(brokerRequest, segments, 0);
       assertTrue(selectionResult.getSegmentToInstanceMap().isEmpty());
       assertEquals(selectionResult.getUnavailableSegments(), Arrays.asList(segment0, segment1));
-      selectionResult = strictReplicaGroupInstanceSelector.select(brokerRequest, segments);
+      selectionResult = strictReplicaGroupInstanceSelector.select(brokerRequest, segments, 0);
       assertTrue(selectionResult.getSegmentToInstanceMap().isEmpty());
       assertEquals(selectionResult.getUnavailableSegments(), Arrays.asList(segment0, segment1));
 
@@ -1061,10 +1196,10 @@ public class InstanceSelectorTest {
       externalViewInstanceStateMap0.put(errorInstance, ONLINE);
       balancedInstanceSelector.onAssignmentChange(idealState, externalView, onlineSegments);
       strictReplicaGroupInstanceSelector.onAssignmentChange(idealState, externalView, onlineSegments);
-      selectionResult = balancedInstanceSelector.select(brokerRequest, segments);
+      selectionResult = balancedInstanceSelector.select(brokerRequest, segments, 0);
       assertEquals(selectionResult.getSegmentToInstanceMap().size(), 1);
       assertEquals(selectionResult.getUnavailableSegments(), Collections.singletonList(segment1));
-      selectionResult = strictReplicaGroupInstanceSelector.select(brokerRequest, segments);
+      selectionResult = strictReplicaGroupInstanceSelector.select(brokerRequest, segments, 0);
       assertEquals(selectionResult.getSegmentToInstanceMap().size(), 1);
       assertEquals(selectionResult.getUnavailableSegments(), Collections.singletonList(segment1));
 
@@ -1082,10 +1217,10 @@ public class InstanceSelectorTest {
       enabledInstances.remove(errorInstance);
       balancedInstanceSelector.onInstancesChange(enabledInstances, Collections.singletonList(errorInstance));
       strictReplicaGroupInstanceSelector.onInstancesChange(enabledInstances, Collections.singletonList(errorInstance));
-      selectionResult = balancedInstanceSelector.select(brokerRequest, segments);
+      selectionResult = balancedInstanceSelector.select(brokerRequest, segments, 0);
       assertTrue(selectionResult.getSegmentToInstanceMap().isEmpty());
       assertEquals(selectionResult.getUnavailableSegments(), Arrays.asList(segment0, segment1));
-      selectionResult = balancedInstanceSelector.select(brokerRequest, segments);
+      selectionResult = balancedInstanceSelector.select(brokerRequest, segments, 0);
       assertTrue(selectionResult.getSegmentToInstanceMap().isEmpty());
       assertEquals(selectionResult.getUnavailableSegments(), Arrays.asList(segment0, segment1));
 
@@ -1106,10 +1241,10 @@ public class InstanceSelectorTest {
       externalViewInstanceStateMap1.put(instance, CONSUMING);
       balancedInstanceSelector.onAssignmentChange(idealState, externalView, onlineSegments);
       strictReplicaGroupInstanceSelector.onAssignmentChange(idealState, externalView, onlineSegments);
-      selectionResult = balancedInstanceSelector.select(brokerRequest, segments);
+      selectionResult = balancedInstanceSelector.select(brokerRequest, segments, 0);
       assertTrue(selectionResult.getSegmentToInstanceMap().isEmpty());
       assertEquals(selectionResult.getUnavailableSegments(), Arrays.asList(segment0, segment1));
-      selectionResult = strictReplicaGroupInstanceSelector.select(brokerRequest, segments);
+      selectionResult = strictReplicaGroupInstanceSelector.select(brokerRequest, segments, 0);
       assertTrue(selectionResult.getSegmentToInstanceMap().isEmpty());
       assertEquals(selectionResult.getUnavailableSegments(), Arrays.asList(segment0, segment1));
     }
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/routing/RoutingManager.java b/pinot-core/src/main/java/org/apache/pinot/core/routing/RoutingManager.java
index db535dcaa6..857f0207da 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/routing/RoutingManager.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/routing/RoutingManager.java
@@ -50,7 +50,7 @@ public interface RoutingManager {
    * @param brokerRequest the broker request constructed from a query.
    * @return the route table.
    */
-  RoutingTable getRoutingTable(BrokerRequest brokerRequest);
+  RoutingTable getRoutingTable(BrokerRequest brokerRequest, long requestId);
 
   /**
    * Validate routing exist for a table
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/QueryEnvironment.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/QueryEnvironment.java
index 814837ec15..d84a70d3f8 100644
--- a/pinot-query-planner/src/main/java/org/apache/pinot/query/QueryEnvironment.java
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/QueryEnvironment.java
@@ -125,11 +125,11 @@ public class QueryEnvironment {
    * @param sqlNodeAndOptions parsed SQL query.
    * @return a dispatchable query plan
    */
-  public QueryPlan planQuery(String sqlQuery, SqlNodeAndOptions sqlNodeAndOptions) {
+  public QueryPlan planQuery(String sqlQuery, SqlNodeAndOptions sqlNodeAndOptions, long requestId) {
     try (PlannerContext plannerContext = new PlannerContext(_config, _catalogReader, _typeFactory, _hepProgram)) {
       plannerContext.setOptions(sqlNodeAndOptions.getOptions());
       RelRoot relRoot = compileQuery(sqlNodeAndOptions.getSqlNode(), plannerContext);
-      return toDispatchablePlan(relRoot, plannerContext);
+      return toDispatchablePlan(relRoot, plannerContext, requestId);
     } catch (CalciteContextException e) {
       throw new RuntimeException("Error composing query plan for '" + sqlQuery
           + "': " + e.getMessage() + "'", e);
@@ -141,9 +141,9 @@ public class QueryEnvironment {
   /**
    * Explain a SQL query.
    *
-   * Similar to {@link QueryEnvironment#planQuery(String, SqlNodeAndOptions)}, this API runs the query compilation.
-   * But it doesn't run the distributed {@link QueryPlan} generation, instead it only returns the explained logical
-   * plan.
+   * Similar to {@link QueryEnvironment#planQuery(String, SqlNodeAndOptions, long)}, this API runs the query
+   * compilation. But it doesn't run the distributed {@link QueryPlan} generation, instead it only returns the
+   * explained logical plan.
    *
    * @param sqlQuery SQL query string.
    * @param sqlNodeAndOptions parsed SQL query.
@@ -165,7 +165,7 @@ public class QueryEnvironment {
 
   @VisibleForTesting
   public QueryPlan planQuery(String sqlQuery) {
-    return planQuery(sqlQuery, CalciteSqlParser.compileToSqlNodeAndOptions(sqlQuery));
+    return planQuery(sqlQuery, CalciteSqlParser.compileToSqlNodeAndOptions(sqlQuery), 0);
   }
 
   @VisibleForTesting
@@ -219,9 +219,9 @@ public class QueryEnvironment {
     }
   }
 
-  private QueryPlan toDispatchablePlan(RelRoot relRoot, PlannerContext plannerContext) {
+  private QueryPlan toDispatchablePlan(RelRoot relRoot, PlannerContext plannerContext, long requestId) {
     // 5. construct a dispatchable query plan.
-    StagePlanner queryStagePlanner = new StagePlanner(plannerContext, _workerManager);
+    StagePlanner queryStagePlanner = new StagePlanner(plannerContext, _workerManager, requestId);
     return queryStagePlanner.makePlan(relRoot);
   }
 
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/StagePlanner.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/StagePlanner.java
index 2d61856c85..5f46b23d26 100644
--- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/StagePlanner.java
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/StagePlanner.java
@@ -44,10 +44,12 @@ public class StagePlanner {
   private final PlannerContext _plannerContext;   // DO NOT REMOVE.
   private final WorkerManager _workerManager;
   private int _stageIdCounter;
+  private long _requestId;
 
-  public StagePlanner(PlannerContext plannerContext, WorkerManager workerManager) {
+  public StagePlanner(PlannerContext plannerContext, WorkerManager workerManager, long requestId) {
     _plannerContext = plannerContext;
     _workerManager = workerManager;
+    _requestId = requestId;
   }
 
   /**
@@ -79,7 +81,7 @@ public class StagePlanner {
 
     // assign workers to each stage.
     for (Map.Entry<Integer, StageMetadata> e : queryPlan.getStageMetadataMap().entrySet()) {
-      _workerManager.assignWorkerToStage(e.getKey(), e.getValue());
+      _workerManager.assignWorkerToStage(e.getKey(), e.getValue(), _requestId);
     }
 
     return queryPlan;
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/routing/WorkerManager.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/routing/WorkerManager.java
index 112aec606b..42bb19d269 100644
--- a/pinot-query-planner/src/main/java/org/apache/pinot/query/routing/WorkerManager.java
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/routing/WorkerManager.java
@@ -58,12 +58,12 @@ public class WorkerManager {
     _routingManager = routingManager;
   }
 
-  public void assignWorkerToStage(int stageId, StageMetadata stageMetadata) {
+  public void assignWorkerToStage(int stageId, StageMetadata stageMetadata, long requestId) {
     List<String> scannedTables = stageMetadata.getScannedTables();
     if (scannedTables.size() == 1) {
       // table scan stage, need to attach server as well as segment info for each physical table type.
       String logicalTableName = scannedTables.get(0);
-      Map<String, RoutingTable> routingTableMap = getRoutingTable(logicalTableName);
+      Map<String, RoutingTable> routingTableMap = getRoutingTable(logicalTableName, requestId);
       if (routingTableMap.size() == 0) {
         throw new IllegalArgumentException("Unable to find routing entries for table: " + logicalTableName);
       }
@@ -124,22 +124,22 @@ public class WorkerManager {
    * @param logicalTableName it can either be a hybrid table name or a physical table name with table type.
    * @return keyed-map from table type(s) to routing table(s).
    */
-  private Map<String, RoutingTable> getRoutingTable(String logicalTableName) {
+  private Map<String, RoutingTable> getRoutingTable(String logicalTableName, long requestId) {
     String rawTableName = TableNameBuilder.extractRawTableName(logicalTableName);
     TableType tableType = TableNameBuilder.getTableTypeFromTableName(logicalTableName);
     Map<String, RoutingTable> routingTableMap = new HashMap<>();
     RoutingTable routingTable;
     if (tableType == null) {
-      routingTable = getRoutingTable(rawTableName, TableType.OFFLINE);
+      routingTable = getRoutingTable(rawTableName, TableType.OFFLINE, requestId);
       if (routingTable != null) {
         routingTableMap.put(TableType.OFFLINE.name(), routingTable);
       }
-      routingTable = getRoutingTable(rawTableName, TableType.REALTIME);
+      routingTable = getRoutingTable(rawTableName, TableType.REALTIME, requestId);
       if (routingTable != null) {
         routingTableMap.put(TableType.REALTIME.name(), routingTable);
       }
     } else {
-      routingTable = getRoutingTable(logicalTableName, tableType);
+      routingTable = getRoutingTable(logicalTableName, tableType, requestId);
       if (routingTable != null) {
         routingTableMap.put(tableType.name(), routingTable);
       }
@@ -147,10 +147,10 @@ public class WorkerManager {
     return routingTableMap;
   }
 
-  private RoutingTable getRoutingTable(String tableName, TableType tableType) {
+  private RoutingTable getRoutingTable(String tableName, TableType tableType, long requestId) {
     String tableNameWithType = TableNameBuilder.forType(tableType).tableNameWithType(
         TableNameBuilder.extractRawTableName(tableName));
     return _routingManager.getRoutingTable(
-        CalciteSqlCompiler.compileToBrokerRequest("SELECT * FROM " + tableNameWithType));
+        CalciteSqlCompiler.compileToBrokerRequest("SELECT * FROM " + tableNameWithType), requestId);
   }
 }
diff --git a/pinot-query-planner/src/test/java/org/apache/pinot/query/testutils/MockRoutingManagerFactory.java b/pinot-query-planner/src/test/java/org/apache/pinot/query/testutils/MockRoutingManagerFactory.java
index 9d2d5e1a22..5e75de0455 100644
--- a/pinot-query-planner/src/test/java/org/apache/pinot/query/testutils/MockRoutingManagerFactory.java
+++ b/pinot-query-planner/src/test/java/org/apache/pinot/query/testutils/MockRoutingManagerFactory.java
@@ -144,7 +144,7 @@ public class MockRoutingManagerFactory {
     }
 
     @Override
-    public RoutingTable getRoutingTable(BrokerRequest brokerRequest) {
+    public RoutingTable getRoutingTable(BrokerRequest brokerRequest, long requestId) {
       String tableName = brokerRequest.getPinotQuery().getDataSource().getTableName();
       return _routingTableMap.getOrDefault(tableName,
           _routingTableMap.get(TableNameBuilder.extractRawTableName(tableName)));
diff --git a/pinot-spi/src/main/java/org/apache/pinot/spi/config/table/RoutingConfig.java b/pinot-spi/src/main/java/org/apache/pinot/spi/config/table/RoutingConfig.java
index 2c238aa287..8af5773675 100644
--- a/pinot-spi/src/main/java/org/apache/pinot/spi/config/table/RoutingConfig.java
+++ b/pinot-spi/src/main/java/org/apache/pinot/spi/config/table/RoutingConfig.java
@@ -31,6 +31,7 @@ public class RoutingConfig extends BaseJsonConfig {
   public static final String EMPTY_SEGMENT_PRUNER_TYPE = "empty";
   public static final String REPLICA_GROUP_INSTANCE_SELECTOR_TYPE = "replicaGroup";
   public static final String STRICT_REPLICA_GROUP_INSTANCE_SELECTOR_TYPE = "strictReplicaGroup";
+  public static final String MULTI_STAGE_REPLICA_GROUP_SELECTOR_TYPE = "multiStageReplicaGroup";
 
   // Replaced by _segmentPrunerTypes and _instanceSelectorType
   @Deprecated


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@pinot.apache.org
For additional commands, e-mail: commits-help@pinot.apache.org