You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@solr.apache.org by no...@apache.org on 2022/09/20 05:53:40 UTC

[solr] branch main updated: SOLR-16414: Race condition in PRS state updates (#1019)

This is an automated email from the ASF dual-hosted git repository.

noble pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/solr.git


The following commit(s) were added to refs/heads/main by this push:
     new f2a9aa5f609 SOLR-16414: Race condition in PRS state updates (#1019)
f2a9aa5f609 is described below

commit f2a9aa5f609916b865e274f7f1e25dbce3d943a5
Author: Noble Paul <no...@users.noreply.github.com>
AuthorDate: Tue Sep 20 15:53:34 2022 +1000

    SOLR-16414: Race condition in PRS state updates (#1019)
---
 solr/CHANGES.txt                                   |  2 +
 .../solr/cloud/ShardLeaderElectionContextBase.java |  3 +-
 .../java/org/apache/solr/cloud/ZkController.java   | 48 +++++++++++----
 .../cloud/api/collections/CreateCollectionCmd.java |  3 +-
 .../solr/cloud/overseer/CollectionMutator.java     | 45 ++++++++++++--
 .../apache/solr/cloud/overseer/NodeMutator.java    |  1 +
 .../apache/solr/cloud/overseer/ReplicaMutator.java | 43 ++++++--------
 .../apache/solr/cloud/overseer/SliceMutator.java   | 30 ++--------
 .../solr/common/cloud/PerReplicaStatesOps.java     | 68 ++++++++++++----------
 solr/solrj/build.gradle                            |  1 +
 .../cloud/PerReplicaStatesIntegrationTest.java     | 53 +++++++++++++++++
 11 files changed, 198 insertions(+), 99 deletions(-)

diff --git a/solr/CHANGES.txt b/solr/CHANGES.txt
index ad44c1ad9b0..3467a8f85c0 100644
--- a/solr/CHANGES.txt
+++ b/solr/CHANGES.txt
@@ -155,6 +155,8 @@ Bug Fixes
 
 * SOLR-16417: NPE if facet query hits timeout or exception (Kevin Risden)
 
+* SOLR-16414 : Race condition in PRS state updates (noble, Justin Sweeney, Patson Luk, Hitesh Khamesra, Ishan Chattopadhyaya)
+
 Other Changes
 ---------------------
 * SOLR-16351: Upgrade Carrot2 to 4.4.3, upgrade randomizedtesting to 2.8.0. (Dawid Weiss)
diff --git a/solr/core/src/java/org/apache/solr/cloud/ShardLeaderElectionContextBase.java b/solr/core/src/java/org/apache/solr/cloud/ShardLeaderElectionContextBase.java
index b2ae127eb2a..3c37c34f042 100644
--- a/solr/core/src/java/org/apache/solr/cloud/ShardLeaderElectionContextBase.java
+++ b/solr/core/src/java/org/apache/solr/cloud/ShardLeaderElectionContextBase.java
@@ -237,7 +237,8 @@ class ShardLeaderElectionContextBase extends ElectionContext {
         } else {
           zkController.getOverseer().offerStateUpdate(Utils.toJSON(m));
         }
-      } else {
+      }
+      if (coll != null && coll.isPerReplicaState()) {
         PerReplicaStates prs =
             PerReplicaStatesFetcher.fetch(coll.getZNode(), zkClient, coll.getPerReplicaStates());
         PerReplicaStatesOps.flipLeader(
diff --git a/solr/core/src/java/org/apache/solr/cloud/ZkController.java b/solr/core/src/java/org/apache/solr/cloud/ZkController.java
index 1307abfada8..267fc55cc0b 100644
--- a/solr/core/src/java/org/apache/solr/cloud/ZkController.java
+++ b/solr/core/src/java/org/apache/solr/cloud/ZkController.java
@@ -1781,10 +1781,10 @@ public class ZkController implements Closeable {
         } else {
           overseerJobQueue.offer(Utils.toJSON(m));
         }
-      } else {
-        if (log.isDebugEnabled()) {
-          log.debug("bypassed overseer for message : {}", Utils.toJSONString(m));
-        }
+      }
+      // extra handling for PRS, we need to write the PRS entries from this node directly,
+      // as overseer does not and should not handle those entries
+      if (coll != null && coll.isPerReplicaState() && coreNodeName != null) {
         PerReplicaStates perReplicaStates =
             PerReplicaStatesFetcher.fetch(coll.getZNode(), zkClient, coll.getPerReplicaStates());
         PerReplicaStatesOps.flipState(coreNodeName, state, perReplicaStates)
@@ -1806,7 +1806,6 @@ public class ZkController implements Closeable {
     if (r == null) return true;
     Slice shard = coll.getSlice(r.shard);
     if (shard == null) return true; // very unlikely
-    if (shard.getState() == Slice.State.RECOVERY) return true;
     if (shard.getParent() != null) return true;
     for (Slice slice : coll.getSlices()) {
       if (Objects.equals(shard.getName(), slice.getParent())) return true;
@@ -2921,13 +2920,40 @@ public class ZkController implements Closeable {
       // immediately return.
       distributedClusterStateUpdater.executeNodeDownStateUpdate(nodeName, zkStateReader);
     } else {
-      ZkNodeProps m =
-          new ZkNodeProps(
-              Overseer.QUEUE_OPERATION,
-              OverseerAction.DOWNNODE.toLower(),
-              ZkStateReader.NODE_NAME_PROP,
-              nodeName);
       try {
+        // Create a concurrently accessible set to avoid repeating collections
+        Set<String> processedCollections = ConcurrentHashMap.newKeySet();
+        cc.getCoreDescriptors().parallelStream()
+            .forEach(
+                cd -> {
+                  DocCollection coll = zkStateReader.getCollection(cd.getCollectionName());
+                  if (processedCollections.add(coll.getName()) && coll.isPerReplicaState()) {
+                    final List<String> replicasToDown = new ArrayList<>();
+                    coll.forEachReplica(
+                        (s, replica) -> {
+                          if (replica.getNodeName().equals(nodeName)) {
+                            replicasToDown.add(replica.getName());
+                          }
+                        });
+                    try {
+                      PerReplicaStatesOps.downReplicas(
+                              replicasToDown,
+                              PerReplicaStatesFetcher.fetch(
+                                  coll.getZNode(), zkClient, coll.getPerReplicaStates()))
+                          .persist(coll.getZNode(), zkClient);
+                    } catch (KeeperException | InterruptedException e) {
+                      throw new RuntimeException(e);
+                    }
+                  }
+                });
+        // We always send a down node event to overseer to be safe, but overseer will not need to do
+        // anything for PRS collections
+        ZkNodeProps m =
+            new ZkNodeProps(
+                Overseer.QUEUE_OPERATION,
+                OverseerAction.DOWNNODE.toLower(),
+                ZkStateReader.NODE_NAME_PROP,
+                nodeName);
         overseer.getStateUpdateQueue().offer(Utils.toJSON(m));
       } catch (AlreadyClosedException e) {
         log.info(
diff --git a/solr/core/src/java/org/apache/solr/cloud/api/collections/CreateCollectionCmd.java b/solr/core/src/java/org/apache/solr/cloud/api/collections/CreateCollectionCmd.java
index 6cbcd910a57..f18713ea427 100644
--- a/solr/core/src/java/org/apache/solr/cloud/api/collections/CreateCollectionCmd.java
+++ b/solr/core/src/java/org/apache/solr/cloud/api/collections/CreateCollectionCmd.java
@@ -184,6 +184,8 @@ public class CreateCollectionCmd implements CollApiCmds.CollectionApiCommand {
         // This code directly updates Zookeeper by creating the collection state.json. It is
         // compatible with both distributed cluster state updates and Overseer based cluster state
         // updates.
+
+        // TODO: Consider doing this for all collections, not just the PRS collections.
         ZkWriteCommand command =
             new ClusterStateMutator(ccc.getSolrCloudManager())
                 .createCollection(clusterState, message);
@@ -337,7 +339,6 @@ public class CreateCollectionCmd implements CollApiCmds.CollectionApiCommand {
           ZkWriteCommand command =
               new SliceMutator(ccc.getSolrCloudManager()).addReplica(clusterState, props);
           byte[] data = Utils.toJSON(Collections.singletonMap(collectionName, command.collection));
-          //        log.info("collection updated : {}", new String(data, StandardCharsets.UTF_8));
           zkStateReader.getZkClient().setData(collectionPath, data, true);
           clusterState = clusterState.copyWith(collectionName, command.collection);
           newColl = command.collection;
diff --git a/solr/core/src/java/org/apache/solr/cloud/overseer/CollectionMutator.java b/solr/core/src/java/org/apache/solr/cloud/overseer/CollectionMutator.java
index 1a6e1c9f583..2eb666661b4 100644
--- a/solr/core/src/java/org/apache/solr/cloud/overseer/CollectionMutator.java
+++ b/solr/core/src/java/org/apache/solr/cloud/overseer/CollectionMutator.java
@@ -33,6 +33,7 @@ import org.apache.solr.client.solrj.request.CollectionAdminRequest;
 import org.apache.solr.common.cloud.ClusterState;
 import org.apache.solr.common.cloud.DocCollection;
 import org.apache.solr.common.cloud.DocCollection.CollectionStateProps;
+import org.apache.solr.common.cloud.PerReplicaStates;
 import org.apache.solr.common.cloud.PerReplicaStatesFetcher;
 import org.apache.solr.common.cloud.PerReplicaStatesOps;
 import org.apache.solr.common.cloud.Replica;
@@ -127,9 +128,12 @@ public class CollectionMutator {
           log.error("trying to set perReplicaState to {} from {}", val, coll.isPerReplicaState());
           continue;
         }
+        PerReplicaStates prs = PerReplicaStatesFetcher.fetch(coll.getZNode(), zkClient, null);
         replicaOps =
-            PerReplicaStatesOps.modifyCollection(
-                coll, enable, PerReplicaStatesFetcher.fetch(coll.getZNode(), zkClient, null));
+            enable ? PerReplicaStatesOps.enable(coll, prs) : PerReplicaStatesOps.disable(prs);
+        if (!enable) {
+          coll = updateReplicas(coll, prs);
+        }
       }
 
       if (message.containsKey(prop)) {
@@ -166,8 +170,6 @@ public class CollectionMutator {
       return ZkStateWriter.NO_OP;
     }
 
-    assert !props.containsKey(COLL_CONF);
-
     DocCollection collection =
         new DocCollection(
             coll.getName(), coll.getSlicesMap(), props, coll.getRouter(), coll.getZNodeVersion());
@@ -178,6 +180,41 @@ public class CollectionMutator {
     }
   }
 
+  public static DocCollection updateReplicas(DocCollection coll, PerReplicaStates prs) {
+    // we are disabling PRS. Update the replica states
+    Map<String, Slice> modifiedSlices = new LinkedHashMap<>();
+    coll.forEachReplica(
+        (s, replica) -> {
+          PerReplicaStates.State prsState = prs.states.get(replica.getName());
+          if (prsState != null) {
+            if (prsState.state != replica.getState()) {
+              Slice slice =
+                  modifiedSlices.getOrDefault(
+                      replica.getShard(), coll.getSlice(replica.getShard()));
+              replica = ReplicaMutator.setState(replica, prsState.state.toString());
+              modifiedSlices.put(replica.getShard(), slice.copyWith(replica));
+            }
+            if (prsState.isLeader != replica.isLeader()) {
+              Slice slice =
+                  modifiedSlices.getOrDefault(
+                      replica.getShard(), coll.getSlice(replica.getShard()));
+              replica =
+                  prsState.isLeader
+                      ? ReplicaMutator.setLeader(replica)
+                      : ReplicaMutator.unsetLeader(replica);
+              modifiedSlices.put(replica.getShard(), slice.copyWith(replica));
+            }
+          }
+        });
+
+    if (!modifiedSlices.isEmpty()) {
+      Map<String, Slice> slices = new LinkedHashMap<>(coll.getSlicesMap());
+      slices.putAll(modifiedSlices);
+      return coll.copyWithSlices(slices);
+    }
+    return coll;
+  }
+
   public static DocCollection updateSlice(
       String collectionName, DocCollection collection, Slice slice) {
     Map<String, Slice> slices =
diff --git a/solr/core/src/java/org/apache/solr/cloud/overseer/NodeMutator.java b/solr/core/src/java/org/apache/solr/cloud/overseer/NodeMutator.java
index acd58a5e242..4c78f0d6718 100644
--- a/solr/core/src/java/org/apache/solr/cloud/overseer/NodeMutator.java
+++ b/solr/core/src/java/org/apache/solr/cloud/overseer/NodeMutator.java
@@ -58,6 +58,7 @@ public class NodeMutator {
     for (Map.Entry<String, DocCollection> entry : collections.entrySet()) {
       String collectionName = entry.getKey();
       DocCollection docCollection = entry.getValue();
+      if (docCollection.isPerReplicaState()) continue;
 
       Optional<ZkWriteCommand> zkWriteCommand =
           computeCollectionUpdate(nodeName, collectionName, docCollection, zkClient);
diff --git a/solr/core/src/java/org/apache/solr/cloud/overseer/ReplicaMutator.java b/solr/core/src/java/org/apache/solr/cloud/overseer/ReplicaMutator.java
index def546db285..9f5fd99a2dd 100644
--- a/solr/core/src/java/org/apache/solr/cloud/overseer/ReplicaMutator.java
+++ b/solr/core/src/java/org/apache/solr/cloud/overseer/ReplicaMutator.java
@@ -29,6 +29,7 @@ import java.util.List;
 import java.util.Locale;
 import java.util.Map;
 import java.util.NoSuchElementException;
+import java.util.Objects;
 import java.util.Set;
 import java.util.concurrent.TimeUnit;
 import org.apache.commons.lang3.StringUtils;
@@ -43,9 +44,6 @@ import org.apache.solr.cloud.api.collections.SplitShardCmd;
 import org.apache.solr.common.SolrException;
 import org.apache.solr.common.cloud.ClusterState;
 import org.apache.solr.common.cloud.DocCollection;
-import org.apache.solr.common.cloud.PerReplicaStates;
-import org.apache.solr.common.cloud.PerReplicaStatesFetcher;
-import org.apache.solr.common.cloud.PerReplicaStatesOps;
 import org.apache.solr.common.cloud.Replica;
 import org.apache.solr.common.cloud.Slice;
 import org.apache.solr.common.cloud.Slice.SliceStateProps;
@@ -71,7 +69,7 @@ public class ReplicaMutator {
     this.zkClient = getZkClient(cloudManager);
   }
 
-  protected Replica setProperty(Replica replica, String key, String value) {
+  static Replica setProperty(Replica replica, String key, String value) {
     assert key != null;
     assert value != null;
 
@@ -84,7 +82,7 @@ public class ReplicaMutator {
         replica.getName(), replicaProps, replica.getCollection(), replica.getShard());
   }
 
-  protected Replica unsetProperty(Replica replica, String key) {
+  static Replica unsetProperty(Replica replica, String key) {
     assert key != null;
 
     if (!replica.containsKey(key)) return replica;
@@ -94,15 +92,15 @@ public class ReplicaMutator {
         replica.getName(), replicaProps, replica.getCollection(), replica.getShard());
   }
 
-  protected Replica setLeader(Replica replica) {
+  static Replica setLeader(Replica replica) {
     return setProperty(replica, ZkStateReader.LEADER_PROP, "true");
   }
 
-  protected Replica unsetLeader(Replica replica) {
+  static Replica unsetLeader(Replica replica) {
     return unsetProperty(replica, ZkStateReader.LEADER_PROP);
   }
 
-  protected Replica setState(Replica replica, String state) {
+  static Replica setState(Replica replica, String state) {
     assert state != null;
 
     return setProperty(replica, ZkStateReader.STATE_PROP, state);
@@ -320,8 +318,6 @@ public class ReplicaMutator {
       log.info("Failed to update state because the replica does not exist, {}", message);
       return ZkStateWriter.NO_OP;
     }
-    boolean persistCollectionState = collection != null && collection.isPerReplicaState();
-
     if (coreNodeName == null) {
       coreNodeName =
           ClusterStateMutator.getAssignedCoreNodeName(
@@ -335,7 +331,6 @@ public class ReplicaMutator {
           log.info("Failed to update state because the replica does not exist, {}", message);
           return ZkStateWriter.NO_OP;
         }
-        persistCollectionState = true;
         // if coreNodeName is null, auto assign one
         coreNodeName = Assign.assignCoreNodeName(stateManager, collection);
       }
@@ -349,7 +344,6 @@ public class ReplicaMutator {
       if (sliceName != null) {
         log.debug("shard={} is already registered", sliceName);
       }
-      persistCollectionState = true;
     }
     if (sliceName == null) {
       // request new shardId
@@ -361,14 +355,14 @@ public class ReplicaMutator {
       }
       sliceName = Assign.assignShard(collection, numShards);
       log.info("Assigning new node to shard shard={}", sliceName);
-      persistCollectionState = true;
     }
 
     Slice slice = collection != null ? collection.getSlice(sliceName) : null;
 
+    Replica oldReplica = null;
     Map<String, Object> replicaProps = new LinkedHashMap<>(message.getProperties());
     if (slice != null) {
-      Replica oldReplica = slice.getReplica(coreNodeName);
+      oldReplica = slice.getReplica(coreNodeName);
       if (oldReplica != null) {
         if (oldReplica.containsKey(ZkStateReader.LEADER_PROP)) {
           replicaProps.put(ZkStateReader.LEADER_PROP, oldReplica.get(ZkStateReader.LEADER_PROP));
@@ -444,18 +438,17 @@ public class ReplicaMutator {
 
     DocCollection newCollection = CollectionMutator.updateSlice(collectionName, collection, slice);
     log.debug("Collection is now: {}", newCollection);
-    if (collection != null && collection.isPerReplicaState()) {
-      PerReplicaStates prs =
-          PerReplicaStatesFetcher.fetch(
-              collection.getZNode(), zkClient, collection.getPerReplicaStates());
-      return new ZkWriteCommand(
-          collectionName,
-          newCollection,
-          PerReplicaStatesOps.flipState(replica.getName(), replica.getState(), prs),
-          persistCollectionState);
-    } else {
-      return new ZkWriteCommand(collectionName, newCollection);
+    if (collection.isPerReplicaState() && oldReplica != null) {
+      if (!isAnyPropertyChanged(replica, oldReplica)) return ZkWriteCommand.NO_OP;
     }
+    return new ZkWriteCommand(collectionName, newCollection);
+  }
+
+  private boolean isAnyPropertyChanged(Replica replica, Replica oldReplica) {
+    if (!Objects.equals(replica.getBaseUrl(), oldReplica.getBaseUrl())) return true;
+    if (!Objects.equals(replica.getCoreName(), oldReplica.getCoreName())) return true;
+    if (!Objects.equals(replica.getNodeName(), oldReplica.getNodeName())) return true;
+    return false;
   }
 
   private DocCollection checkAndCompleteShardSplit(
diff --git a/solr/core/src/java/org/apache/solr/cloud/overseer/SliceMutator.java b/solr/core/src/java/org/apache/solr/cloud/overseer/SliceMutator.java
index ed26494cf4e..ea06967dad7 100644
--- a/solr/core/src/java/org/apache/solr/cloud/overseer/SliceMutator.java
+++ b/solr/core/src/java/org/apache/solr/cloud/overseer/SliceMutator.java
@@ -146,16 +146,7 @@ public class SliceMutator {
       }
       newSlices.put(slice.getName(), slice);
     }
-
-    if (coll.isPerReplicaState()) {
-      PerReplicaStatesOps replicaOps =
-          PerReplicaStatesOps.deleteReplica(
-              cnn,
-              PerReplicaStatesFetcher.fetch(coll.getZNode(), zkClient, coll.getPerReplicaStates()));
-      return new ZkWriteCommand(collection, coll.copyWithSlices(newSlices), replicaOps, true);
-    } else {
-      return new ZkWriteCommand(collection, coll.copyWithSlices(newSlices));
-    }
+    return new ZkWriteCommand(collection, coll.copyWithSlices(newSlices));
   }
 
   public ZkWriteCommand setShardLeader(ClusterState clusterState, ZkNodeProps message) {
@@ -185,9 +176,9 @@ public class SliceMutator {
               replica.getBaseUrl(), replica.getStr(ZkStateReader.CORE_NAME_PROP));
 
       if (replica == oldLeader && !coreURL.equals(leaderUrl)) {
-        replica = new ReplicaMutator(cloudManager).unsetLeader(replica);
+        replica = ReplicaMutator.unsetLeader(replica);
       } else if (coreURL.equals(leaderUrl)) {
-        newLeader = replica = new ReplicaMutator(cloudManager).setLeader(replica);
+        newLeader = replica = ReplicaMutator.setLeader(replica);
       }
 
       newReplicas.put(replica.getName(), replica);
@@ -196,19 +187,8 @@ public class SliceMutator {
     Map<String, Object> newSliceProps = slice.shallowCopy();
     newSliceProps.put(SliceStateProps.REPLICAS, newReplicas);
     slice = new Slice(slice.getName(), newReplicas, slice.getProperties(), collectionName);
-    if (coll.isPerReplicaState()) {
-      PerReplicaStates prs =
-          PerReplicaStatesFetcher.fetch(coll.getZNode(), zkClient, coll.getPerReplicaStates());
-      return new ZkWriteCommand(
-          collectionName,
-          CollectionMutator.updateSlice(collectionName, coll, slice),
-          PerReplicaStatesOps.flipLeader(
-              slice.getReplicaNames(), newLeader == null ? null : newLeader.getName(), prs),
-          false);
-    } else {
-      return new ZkWriteCommand(
-          collectionName, CollectionMutator.updateSlice(collectionName, coll, slice));
-    }
+    return new ZkWriteCommand(
+        collectionName, CollectionMutator.updateSlice(collectionName, coll, slice));
   }
 
   public ZkWriteCommand updateShardState(ClusterState clusterState, ZkNodeProps message) {
diff --git a/solr/solrj-zookeeper/src/java/org/apache/solr/common/cloud/PerReplicaStatesOps.java b/solr/solrj-zookeeper/src/java/org/apache/solr/common/cloud/PerReplicaStatesOps.java
index 7c5688230d4..a21d3be9387 100644
--- a/solr/solrj-zookeeper/src/java/org/apache/solr/common/cloud/PerReplicaStatesOps.java
+++ b/solr/solrj-zookeeper/src/java/org/apache/solr/common/cloud/PerReplicaStatesOps.java
@@ -41,6 +41,7 @@ public class PerReplicaStatesOps {
   private static final Logger log = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass());
   private PerReplicaStates rs;
   List<PerReplicaStates.Operation> ops;
+  private boolean preOp = true;
   final Function<PerReplicaStates, List<PerReplicaStates.Operation>> fun;
 
   PerReplicaStatesOps(Function<PerReplicaStates, List<PerReplicaStates.Operation>> fun) {
@@ -139,39 +140,42 @@ public class PerReplicaStatesOps {
         .init(rs);
   }
 
-  /** Switch a collection from/to perReplicaState=true */
-  public static PerReplicaStatesOps modifyCollection(
-      DocCollection coll, boolean enable, PerReplicaStates rs) {
-    return new PerReplicaStatesOps(prs -> enable ? enable(coll, prs) : disable(prs)).init(rs);
-  }
-
-  private static List<PerReplicaStates.Operation> enable(DocCollection coll, PerReplicaStates prs) {
-    log.info("ENABLING_PRS ");
-    List<PerReplicaStates.Operation> result = new ArrayList<>();
-    coll.forEachReplica(
-        (s, r) -> {
-          PerReplicaStates.State st = prs.get(r.getName());
-          int newVer = 0;
-          if (st != null) {
-            result.add(new PerReplicaStates.Operation(PerReplicaStates.Operation.Type.DELETE, st));
-            newVer = st.version + 1;
-          }
-          result.add(
-              new PerReplicaStates.Operation(
-                  PerReplicaStates.Operation.Type.ADD,
-                  new PerReplicaStates.State(r.getName(), r.getState(), r.isLeader(), newVer)));
-        });
-    log.info("ENABLING_PRS OPS {}", result);
-    return result;
+  /** Switch a collection /to perReplicaState=true */
+  public static PerReplicaStatesOps enable(DocCollection coll, PerReplicaStates rs) {
+    return new PerReplicaStatesOps(
+            prs -> {
+              List<PerReplicaStates.Operation> result = new ArrayList<>();
+              coll.forEachReplica(
+                  (s, r) -> {
+                    PerReplicaStates.State old = prs.states.get(r.getName());
+                    int version = old == null ? 0 : old.version + 1;
+                    result.add(
+                        new PerReplicaStates.Operation(
+                            PerReplicaStates.Operation.Type.ADD,
+                            new PerReplicaStates.State(
+                                r.getName(), r.getState(), r.isLeader(), version)));
+                    addDeleteStaleNodes(result, old);
+                  });
+              return result;
+            })
+        .init(rs);
   }
 
-  private static List<PerReplicaStates.Operation> disable(PerReplicaStates prs) {
-    List<PerReplicaStates.Operation> result = new ArrayList<>();
-    prs.states.forEachEntry(
-        (s, state) ->
-            result.add(
-                new PerReplicaStates.Operation(PerReplicaStates.Operation.Type.DELETE, state)));
-    return result;
+  /** Switch a collection /to perReplicaState=false */
+  public static PerReplicaStatesOps disable(PerReplicaStates rs) {
+    PerReplicaStatesOps ops =
+        new PerReplicaStatesOps(
+            prs -> {
+              List<PerReplicaStates.Operation> result = new ArrayList<>();
+              prs.states.forEachEntry(
+                  (s, state) ->
+                      result.add(
+                          new PerReplicaStates.Operation(
+                              PerReplicaStates.Operation.Type.DELETE, state)));
+              return result;
+            });
+    ops.preOp = false;
+    return ops.init(rs);
   }
 
   /**
@@ -277,7 +281,7 @@ public class PerReplicaStatesOps {
                       new PerReplicaStates.Operation(
                           PerReplicaStates.Operation.Type.ADD,
                           new PerReplicaStates.State(
-                              replica, Replica.State.DOWN, Boolean.FALSE, r.version + 1)));
+                              replica, Replica.State.DOWN, r.isLeader, r.version + 1)));
                   addDeleteStaleNodes(operations, r);
                 } else {
                   operations.add(
diff --git a/solr/solrj/build.gradle b/solr/solrj/build.gradle
index f12ef45703c..412979ad1fa 100644
--- a/solr/solrj/build.gradle
+++ b/solr/solrj/build.gradle
@@ -49,6 +49,7 @@ dependencies {
   // ideally ZK centric tests move to solrj-zookeeper but sometimes we depend on ZK here anyway
   testImplementation project(':solr:solrj-zookeeper')
   testImplementation 'org.apache.zookeeper:zookeeper'
+  testImplementation 'org.apache.zookeeper:zookeeper-jute'
   permitTestUnusedDeclared 'org.apache.zookeeper:zookeeper'
 
   testImplementation 'org.apache.lucene:lucene-core'
diff --git a/solr/solrj/src/test/org/apache/solr/common/cloud/PerReplicaStatesIntegrationTest.java b/solr/solrj/src/test/org/apache/solr/common/cloud/PerReplicaStatesIntegrationTest.java
index 8a4c36a92ee..24f6039225d 100644
--- a/solr/solrj/src/test/org/apache/solr/common/cloud/PerReplicaStatesIntegrationTest.java
+++ b/solr/solrj/src/test/org/apache/solr/common/cloud/PerReplicaStatesIntegrationTest.java
@@ -32,6 +32,7 @@ import org.apache.solr.client.solrj.response.SolrPingResponse;
 import org.apache.solr.cloud.MiniSolrCloudCluster;
 import org.apache.solr.cloud.SolrCloudTestCase;
 import org.apache.solr.util.LogLevel;
+import org.apache.zookeeper.data.Stat;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -249,4 +250,56 @@ public class PerReplicaStatesIntegrationTest extends SolrCloudTestCase {
       cluster.shutdown();
     }
   }
+
+  public void testZkNodeVersions() throws Exception {
+    String NONPRS_COLL = "non_prs_test_coll1";
+    String PRS_COLL = "prs_test_coll2";
+    MiniSolrCloudCluster cluster =
+        configureCluster(3)
+            .withDistributedClusterStateUpdates(false, false)
+            .addConfig(
+                "conf",
+                getFile("solrj")
+                    .toPath()
+                    .resolve("solr")
+                    .resolve("configsets")
+                    .resolve("streaming")
+                    .resolve("conf"))
+            .withJettyConfig(jetty -> jetty.enableV2(true))
+            .configure();
+    try {
+      Stat stat = null;
+      CollectionAdminRequest.createCollection(NONPRS_COLL, "conf", 10, 1)
+          .process(cluster.getSolrClient());
+      stat = cluster.getZkClient().exists(DocCollection.getCollectionPath(NONPRS_COLL), null, true);
+      log.info("");
+      // the actual number can vary depending on batching
+      assertTrue(stat.getVersion() >= 2);
+      assertEquals(0, stat.getCversion());
+
+      CollectionAdminRequest.createCollection(PRS_COLL, "conf", 10, 1)
+          .setPerReplicaState(Boolean.TRUE)
+          .process(cluster.getSolrClient());
+      stat = cluster.getZkClient().exists(DocCollection.getCollectionPath(PRS_COLL), null, true);
+      // 0 from CreateCollectionCmd.create() and
+      // +1 each for each replica added CreateCollectionCmd.setData()
+      assertEquals(10, stat.getVersion());
+      // For each replica:
+      // +1 for ZkController#preRegister, in ZkController#publish, direct write PRS to down
+      // +2 for runLeaderProcess, flip the replica to leader
+      // +2 for ZkController#register, in ZkController#publish, direct write PRS to active
+      // Hence 5 * 10 = 70. Take note that +1 for ADD, and +2 for all the UPDATE (remove the old PRS
+      // and add new PRS entry)
+      assertEquals(50, stat.getCversion());
+      for (JettySolrRunner j : cluster.getJettySolrRunners()) {
+        j.stop();
+        j.start(true);
+        stat = cluster.getZkClient().exists(DocCollection.getCollectionPath(PRS_COLL), null, true);
+        // ensure restart does not update the state.json
+        assertEquals(10, stat.getVersion());
+      }
+    } finally {
+      cluster.shutdown();
+    }
+  }
 }