You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by sh...@apache.org on 2020/02/10 15:28:02 UTC

[lucene-solr] branch branch_8x updated: SOLR-13996: Refactor HttpShardHandler.prepDistributed method (#1220)

This is an automated email from the ASF dual-hosted git repository.

shalin pushed a commit to branch branch_8x
in repository https://gitbox.apache.org/repos/asf/lucene-solr.git


The following commit(s) were added to refs/heads/branch_8x by this push:
     new 78e567c  SOLR-13996: Refactor HttpShardHandler.prepDistributed method (#1220)
78e567c is described below

commit 78e567c57e45f56ff22ecfc3e43e315255ef3561
Author: Shalin Shekhar Mangar <sh...@apache.org>
AuthorDate: Mon Feb 10 19:57:05 2020 +0530

    SOLR-13996: Refactor HttpShardHandler.prepDistributed method (#1220)
    
    SOLR-13996: Refactor HttpShardHandler.prepDistributed method into smaller pieces
    
    This commit introduces an interface named ReplicaSource which is marked as experimental. It has two sub-classes named CloudReplicaSource (for solr cloud) and LegacyReplicaSource for non-cloud clusters. The prepDistributed method now calls out to these sub-classes depending on whether the cluster is running on cloud mode or not.
    
    (cherry picked from commit c65b97665c61116632bc93e5f88f84bdb5cccf21)
---
 solr/CHANGES.txt                                   |   2 +
 .../solr/handler/component/CloudReplicaSource.java | 246 +++++++++++++++
 .../solr/handler/component/HttpShardHandler.java   | 339 +++++----------------
 .../handler/component/HttpShardHandlerFactory.java |   4 +-
 .../handler/component/LegacyReplicaSource.java     |  77 +++++
 .../solr/handler/component/ReplicaSource.java}     |  37 +--
 .../apache/solr/cloud/ClusterStateMockUtil.java    |  19 +-
 .../handler/component/CloudReplicaSourceTest.java  | 263 ++++++++++++++++
 .../org/apache/solr/cloud/MockZkStateReader.java   |   7 +
 9 files changed, 714 insertions(+), 280 deletions(-)

diff --git a/solr/CHANGES.txt b/solr/CHANGES.txt
index 92345b4..21b90e8 100644
--- a/solr/CHANGES.txt
+++ b/solr/CHANGES.txt
@@ -151,6 +151,8 @@ Other Changes
 
 * SOLR-14209: Upgrade JQuery to 3.4.1 (Kevin Risden)
 
+* SOLR-13996: Refactor HttpShardHandler.prepDistributed method. (shalin)
+
 ==================  8.4.1 ==================
 
 Consult the LUCENE_CHANGES.txt file for additional, low level, changes in this release.
diff --git a/solr/core/src/java/org/apache/solr/handler/component/CloudReplicaSource.java b/solr/core/src/java/org/apache/solr/handler/component/CloudReplicaSource.java
new file mode 100644
index 0000000..5ff8ec9
--- /dev/null
+++ b/solr/core/src/java/org/apache/solr/handler/component/CloudReplicaSource.java
@@ -0,0 +1,246 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.solr.handler.component;
+
+import java.lang.invoke.MethodHandles;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Predicate;
+import java.util.stream.Collectors;
+
+import org.apache.solr.client.solrj.routing.ReplicaListTransformer;
+import org.apache.solr.client.solrj.util.ClientUtils;
+import org.apache.solr.common.SolrException;
+import org.apache.solr.common.cloud.ClusterState;
+import org.apache.solr.common.cloud.DocCollection;
+import org.apache.solr.common.cloud.Replica;
+import org.apache.solr.common.cloud.Slice;
+import org.apache.solr.common.cloud.ZkStateReader;
+import org.apache.solr.common.params.ShardParams;
+import org.apache.solr.common.params.SolrParams;
+import org.apache.solr.common.util.StrUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * A replica source for solr cloud mode
+ */
+class CloudReplicaSource implements ReplicaSource {
+  private static final Logger log = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass());
+
+  private String[] slices;
+  private List<String>[] replicas;
+
+  private CloudReplicaSource(Builder builder) {
+    final String shards = builder.params.get(ShardParams.SHARDS);
+    if (shards != null) {
+      withShardsParam(builder, shards);
+    } else {
+      withClusterState(builder, builder.params);
+    }
+  }
+
+  private void withClusterState(Builder builder, SolrParams params) {
+    ClusterState clusterState = builder.zkStateReader.getClusterState();
+    String shardKeys = params.get(ShardParams._ROUTE_);
+
+    // This will be the complete list of slices we need to query for this request.
+    Map<String, Slice> sliceMap = new HashMap<>();
+
+    // we need to find out what collections this request is for.
+
+    // A comma-separated list of specified collections.
+    // Eg: "collection1,collection2,collection3"
+    String collections = params.get("collection");
+    if (collections != null) {
+      // If there were one or more collections specified in the query, split
+      // each parameter and store as a separate member of a List.
+      List<String> collectionList = StrUtils.splitSmart(collections, ",",
+          true);
+      // In turn, retrieve the slices that cover each collection from the
+      // cloud state and add them to the Map 'slices'.
+      for (String collectionName : collectionList) {
+        // The original code produced <collection-name>_<shard-name> when the collections
+        // parameter was specified (see ClientUtils.appendMap)
+        // Is this necessary if ony one collection is specified?
+        // i.e. should we change multiCollection to collectionList.size() > 1?
+        addSlices(sliceMap, clusterState, params, collectionName, shardKeys, true);
+      }
+    } else {
+      // just this collection
+      addSlices(sliceMap, clusterState, params, builder.collection, shardKeys, false);
+    }
+
+    this.slices = sliceMap.keySet().toArray(new String[sliceMap.size()]);
+    this.replicas = new List[slices.length];
+    for (int i = 0; i < slices.length; i++) {
+      String sliceName = slices[i];
+      replicas[i] = findReplicas(builder, null, clusterState, sliceMap.get(sliceName));
+    }
+  }
+
+  private void withShardsParam(Builder builder, String shardsParam) {
+    List<String> sliceOrUrls = StrUtils.splitSmart(shardsParam, ",", true);
+    this.slices = new String[sliceOrUrls.size()];
+    this.replicas = new List[sliceOrUrls.size()];
+
+    ClusterState clusterState = builder.zkStateReader.getClusterState();
+
+    for (int i = 0; i < sliceOrUrls.size(); i++) {
+      String sliceOrUrl = sliceOrUrls.get(i);
+      if (sliceOrUrl.indexOf('/') < 0) {
+        // this is a logical shard
+        this.slices[i] = sliceOrUrl;
+        replicas[i] = findReplicas(builder, shardsParam, clusterState, clusterState.getCollection(builder.collection).getSlice(sliceOrUrl));
+      } else {
+        // this has urls
+        this.replicas[i] = StrUtils.splitSmart(sliceOrUrl, "|", true);
+        builder.replicaListTransformer.transform(replicas[i]);
+        builder.hostChecker.checkWhitelist(clusterState, shardsParam, replicas[i]);
+      }
+    }
+  }
+
+  private List<String> findReplicas(Builder builder, String shardsParam, ClusterState clusterState, Slice slice) {
+    if (slice == null) {
+      // Treat this the same as "all servers down" for a slice, and let things continue
+      // if partial results are acceptable
+      return Collections.emptyList();
+    } else {
+      final Predicate<Replica> isShardLeader = new IsLeaderPredicate(builder.zkStateReader, clusterState, slice.getCollection(), slice.getName());
+      List<Replica> list = slice.getReplicas()
+          .stream()
+          .filter(replica -> replica.isActive(clusterState.getLiveNodes()))
+          .filter(replica -> !builder.onlyNrt || (replica.getType() == Replica.Type.NRT || (replica.getType() == Replica.Type.TLOG && isShardLeader.test(replica))))
+          .collect(Collectors.toList());
+      builder.replicaListTransformer.transform(list);
+      List<String> collect = list.stream().map(Replica::getCoreUrl).collect(Collectors.toList());
+      builder.hostChecker.checkWhitelist(clusterState, shardsParam, collect);
+      return collect;
+    }
+  }
+
+  private void addSlices(Map<String, Slice> target, ClusterState state, SolrParams params, String collectionName, String shardKeys, boolean multiCollection) {
+    DocCollection coll = state.getCollection(collectionName);
+    Collection<Slice> slices = coll.getRouter().getSearchSlices(shardKeys, params, coll);
+    ClientUtils.addSlices(target, collectionName, slices, multiCollection);
+  }
+
+  @Override
+  public List<String> getSliceNames() {
+    return Collections.unmodifiableList(Arrays.asList(slices));
+  }
+
+  @Override
+  public List<String> getReplicasBySlice(int sliceNumber) {
+    assert sliceNumber >= 0 && sliceNumber < replicas.length;
+    return replicas[sliceNumber];
+  }
+
+  @Override
+  public int getSliceCount() {
+    return slices.length;
+  }
+
+  /**
+   * A predicate to test if a replica is the leader according to {@link ZkStateReader#getLeaderRetry(String, String)}.
+   * <p>
+   * The result of getLeaderRetry is cached in the first call so that subsequent tests are faster and do not block.
+   */
+  private static class IsLeaderPredicate implements Predicate<Replica> {
+    private final ZkStateReader zkStateReader;
+    private final ClusterState clusterState;
+    private final String collectionName;
+    private final String sliceName;
+    private Replica shardLeader = null;
+
+    public IsLeaderPredicate(ZkStateReader zkStateReader, ClusterState clusterState, String collectionName, String sliceName) {
+      this.zkStateReader = zkStateReader;
+      this.clusterState = clusterState;
+      this.collectionName = collectionName;
+      this.sliceName = sliceName;
+    }
+
+    @Override
+    public boolean test(Replica replica) {
+      if (shardLeader == null) {
+        try {
+          shardLeader = zkStateReader.getLeaderRetry(collectionName, sliceName);
+        } catch (InterruptedException e) {
+          throw new SolrException(SolrException.ErrorCode.SERVICE_UNAVAILABLE,
+              "Exception finding leader for shard " + sliceName + " in collection "
+                  + collectionName, e);
+        } catch (SolrException e) {
+          if (log.isDebugEnabled()) {
+            log.debug("Exception finding leader for shard {} in collection {}. Collection State: {}",
+                sliceName, collectionName, clusterState.getCollectionOrNull(collectionName));
+          }
+          throw e;
+        }
+      }
+      return replica.getName().equals(shardLeader.getName());
+    }
+  }
+
+  static class Builder {
+    private String collection;
+    private ZkStateReader zkStateReader;
+    private SolrParams params;
+    private boolean onlyNrt;
+    private ReplicaListTransformer replicaListTransformer;
+    private HttpShardHandlerFactory.WhitelistHostChecker hostChecker;
+
+    public Builder collection(String collection) {
+      this.collection = collection;
+      return this;
+    }
+
+    public Builder zkStateReader(ZkStateReader stateReader) {
+      this.zkStateReader = stateReader;
+      return this;
+    }
+
+    public Builder params(SolrParams params) {
+      this.params = params;
+      return this;
+    }
+
+    public Builder onlyNrt(boolean onlyNrt) {
+      this.onlyNrt = onlyNrt;
+      return this;
+    }
+
+    public Builder replicaListTransformer(ReplicaListTransformer replicaListTransformer) {
+      this.replicaListTransformer = replicaListTransformer;
+      return this;
+    }
+
+    public Builder whitelistHostChecker(HttpShardHandlerFactory.WhitelistHostChecker hostChecker) {
+      this.hostChecker = hostChecker;
+      return this;
+    }
+
+    public CloudReplicaSource build() {
+      return new CloudReplicaSource(this);
+    }
+  }
+}
diff --git a/solr/core/src/java/org/apache/solr/handler/component/HttpShardHandler.java b/solr/core/src/java/org/apache/solr/handler/component/HttpShardHandler.java
index 8f8c48e..ca6dddf 100644
--- a/solr/core/src/java/org/apache/solr/handler/component/HttpShardHandler.java
+++ b/solr/core/src/java/org/apache/solr/handler/component/HttpShardHandler.java
@@ -17,11 +17,7 @@
 package org.apache.solr.handler.component;
 
 import java.io.IOException;
-import java.lang.invoke.MethodHandles;
 import java.net.ConnectException;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.Collection;
 import java.util.HashMap;
 import java.util.HashSet;
 import java.util.List;
@@ -32,7 +28,6 @@ import java.util.concurrent.CompletionService;
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.Future;
 import java.util.concurrent.TimeUnit;
-import java.util.function.Predicate;
 
 import io.opentracing.Span;
 import io.opentracing.Tracer;
@@ -40,43 +35,32 @@ import io.opentracing.propagation.Format;
 import org.apache.solr.client.solrj.SolrRequest;
 import org.apache.solr.client.solrj.SolrResponse;
 import org.apache.solr.client.solrj.SolrServerException;
-import org.apache.solr.client.solrj.impl.BinaryResponseParser;
 import org.apache.solr.client.solrj.impl.Http2SolrClient;
 import org.apache.solr.client.solrj.impl.LBSolrClient;
 import org.apache.solr.client.solrj.request.QueryRequest;
 import org.apache.solr.client.solrj.routing.ReplicaListTransformer;
-import org.apache.solr.client.solrj.util.ClientUtils;
 import org.apache.solr.cloud.CloudDescriptor;
 import org.apache.solr.cloud.ZkController;
 import org.apache.solr.common.SolrException;
-import org.apache.solr.common.SolrException.ErrorCode;
-import org.apache.solr.common.cloud.ClusterState;
-import org.apache.solr.common.cloud.DocCollection;
 import org.apache.solr.common.cloud.Replica;
-import org.apache.solr.common.cloud.Slice;
 import org.apache.solr.common.cloud.ZkCoreNodeProps;
 import org.apache.solr.common.params.CommonParams;
 import org.apache.solr.common.params.ModifiableSolrParams;
 import org.apache.solr.common.params.ShardParams;
 import org.apache.solr.common.params.SolrParams;
-import org.apache.solr.common.util.JavaBinCodec;
 import org.apache.solr.common.util.NamedList;
-import org.apache.solr.common.util.StrUtils;
 import org.apache.solr.core.CoreDescriptor;
 import org.apache.solr.request.SolrQueryRequest;
 import org.apache.solr.request.SolrRequestInfo;
 import org.apache.solr.util.tracing.GlobalTracer;
 import org.apache.solr.util.tracing.SolrRequestCarrier;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
 import org.slf4j.MDC;
 
 public class HttpShardHandler extends ShardHandler {
-  
   /**
    * If the request context map has an entry with this key and Boolean.TRUE as value,
    * {@link #prepDistributed(ResponseBuilder)} will only include {@link org.apache.solr.common.cloud.Replica.Type#NRT} replicas as possible
-   * destination of the distributed request (or a leader replica of type {@link org.apache.solr.common.cloud.Replica.Type#TLOG}). This is used 
+   * destination of the distributed request (or a leader replica of type {@link org.apache.solr.common.cloud.Replica.Type#TLOG}). This is used
    * by the RealtimeGet handler, since other types of replicas shouldn't respond to RTG requests
    */
   public static String ONLY_NRT_REPLICAS = "distribOnlyRealtime";
@@ -84,11 +68,9 @@ public class HttpShardHandler extends ShardHandler {
   private HttpShardHandlerFactory httpShardHandlerFactory;
   private CompletionService<ShardResponse> completionService;
   private Set<Future<ShardResponse>> pending;
-  private Map<String,List<String>> shardToURLs;
+  private Map<String, List<String>> shardToURLs;
   private Http2SolrClient httpClient;
 
-  private static final Logger log = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass());
-
   public HttpShardHandler(HttpShardHandlerFactory httpShardHandlerFactory, Http2SolrClient httpClient) {
     this.httpClient = httpClient;
     this.httpShardHandlerFactory = httpShardHandlerFactory;
@@ -141,19 +123,12 @@ public class HttpShardHandler extends ShardHandler {
     return urls;
   }
 
-  private static final BinaryResponseParser READ_STR_AS_CHARSEQ_PARSER = new BinaryResponseParser() {
-    @Override
-    protected JavaBinCodec createCodec() {
-      return new JavaBinCodec(null, stringCache).setReadStringAsCharSeq(true);
-    }
-  };
-
   @Override
   public void submit(final ShardRequest sreq, final String shard, final ModifiableSolrParams params) {
     // do this outside of the callable for thread safety reasons
     final List<String> urls = getURLs(shard);
     final Tracer tracer = GlobalTracer.getTracer();
-    final Span span = tracer != null? tracer.activeSpan() : null;
+    final Span span = tracer != null ? tracer.activeSpan() : null;
 
     Callable<ShardResponse> task = () -> {
 
@@ -183,7 +158,7 @@ public class HttpShardHandler extends ShardHandler {
         // req.setResponseParser(new BinaryResponseParser());
 
         // if there are no shards available for a slice, urls.size()==0
-        if (urls.size()==0) {
+        if (urls.size() == 0) {
           // TODO: what's the right error code here? We should use the same thing when
           // all of the servers for a shard are down.
           throw new SolrException(SolrException.ErrorCode.SERVICE_UNAVAILABLE, "no servers hosting shard: " + shard);
@@ -198,13 +173,12 @@ public class HttpShardHandler extends ShardHandler {
           ssr.nl = rsp.getResponse();
           srsp.setShardAddress(rsp.getServer());
         }
-      }
-      catch( ConnectException cex ) {
+      } catch (ConnectException cex) {
         srsp.setException(cex); //????
       } catch (Exception th) {
         srsp.setException(th);
         if (th instanceof SolrException) {
-          srsp.setResponseCode(((SolrException)th).code());
+          srsp.setResponseCode(((SolrException) th).code());
         } else {
           srsp.setResponseCode(-1);
         }
@@ -216,13 +190,13 @@ public class HttpShardHandler extends ShardHandler {
     };
 
     try {
-      if (shard != null)  {
+      if (shard != null) {
         MDC.put("ShardRequest.shards", shard);
       }
-      if (urls != null && !urls.isEmpty())  {
+      if (urls != null && !urls.isEmpty()) {
         MDC.put("ShardRequest.urlList", urls.toString());
       }
-      pending.add( completionService.submit(task) );
+      pending.add(completionService.submit(task));
     } finally {
       MDC.remove("ShardRequest.shards");
       MDC.remove("ShardRequest.urlList");
@@ -233,26 +207,25 @@ public class HttpShardHandler extends ShardHandler {
     req.setBasePath(url);
     return httpClient.request(req);
   }
-  
+
   /**
    * Subclasses could modify the request based on the shard
    */
-  protected QueryRequest makeQueryRequest(final ShardRequest sreq, ModifiableSolrParams params, String shard)
-  {
+  protected QueryRequest makeQueryRequest(final ShardRequest sreq, ModifiableSolrParams params, String shard) {
     // use generic request to avoid extra processing of queries
     return new QueryRequest(params);
   }
-  
+
   /**
    * Subclasses could modify the Response based on the the shard
    */
-  protected ShardResponse transfomResponse(final ShardRequest sreq, ShardResponse rsp, String shard)
-  {
+  protected ShardResponse transfomResponse(final ShardRequest sreq, ShardResponse rsp, String shard) {
     return rsp;
   }
 
-  /** returns a ShardResponse of the last response correlated with a ShardRequest.  This won't 
-   * return early if it runs into an error.  
+  /**
+   * returns a ShardResponse of the last response correlated with a ShardRequest.  This won't
+   * return early if it runs into an error.
    **/
   @Override
   public ShardResponse takeCompletedIncludingErrors() {
@@ -260,16 +233,17 @@ public class HttpShardHandler extends ShardHandler {
   }
 
 
-  /** returns a ShardResponse of the last response correlated with a ShardRequest,
+  /**
+   * returns a ShardResponse of the last response correlated with a ShardRequest,
    * or immediately returns a ShardResponse if there was an error detected
    */
   @Override
   public ShardResponse takeCompletedOrError() {
     return take(true);
   }
-  
+
   private ShardResponse take(boolean bailOnError) {
-    
+
     while (pending.size() > 0) {
       try {
         Future<ShardResponse> future = completionService.take();
@@ -289,7 +263,7 @@ public class HttpShardHandler extends ShardHandler {
       } catch (ExecutionException e) {
         // should be impossible... the problem with catching the exception
         // at this level is we don't know what ShardRequest it applied to
-        throw new SolrException(SolrException.ErrorCode.SERVER_ERROR, "Impossible Exception",e);
+        throw new SolrException(SolrException.ErrorCode.SERVER_ERROR, "Impossible Exception", e);
       }
     }
     return null;
@@ -309,237 +283,71 @@ public class HttpShardHandler extends ShardHandler {
     final SolrParams params = req.getParams();
     final String shards = params.get(ShardParams.SHARDS);
 
-    // since the cost of grabbing cloud state is still up in the air, we grab it only
-    // if we need it.
-    ClusterState clusterState = null;
-    Map<String,Slice> slices = null;
     CoreDescriptor coreDescriptor = req.getCore().getCoreDescriptor();
     CloudDescriptor cloudDescriptor = coreDescriptor.getCloudDescriptor();
     ZkController zkController = req.getCore().getCoreContainer().getZkController();
 
     final ReplicaListTransformer replicaListTransformer = httpShardHandlerFactory.getReplicaListTransformer(req);
 
-    if (shards != null) {
-      List<String> lst = StrUtils.splitSmart(shards, ",", true);
-      rb.shards = lst.toArray(new String[lst.size()]);
-      rb.slices = new String[rb.shards.length];
-
-      if (zkController != null) {
-        // figure out which shards are slices
-        for (int i=0; i<rb.shards.length; i++) {
-          if (rb.shards[i].indexOf('/') < 0) {
-            // this is a logical shard
-            rb.slices[i] = rb.shards[i];
-            rb.shards[i] = null;
-          }
-        }
-      }
-    } else if (zkController != null) {
-      // we weren't provided with an explicit list of slices to query via "shards", so use the cluster state
-
-      clusterState =  zkController.getClusterState();
-      String shardKeys =  params.get(ShardParams._ROUTE_);
-
-      // This will be the complete list of slices we need to query for this request.
-      slices = new HashMap<>();
-
-      // we need to find out what collections this request is for.
-
-      // A comma-separated list of specified collections.
-      // Eg: "collection1,collection2,collection3"
-      String collections = params.get("collection");
-      if (collections != null) {
-        // If there were one or more collections specified in the query, split
-        // each parameter and store as a separate member of a List.
-        List<String> collectionList = StrUtils.splitSmart(collections, ",",
-            true);
-        // In turn, retrieve the slices that cover each collection from the
-        // cloud state and add them to the Map 'slices'.
-        for (String collectionName : collectionList) {
-          // The original code produced <collection-name>_<shard-name> when the collections
-          // parameter was specified (see ClientUtils.appendMap)
-          // Is this necessary if ony one collection is specified?
-          // i.e. should we change multiCollection to collectionList.size() > 1?
-          addSlices(slices, clusterState, params, collectionName,  shardKeys, true);
-        }
-      } else {
-        // just this collection
-        String collectionName = cloudDescriptor.getCollectionName();
-        addSlices(slices, clusterState, params, collectionName,  shardKeys, false);
-      }
-
-
-      // Store the logical slices in the ResponseBuilder and create a new
-      // String array to hold the physical shards (which will be mapped
-      // later).
-      rb.slices = slices.keySet().toArray(new String[slices.size()]);
-      rb.shards = new String[rb.slices.length];
-    }
-
     HttpShardHandlerFactory.WhitelistHostChecker hostChecker = httpShardHandlerFactory.getWhitelistHostChecker();
     if (shards != null && zkController == null && hostChecker.isWhitelistHostCheckingEnabled() && !hostChecker.hasExplicitWhitelist()) {
-      throw new SolrException(ErrorCode.FORBIDDEN, "HttpShardHandlerFactory "+HttpShardHandlerFactory.INIT_SHARDS_WHITELIST
-          +" not configured but required (in lieu of ZkController and ClusterState) when using the '"+ShardParams.SHARDS+"' parameter."
-          +HttpShardHandlerFactory.SET_SOLR_DISABLE_SHARDS_WHITELIST_CLUE);
+      throw new SolrException(SolrException.ErrorCode.FORBIDDEN, "HttpShardHandlerFactory " + HttpShardHandlerFactory.INIT_SHARDS_WHITELIST
+          + " not configured but required (in lieu of ZkController and ClusterState) when using the '" + ShardParams.SHARDS + "' parameter."
+          + HttpShardHandlerFactory.SET_SOLR_DISABLE_SHARDS_WHITELIST_CLUE);
     }
 
-    //
-    // Map slices to shards
-    //
+    ReplicaSource replicaSource;
     if (zkController != null) {
-
-      // Are we hosting the shard that this request is for, and are we active? If so, then handle it ourselves
-      // and make it a non-distributed request.
-      String ourSlice = cloudDescriptor.getShardId();
-      String ourCollection = cloudDescriptor.getCollectionName();
-      // Some requests may only be fulfilled by replicas of type Replica.Type.NRT
-      boolean onlyNrtReplicas = Boolean.TRUE == req.getContext().get(ONLY_NRT_REPLICAS);
-      if (rb.slices.length == 1 && rb.slices[0] != null
-          && ( rb.slices[0].equals(ourSlice) || rb.slices[0].equals(ourCollection + "_" + ourSlice) )  // handle the <collection>_<slice> format
-          && cloudDescriptor.getLastPublished() == Replica.State.ACTIVE
-          && (!onlyNrtReplicas || cloudDescriptor.getReplicaType() == Replica.Type.NRT)) {
-        boolean shortCircuit = params.getBool("shortCircuit", true);       // currently just a debugging parameter to check distrib search on a single node
-
-        String targetHandler = params.get(ShardParams.SHARDS_QT);
-        shortCircuit = shortCircuit && targetHandler == null;             // if a different handler is specified, don't short-circuit
-
-        if (shortCircuit) {
-          rb.isDistrib = false;
-          rb.shortCircuitedURL = ZkCoreNodeProps.getCoreUrl(zkController.getBaseUrl(), coreDescriptor.getName());
-          if (hostChecker.isWhitelistHostCheckingEnabled() && hostChecker.hasExplicitWhitelist()) {
-            /*
-             * We only need to check the host whitelist if there is an explicit whitelist (other than all the live nodes)
-             * when the "shards" indicate cluster state elements only
-             */
-            hostChecker.checkWhitelist(clusterState, shards, Arrays.asList(rb.shortCircuitedURL));
-          }
-          return;
-        }
+      boolean onlyNrt = Boolean.TRUE == req.getContext().get(ONLY_NRT_REPLICAS);
+
+      replicaSource = new CloudReplicaSource.Builder()
+          .params(params)
+          .zkStateReader(zkController.getZkStateReader())
+          .whitelistHostChecker(hostChecker)
+          .replicaListTransformer(replicaListTransformer)
+          .collection(cloudDescriptor.getCollectionName())
+          .onlyNrt(onlyNrt)
+          .build();
+      rb.slices = replicaSource.getSliceNames().toArray(new String[replicaSource.getSliceCount()]);
+
+      if (canShortCircuit(rb.slices, onlyNrt, params, cloudDescriptor)) {
+        rb.isDistrib = false;
+        rb.shortCircuitedURL = ZkCoreNodeProps.getCoreUrl(zkController.getBaseUrl(), coreDescriptor.getName());
+        return;
         // We shouldn't need to do anything to handle "shard.rows" since it was previously meant to be an optimization?
       }
-      
-      if (clusterState == null && zkController != null) {
-        clusterState =  zkController.getClusterState();
-      }
-
 
-      for (int i=0; i<rb.shards.length; i++) {
-        if (rb.shards[i] != null) {
-          final List<String> shardUrls = StrUtils.splitSmart(rb.shards[i], "|", true);
-          replicaListTransformer.transform(shardUrls);
-          hostChecker.checkWhitelist(clusterState, shards, shardUrls);
-          // And now recreate the | delimited list of equivalent servers
-          rb.shards[i] = createSliceShardsStr(shardUrls);
-        } else {
-          if (slices == null) {
-            slices = clusterState.getCollection(cloudDescriptor.getCollectionName()).getSlicesMap();
-          }
-          String sliceName = rb.slices[i];
-
-          Slice slice = slices.get(sliceName);
-
-          if (slice==null) {
-            // Treat this the same as "all servers down" for a slice, and let things continue
-            // if partial results are acceptable
-            rb.shards[i] = "";
-            continue;
-            // throw new SolrException(SolrException.ErrorCode.BAD_REQUEST, "no such shard: " + sliceName);
-          }
-          final Predicate<Replica> isShardLeader = new Predicate<Replica>() {
-            private Replica shardLeader = null;
-
-            @Override
-            public boolean test(Replica replica) {
-              if (shardLeader == null) {
-                try {
-                  shardLeader = zkController.getZkStateReader().getLeaderRetry(cloudDescriptor.getCollectionName(), slice.getName());
-                } catch (InterruptedException e) {
-                  throw new SolrException(SolrException.ErrorCode.SERVICE_UNAVAILABLE, "Exception finding leader for shard " + slice.getName() + " in collection " 
-                      + cloudDescriptor.getCollectionName(), e);
-                } catch (SolrException e) {
-                  if (log.isDebugEnabled()) {
-                    log.debug("Exception finding leader for shard {} in collection {}. Collection State: {}", 
-                        slice.getName(), cloudDescriptor.getCollectionName(), zkController.getZkStateReader().getClusterState().getCollectionOrNull(cloudDescriptor.getCollectionName()));
-                  }
-                  throw e;
-                }
-              }
-              return replica.getName().equals(shardLeader.getName());
-            }
-          };
-
-          final List<Replica> eligibleSliceReplicas = collectEligibleReplicas(slice, clusterState, onlyNrtReplicas, isShardLeader);
-
-          final List<String> shardUrls = transformReplicasToShardUrls(replicaListTransformer, eligibleSliceReplicas);
-
-          if (hostChecker.isWhitelistHostCheckingEnabled() && hostChecker.hasExplicitWhitelist()) {
-            /*
-             * We only need to check the host whitelist if there is an explicit whitelist (other than all the live nodes)
-             * when the "shards" indicate cluster state elements only
-             */
-            hostChecker.checkWhitelist(clusterState, shards, shardUrls);
-          }
-
-          // And now recreate the | delimited list of equivalent servers
-          final String sliceShardsStr = createSliceShardsStr(shardUrls);
-          if (sliceShardsStr.isEmpty()) {
-            boolean tolerant = ShardParams.getShardsTolerantAsBool(rb.req.getParams());
-            if (!tolerant) {
-              // stop the check when there are no replicas available for a shard
-              throw new SolrException(SolrException.ErrorCode.SERVICE_UNAVAILABLE,
-                  "no servers hosting shard: " + rb.slices[i]);
-            }
-          }
-          rb.shards[i] = sliceShardsStr;
+      for (int i = 0; i < rb.slices.length; i++) {
+        if (!ShardParams.getShardsTolerantAsBool(params) && replicaSource.getReplicasBySlice(i).isEmpty()) {
+          // stop the check when there are no replicas available for a shard
+          // todo fix use of slices[i] which can be null if user specified urls in shards param
+          throw new SolrException(SolrException.ErrorCode.SERVICE_UNAVAILABLE,
+              "no servers hosting shard: " + rb.slices[i]);
         }
       }
     } else {
-      if (shards != null) {
-        // No cloud, verbatim check of shards
-        hostChecker.checkWhitelist(shards, new ArrayList<>(Arrays.asList(shards.split("[,|]"))));
-      }
+      replicaSource = new LegacyReplicaSource.Builder()
+          .whitelistHostChecker(hostChecker)
+          .shards(shards)
+          .build();
+      rb.slices = new String[replicaSource.getSliceCount()];
+    }
+
+    rb.shards = new String[rb.slices.length];
+    for (int i = 0; i < rb.slices.length; i++) {
+      rb.shards[i] = createSliceShardsStr(replicaSource.getReplicasBySlice(i));
     }
+
     String shards_rows = params.get(ShardParams.SHARDS_ROWS);
-    if(shards_rows != null) {
+    if (shards_rows != null) {
       rb.shards_rows = Integer.parseInt(shards_rows);
     }
     String shards_start = params.get(ShardParams.SHARDS_START);
-    if(shards_start != null) {
+    if (shards_start != null) {
       rb.shards_start = Integer.parseInt(shards_start);
     }
   }
 
-  private static List<Replica> collectEligibleReplicas(Slice slice, ClusterState clusterState, boolean onlyNrtReplicas, Predicate<Replica> isShardLeader) {
-    final Collection<Replica> allSliceReplicas = slice.getReplicasMap().values();
-    final List<Replica> eligibleSliceReplicas = new ArrayList<>(allSliceReplicas.size());
-    for (Replica replica : allSliceReplicas) {
-      if (!clusterState.liveNodesContain(replica.getNodeName())
-          || replica.getState() != Replica.State.ACTIVE
-          || (onlyNrtReplicas && replica.getType() == Replica.Type.PULL)) {
-        continue;
-      }
-
-      if (onlyNrtReplicas && replica.getType() == Replica.Type.TLOG) {
-        if (!isShardLeader.test(replica)) {
-          continue;
-        }
-      }
-      eligibleSliceReplicas.add(replica);
-    }
-    return eligibleSliceReplicas;
-  }
-
-  private static List<String> transformReplicasToShardUrls(final ReplicaListTransformer replicaListTransformer, final List<Replica> eligibleSliceReplicas) {
-    replicaListTransformer.transform(eligibleSliceReplicas);
-
-    final List<String> shardUrls = new ArrayList<>(eligibleSliceReplicas.size());
-    for (Replica replica : eligibleSliceReplicas) {
-      String url = ZkCoreNodeProps.getCoreUrl(replica);
-      shardUrls.add(url);
-    }
-    return shardUrls;
-  }
-
   private static String createSliceShardsStr(final List<String> shardUrls) {
     final StringBuilder sliceShardsStr = new StringBuilder();
     boolean first = true;
@@ -554,17 +362,28 @@ public class HttpShardHandler extends ShardHandler {
     return sliceShardsStr.toString();
   }
 
-
-  private void addSlices(Map<String,Slice> target, ClusterState state, SolrParams params, String collectionName, String shardKeys, boolean multiCollection) {
-    DocCollection coll = state.getCollection(collectionName);
-    Collection<Slice> slices = coll.getRouter().getSearchSlices(shardKeys, params , coll);
-    ClientUtils.addSlices(target, collectionName, slices, multiCollection);
+  private boolean canShortCircuit(String[] slices, boolean onlyNrtReplicas, SolrParams params, CloudDescriptor cloudDescriptor) {
+    // Are we hosting the shard that this request is for, and are we active? If so, then handle it ourselves
+    // and make it a non-distributed request.
+    String ourSlice = cloudDescriptor.getShardId();
+    String ourCollection = cloudDescriptor.getCollectionName();
+    // Some requests may only be fulfilled by replicas of type Replica.Type.NRT
+    if (slices.length == 1 && slices[0] != null
+        && (slices[0].equals(ourSlice) || slices[0].equals(ourCollection + "_" + ourSlice))  // handle the <collection>_<slice> format
+        && cloudDescriptor.getLastPublished() == Replica.State.ACTIVE
+        && (!onlyNrtReplicas || cloudDescriptor.getReplicaType() == Replica.Type.NRT)) {
+      boolean shortCircuit = params.getBool("shortCircuit", true);       // currently just a debugging parameter to check distrib search on a single node
+
+      String targetHandler = params.get(ShardParams.SHARDS_QT);
+      shortCircuit = shortCircuit && targetHandler == null;             // if a different handler is specified, don't short-circuit
+
+      return shortCircuit;
+    }
+    return false;
   }
 
-  public ShardHandlerFactory getShardHandlerFactory(){
+  public ShardHandlerFactory getShardHandlerFactory() {
     return httpShardHandlerFactory;
   }
 
-
-
 }
diff --git a/solr/core/src/java/org/apache/solr/handler/component/HttpShardHandlerFactory.java b/solr/core/src/java/org/apache/solr/handler/component/HttpShardHandlerFactory.java
index c6c6d4f..96b9b94 100644
--- a/solr/core/src/java/org/apache/solr/handler/component/HttpShardHandlerFactory.java
+++ b/solr/core/src/java/org/apache/solr/handler/component/HttpShardHandlerFactory.java
@@ -419,8 +419,8 @@ public class HttpShardHandlerFactory extends ShardHandlerFactory implements org.
   /**
    * Creates a new completion service for use by a single set of distributed requests.
    */
-  public CompletionService newCompletionService() {
-    return new ExecutorCompletionService<ShardResponse>(commExecutor);
+  public CompletionService<ShardResponse> newCompletionService() {
+    return new ExecutorCompletionService<>(commExecutor);
   }
 
   /**
diff --git a/solr/core/src/java/org/apache/solr/handler/component/LegacyReplicaSource.java b/solr/core/src/java/org/apache/solr/handler/component/LegacyReplicaSource.java
new file mode 100644
index 0000000..a67d3a5
--- /dev/null
+++ b/solr/core/src/java/org/apache/solr/handler/component/LegacyReplicaSource.java
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.solr.handler.component;
+
+import java.util.Collections;
+import java.util.List;
+
+import org.apache.solr.common.util.StrUtils;
+
+/**
+ * A replica source for solr stand alone mode
+ */
+class LegacyReplicaSource implements ReplicaSource {
+  private List<String>[] replicas;
+
+  public LegacyReplicaSource(Builder builder) {
+    List<String> list = StrUtils.splitSmart(builder.shardsParam, ",", true);
+    replicas = new List[list.size()];
+    for (int i = 0; i < list.size(); i++) {
+      replicas[i] = StrUtils.splitSmart(list.get(i), "|", true);
+      // todo do we really not need to transform in non-cloud mode?!
+      // builder.replicaListTransformer.transform(replicas[i]);
+      builder.hostChecker.checkWhitelist(builder.shardsParam, replicas[i]);
+    }
+  }
+
+  @Override
+  public List<String> getSliceNames() {
+    // there are no logical slice names in non-cloud
+    return Collections.emptyList();
+  }
+
+  @Override
+  public int getSliceCount() {
+    return replicas.length;
+  }
+
+  @Override
+  public List<String> getReplicasBySlice(int sliceNumber) {
+    assert sliceNumber >= 0 && sliceNumber < replicas.length;
+    return replicas[sliceNumber];
+  }
+
+  static class Builder {
+    private String shardsParam;
+    private HttpShardHandlerFactory.WhitelistHostChecker hostChecker;
+
+    public Builder shards(String shardsParam) {
+      this.shardsParam = shardsParam;
+      return this;
+    }
+
+    public Builder whitelistHostChecker(HttpShardHandlerFactory.WhitelistHostChecker hostChecker) {
+      this.hostChecker = hostChecker;
+      return this;
+    }
+
+    public LegacyReplicaSource build() {
+      return new LegacyReplicaSource(this);
+    }
+  }
+}
diff --git a/solr/test-framework/src/java/org/apache/solr/cloud/MockZkStateReader.java b/solr/core/src/java/org/apache/solr/handler/component/ReplicaSource.java
similarity index 59%
copy from solr/test-framework/src/java/org/apache/solr/cloud/MockZkStateReader.java
copy to solr/core/src/java/org/apache/solr/handler/component/ReplicaSource.java
index b0ba518..979e69b 100644
--- a/solr/test-framework/src/java/org/apache/solr/cloud/MockZkStateReader.java
+++ b/solr/core/src/java/org/apache/solr/handler/component/ReplicaSource.java
@@ -14,26 +14,29 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-package org.apache.solr.cloud;
 
-import java.util.Set;
+package org.apache.solr.handler.component;
 
-import org.apache.solr.common.cloud.ClusterState;
-import org.apache.solr.common.cloud.ZkStateReader;
+import java.util.List;
 
-// does not yet mock zkclient at all
-public class MockZkStateReader extends ZkStateReader {
-
-  private Set<String> collections;
+/**
+ * A source of slices and corresponding replicas required to execute a request.
+ *
+ * @lucene.experimental
+ */
+interface ReplicaSource {
+  /**
+   * @return the list of slice names
+   */
+  List<String> getSliceNames();
 
-  public MockZkStateReader(ClusterState clusterState, Set<String> collections) {
-    super(new MockSolrZkClient());
-    this.clusterState = clusterState;
-    this.collections = collections;
-  }
-  
-  public Set<String> getAllCollections(){
-    return collections;
-  }
+  /**
+   * Get the list of replica urls for a 0-indexed slice number.
+   */
+  List<String> getReplicasBySlice(int sliceNumber);
 
+  /**
+   * @return the count of slices
+   */
+  int getSliceCount();
 }
diff --git a/solr/core/src/test/org/apache/solr/cloud/ClusterStateMockUtil.java b/solr/core/src/test/org/apache/solr/cloud/ClusterStateMockUtil.java
index 90df39a..f41d80a 100644
--- a/solr/core/src/test/org/apache/solr/cloud/ClusterStateMockUtil.java
+++ b/solr/core/src/test/org/apache/solr/cloud/ClusterStateMockUtil.java
@@ -158,9 +158,26 @@ public class ClusterStateMockUtil {
           Map<String, Object> replicaPropMap = makeReplicaProps(sliceName, node, replicaName, stateCode, m.group(1));
           if (collName == null) collName = "collection" + (collectionStates.size() + 1);
           if (sliceName == null) collName = "slice" + (slices.size() + 1);
-          replica = new Replica(replicaName, replicaPropMap, collName, sliceName);
 
+          // O(n^2) alert! but this is for mocks and testing so shouldn't be used for very large cluster states
+          boolean leaderFound = false;
+          for (Map.Entry<String, Replica> entry : replicas.entrySet()) {
+            Replica value = entry.getValue();
+            if ("true".equals(value.get(Slice.LEADER)))  {
+              leaderFound = true;
+              break;
+            }
+          }
+          if (!leaderFound && !m.group(1).equals("p")) {
+            replicaPropMap.put(Slice.LEADER, "true");
+          }
+          replica = new Replica(replicaName, replicaPropMap, collName, sliceName);
           replicas.put(replica.getName(), replica);
+
+          // hack alert: re-create slice with existing data and new replicas map so that it updates its internal leader attribute
+          slice = new Slice(slice.getName(), replicas, null, collName);
+          slices.put(slice.getName(), slice);
+          // we don't need to update doc collection again because we aren't adding a new slice or changing its state
           break;
         default:
           break;
diff --git a/solr/core/src/test/org/apache/solr/handler/component/CloudReplicaSourceTest.java b/solr/core/src/test/org/apache/solr/handler/component/CloudReplicaSourceTest.java
new file mode 100644
index 0000000..186af33
--- /dev/null
+++ b/solr/core/src/test/org/apache/solr/handler/component/CloudReplicaSourceTest.java
@@ -0,0 +1,263 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.solr.handler.component;
+
+import java.util.List;
+
+import org.apache.solr.SolrTestCaseJ4;
+import org.apache.solr.client.solrj.routing.ReplicaListTransformer;
+import org.apache.solr.cloud.ClusterStateMockUtil;
+import org.apache.solr.common.cloud.ZkStateReader;
+import org.apache.solr.common.params.ModifiableSolrParams;
+import org.junit.BeforeClass;
+import org.junit.Test;
+import org.mockito.Mockito;
+
+/**
+ * Tests for {@link CloudReplicaSource}
+ */
+public class CloudReplicaSourceTest extends SolrTestCaseJ4 {
+
+  @BeforeClass
+  public static void setup() {
+    assumeWorkingMockito();
+  }
+
+  @Test
+  public void testSimple_ShardsParam() {
+    ReplicaListTransformer replicaListTransformer = Mockito.mock(ReplicaListTransformer.class);
+    HttpShardHandlerFactory.WhitelistHostChecker whitelistHostChecker = Mockito.mock(HttpShardHandlerFactory.WhitelistHostChecker.class);
+    ModifiableSolrParams params = new ModifiableSolrParams();
+    params.set("shards", "slice1,slice2");
+    try (ZkStateReader zkStateReader = ClusterStateMockUtil.buildClusterState("csr*sr2", "baseUrl1_", "baseUrl2_")) {
+      CloudReplicaSource cloudReplicaSource = new CloudReplicaSource.Builder()
+          .collection("collection1")
+          .onlyNrt(false)
+          .zkStateReader(zkStateReader)
+          .replicaListTransformer(replicaListTransformer)
+          .whitelistHostChecker(whitelistHostChecker)
+          .params(params)
+          .build();
+      assertEquals(2, cloudReplicaSource.getSliceCount());
+      assertEquals(2, cloudReplicaSource.getSliceNames().size());
+      assertEquals(1, cloudReplicaSource.getReplicasBySlice(0).size());
+      assertEquals("http://baseUrl1/slice1_replica1/", cloudReplicaSource.getReplicasBySlice(0).get(0));
+      assertEquals(1, cloudReplicaSource.getReplicasBySlice(1).size());
+      assertEquals("http://baseUrl2/slice2_replica2/", cloudReplicaSource.getReplicasBySlice(1).get(0));
+    }
+  }
+
+  @Test
+  public void testShardsParam_DeadNode() {
+    ReplicaListTransformer replicaListTransformer = Mockito.mock(ReplicaListTransformer.class);
+    HttpShardHandlerFactory.WhitelistHostChecker whitelistHostChecker = Mockito.mock(HttpShardHandlerFactory.WhitelistHostChecker.class);
+    ModifiableSolrParams params = new ModifiableSolrParams();
+    params.set("shards", "slice1,slice2");
+    // here node2 is not live so there should be no replicas found for slice2
+    try (ZkStateReader zkStateReader = ClusterStateMockUtil.buildClusterState("csr*sr2", "baseUrl1_")) {
+      CloudReplicaSource cloudReplicaSource = new CloudReplicaSource.Builder()
+          .collection("collection1")
+          .onlyNrt(false)
+          .zkStateReader(zkStateReader)
+          .replicaListTransformer(replicaListTransformer)
+          .whitelistHostChecker(whitelistHostChecker)
+          .params(params)
+          .build();
+      assertEquals(2, cloudReplicaSource.getSliceCount());
+      assertEquals(2, cloudReplicaSource.getSliceNames().size());
+      assertEquals(1, cloudReplicaSource.getReplicasBySlice(0).size());
+      assertEquals("http://baseUrl1/slice1_replica1/", cloudReplicaSource.getReplicasBySlice(0).get(0));
+      assertEquals(0, cloudReplicaSource.getReplicasBySlice(1).size());
+    }
+  }
+
+  @Test
+  public void testShardsParam_DownReplica() {
+    ReplicaListTransformer replicaListTransformer = Mockito.mock(ReplicaListTransformer.class);
+    HttpShardHandlerFactory.WhitelistHostChecker whitelistHostChecker = Mockito.mock(HttpShardHandlerFactory.WhitelistHostChecker.class);
+    ModifiableSolrParams params = new ModifiableSolrParams();
+    params.set("shards", "slice1,slice2");
+    // here replica3 is in DOWN state so only 1 replica should be returned for slice2
+    try (ZkStateReader zkStateReader = ClusterStateMockUtil.buildClusterState("csr*sr2r3D", "baseUrl1_", "baseUrl2_", "baseUrl3_")) {
+      CloudReplicaSource cloudReplicaSource = new CloudReplicaSource.Builder()
+          .collection("collection1")
+          .onlyNrt(false)
+          .zkStateReader(zkStateReader)
+          .replicaListTransformer(replicaListTransformer)
+          .whitelistHostChecker(whitelistHostChecker)
+          .params(params)
+          .build();
+      assertEquals(2, cloudReplicaSource.getSliceCount());
+      assertEquals(2, cloudReplicaSource.getSliceNames().size());
+      assertEquals(1, cloudReplicaSource.getReplicasBySlice(0).size());
+      assertEquals("http://baseUrl1/slice1_replica1/", cloudReplicaSource.getReplicasBySlice(0).get(0));
+      assertEquals(1, cloudReplicaSource.getReplicasBySlice(1).size());
+      assertEquals(1, cloudReplicaSource.getReplicasBySlice(1).size());
+      assertEquals("http://baseUrl2/slice2_replica2/", cloudReplicaSource.getReplicasBySlice(1).get(0));
+    }
+  }
+
+  @Test
+  public void testMultipleCollections() {
+    ReplicaListTransformer replicaListTransformer = Mockito.mock(ReplicaListTransformer.class);
+    HttpShardHandlerFactory.WhitelistHostChecker whitelistHostChecker = Mockito.mock(HttpShardHandlerFactory.WhitelistHostChecker.class);
+    ModifiableSolrParams params = new ModifiableSolrParams();
+    params.set("collection", "collection1,collection2");
+    try (ZkStateReader zkStateReader = ClusterStateMockUtil.buildClusterState("csr*sr2csr*", "baseUrl1_", "baseUrl2_")) {
+      CloudReplicaSource cloudReplicaSource = new CloudReplicaSource.Builder()
+          .collection("collection1")
+          .onlyNrt(false)
+          .zkStateReader(zkStateReader)
+          .replicaListTransformer(replicaListTransformer)
+          .whitelistHostChecker(whitelistHostChecker)
+          .params(params)
+          .build();
+      assertEquals(3, cloudReplicaSource.getSliceCount());
+      List<String> sliceNames = cloudReplicaSource.getSliceNames();
+      assertEquals(3, sliceNames.size());
+      for (int i = 0; i < cloudReplicaSource.getSliceCount(); i++) {
+        String sliceName = sliceNames.get(i);
+        assertEquals(1, cloudReplicaSource.getReplicasBySlice(i).size());
+
+        // need a switch here because unlike the testShards* tests which always returns slices in the order they were specified,
+        // using the collection param can return slice names in any order
+        switch (sliceName) {
+          case "collection1_slice1":
+            assertEquals("http://baseUrl1/slice1_replica1/", cloudReplicaSource.getReplicasBySlice(i).get(0));
+            break;
+          case "collection1_slice2":
+            assertEquals("http://baseUrl2/slice2_replica2/", cloudReplicaSource.getReplicasBySlice(i).get(0));
+            break;
+          case "collection2_slice1":
+            assertEquals("http://baseUrl1/slice1_replica3/", cloudReplicaSource.getReplicasBySlice(i).get(0));
+            break;
+        }
+      }
+    }
+  }
+
+  @Test
+  public void testSimple_UsingClusterState() {
+    ReplicaListTransformer replicaListTransformer = Mockito.mock(ReplicaListTransformer.class);
+    HttpShardHandlerFactory.WhitelistHostChecker whitelistHostChecker = Mockito.mock(HttpShardHandlerFactory.WhitelistHostChecker.class);
+    ModifiableSolrParams params = new ModifiableSolrParams();
+    try (ZkStateReader zkStateReader = ClusterStateMockUtil.buildClusterState("csr*sr2", "baseUrl1_", "baseUrl2_")) {
+      CloudReplicaSource cloudReplicaSource = new CloudReplicaSource.Builder()
+          .collection("collection1")
+          .onlyNrt(false)
+          .zkStateReader(zkStateReader)
+          .replicaListTransformer(replicaListTransformer)
+          .whitelistHostChecker(whitelistHostChecker)
+          .params(params)
+          .build();
+      assertEquals(2, cloudReplicaSource.getSliceCount());
+      List<String> sliceNames = cloudReplicaSource.getSliceNames();
+      assertEquals(2, sliceNames.size());
+      for (int i = 0; i < cloudReplicaSource.getSliceCount(); i++) {
+        String sliceName = sliceNames.get(i);
+        assertEquals(1, cloudReplicaSource.getReplicasBySlice(i).size());
+
+        // need to switch because without a shards param, the order of slices is not deterministic
+        switch (sliceName) {
+          case "slice1":
+            assertEquals("http://baseUrl1/slice1_replica1/", cloudReplicaSource.getReplicasBySlice(i).get(0));
+            break;
+          case "slice2":
+            assertEquals("http://baseUrl2/slice2_replica2/", cloudReplicaSource.getReplicasBySlice(i).get(0));
+            break;
+        }
+      }
+    }
+  }
+
+  @Test
+  public void testSimple_OnlyNrt() {
+    ReplicaListTransformer replicaListTransformer = Mockito.mock(ReplicaListTransformer.class);
+    HttpShardHandlerFactory.WhitelistHostChecker whitelistHostChecker = Mockito.mock(HttpShardHandlerFactory.WhitelistHostChecker.class);
+    ModifiableSolrParams params = new ModifiableSolrParams();
+    // the cluster state will have slice2 with two tlog replicas out of which the first one will be the leader
+    try (ZkStateReader zkStateReader = ClusterStateMockUtil.buildClusterState("csrr*st2t2", "baseUrl1_", "baseUrl2_")) {
+      CloudReplicaSource cloudReplicaSource = new CloudReplicaSource.Builder()
+          .collection("collection1")
+          .onlyNrt(true) // enable only nrt mode
+          .zkStateReader(zkStateReader)
+          .replicaListTransformer(replicaListTransformer)
+          .whitelistHostChecker(whitelistHostChecker)
+          .params(params)
+          .build();
+      assertEquals(2, cloudReplicaSource.getSliceCount());
+      List<String> sliceNames = cloudReplicaSource.getSliceNames();
+      assertEquals(2, sliceNames.size());
+      for (int i = 0; i < cloudReplicaSource.getSliceCount(); i++) {
+        String sliceName = sliceNames.get(i);
+        // need to switch because without a shards param, the order of slices is not deterministic
+        switch (sliceName) {
+          case "slice1":
+            assertEquals(2, cloudReplicaSource.getReplicasBySlice(i).size());
+            assertEquals("http://baseUrl1/slice1_replica1/", cloudReplicaSource.getReplicasBySlice(i).get(0));
+            break;
+          case "slice2":
+            assertEquals(1, cloudReplicaSource.getReplicasBySlice(i).size());
+            assertEquals("http://baseUrl2/slice2_replica3/", cloudReplicaSource.getReplicasBySlice(i).get(0));
+            break;
+        }
+      }
+    }
+  }
+
+  @Test
+  public void testMultipleCollections_OnlyNrt() {
+    ReplicaListTransformer replicaListTransformer = Mockito.mock(ReplicaListTransformer.class);
+    HttpShardHandlerFactory.WhitelistHostChecker whitelistHostChecker = Mockito.mock(HttpShardHandlerFactory.WhitelistHostChecker.class);
+    ModifiableSolrParams params = new ModifiableSolrParams();
+    params.set("collection", "collection1,collection2");
+    // the cluster state will have collection1 with slice2 with two tlog replicas out of which the first one will be the leader
+    // and collection2 with just a single slice and a tlog replica that will be leader
+    try (ZkStateReader zkStateReader = ClusterStateMockUtil.buildClusterState("csrr*st2t2cst", "baseUrl1_", "baseUrl2_")) {
+      CloudReplicaSource cloudReplicaSource = new CloudReplicaSource.Builder()
+          .collection("collection1")
+          .onlyNrt(true) // enable only nrt mode
+          .zkStateReader(zkStateReader)
+          .replicaListTransformer(replicaListTransformer)
+          .whitelistHostChecker(whitelistHostChecker)
+          .params(params)
+          .build();
+      assertEquals(3, cloudReplicaSource.getSliceCount());
+      List<String> sliceNames = cloudReplicaSource.getSliceNames();
+      assertEquals(3, sliceNames.size());
+      for (int i = 0; i < cloudReplicaSource.getSliceCount(); i++) {
+        String sliceName = sliceNames.get(i);
+        // need to switch because without a shards param, the order of slices is not deterministic
+        switch (sliceName) {
+          case "collection1_slice1":
+            assertEquals(2, cloudReplicaSource.getReplicasBySlice(i).size());
+            assertEquals("http://baseUrl1/slice1_replica1/", cloudReplicaSource.getReplicasBySlice(i).get(0));
+            break;
+          case "collection1_slice2":
+            assertEquals(1, cloudReplicaSource.getReplicasBySlice(i).size());
+            assertEquals("http://baseUrl2/slice2_replica3/", cloudReplicaSource.getReplicasBySlice(i).get(0));
+            break;
+          case "collection2_slice1":
+            assertEquals(1, cloudReplicaSource.getReplicasBySlice(i).size());
+            assertEquals("http://baseUrl1/slice1_replica5/", cloudReplicaSource.getReplicasBySlice(i).get(0));
+            break;
+        }
+      }
+    }
+  }
+}
diff --git a/solr/test-framework/src/java/org/apache/solr/cloud/MockZkStateReader.java b/solr/test-framework/src/java/org/apache/solr/cloud/MockZkStateReader.java
index b0ba518..2f73792 100644
--- a/solr/test-framework/src/java/org/apache/solr/cloud/MockZkStateReader.java
+++ b/solr/test-framework/src/java/org/apache/solr/cloud/MockZkStateReader.java
@@ -19,6 +19,7 @@ package org.apache.solr.cloud;
 import java.util.Set;
 
 import org.apache.solr.common.cloud.ClusterState;
+import org.apache.solr.common.cloud.DocCollectionWatcher;
 import org.apache.solr.common.cloud.ZkStateReader;
 
 // does not yet mock zkclient at all
@@ -36,4 +37,10 @@ public class MockZkStateReader extends ZkStateReader {
     return collections;
   }
 
+  @Override
+  public void registerDocCollectionWatcher(String collection, DocCollectionWatcher stateWatcher) {
+    // the doc collection will never be changed by this mock
+    // so we just call onStateChanged once with the existing DocCollection object an return
+    stateWatcher.onStateChanged(clusterState.getCollectionOrNull(collection));
+  }
 }