You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@helix.apache.org by GitBox <gi...@apache.org> on 2021/02/05 05:40:25 UTC

[GitHub] [helix] pkuwm commented on a change in pull request #1628: Per Replica Throttle -- 2nd: messages classification and basic throttle application

pkuwm commented on a change in pull request #1628:
URL: https://github.com/apache/helix/pull/1628#discussion_r570718493



##########
File path: helix-core/src/main/java/org/apache/helix/controller/stages/PerReplicaThrottleStage.java
##########
@@ -220,6 +271,224 @@ private MessageOutput throttlePerReplicaMessages(IdealState idealState,
     return output;
   }
 
+  private void propagateCountsTopDown(StateModelDefinition stateModelDef,
+      Map<String, Integer> expectedStateCountMap) {
+    // attribute state in higher priority to lower priority
+    List<String> stateList = stateModelDef.getStatesPriorityList();
+    if (stateList.size() <= 0) {

Review comment:
       Is it possible that `StatePriorityList` is not set in the state model and so stateList is returned null?

##########
File path: helix-core/src/main/java/org/apache/helix/controller/stages/PerReplicaThrottleStage.java
##########
@@ -220,6 +271,224 @@ private MessageOutput throttlePerReplicaMessages(IdealState idealState,
     return output;
   }
 
+  private void propagateCountsTopDown(StateModelDefinition stateModelDef,
+      Map<String, Integer> expectedStateCountMap) {
+    // attribute state in higher priority to lower priority
+    List<String> stateList = stateModelDef.getStatesPriorityList();
+    if (stateList.size() <= 0) {
+      return;
+    }
+    int index = 0;
+    String prevState = stateList.get(index);
+    if (!expectedStateCountMap.containsKey(prevState)) {
+      expectedStateCountMap.put(prevState, 0);
+    }
+    while (true) {
+      if (index == stateList.size() - 1) {
+        break;
+      }
+      index++;
+      String curState = stateList.get(index);
+      String num = stateModelDef.getNumInstancesPerState(curState);
+      if ("-1".equals(num)) {
+        break;
+      }
+      Integer prevCnt = expectedStateCountMap.get(prevState);
+      expectedStateCountMap
+          .put(curState, prevCnt + expectedStateCountMap.getOrDefault(curState, 0));
+      prevState = curState;
+    }
+  }
+
+  private void getPartitionExpectedAndCurrentStateCountMap(Partition partition,
+      Map<String, List<String>> preferenceLists, IdealState idealState,
+      ResourceControllerDataProvider cache, Map<String, String> currentStateMap,
+      Map<String, Integer> expectedStateCountMapOut, Map<String, Integer> currentStateCountsOut) {

Review comment:
       A function having 2 output params just looks unclean to me :) It'd be better if we can figure out a cleaner method.

##########
File path: helix-core/src/main/java/org/apache/helix/controller/stages/PerReplicaThrottleStage.java
##########
@@ -220,6 +271,224 @@ private MessageOutput throttlePerReplicaMessages(IdealState idealState,
     return output;
   }
 
+  private void propagateCountsTopDown(StateModelDefinition stateModelDef,
+      Map<String, Integer> expectedStateCountMap) {
+    // attribute state in higher priority to lower priority
+    List<String> stateList = stateModelDef.getStatesPriorityList();
+    if (stateList.size() <= 0) {
+      return;
+    }
+    int index = 0;
+    String prevState = stateList.get(index);
+    if (!expectedStateCountMap.containsKey(prevState)) {
+      expectedStateCountMap.put(prevState, 0);
+    }
+    while (true) {
+      if (index == stateList.size() - 1) {
+        break;
+      }
+      index++;
+      String curState = stateList.get(index);
+      String num = stateModelDef.getNumInstancesPerState(curState);
+      if ("-1".equals(num)) {
+        break;
+      }
+      Integer prevCnt = expectedStateCountMap.get(prevState);

Review comment:
       `Integer` -> `int`

##########
File path: helix-core/src/main/java/org/apache/helix/controller/stages/PerReplicaThrottleStage.java
##########
@@ -220,6 +271,224 @@ private MessageOutput throttlePerReplicaMessages(IdealState idealState,
     return output;
   }
 
+  private void propagateCountsTopDown(StateModelDefinition stateModelDef,
+      Map<String, Integer> expectedStateCountMap) {
+    // attribute state in higher priority to lower priority
+    List<String> stateList = stateModelDef.getStatesPriorityList();
+    if (stateList.size() <= 0) {
+      return;
+    }
+    int index = 0;
+    String prevState = stateList.get(index);
+    if (!expectedStateCountMap.containsKey(prevState)) {
+      expectedStateCountMap.put(prevState, 0);
+    }
+    while (true) {
+      if (index == stateList.size() - 1) {
+        break;
+      }
+      index++;
+      String curState = stateList.get(index);
+      String num = stateModelDef.getNumInstancesPerState(curState);
+      if ("-1".equals(num)) {
+        break;
+      }
+      Integer prevCnt = expectedStateCountMap.get(prevState);
+      expectedStateCountMap

Review comment:
       Just a question: for index = 0, prevState is the same as curState. Is it correct?

##########
File path: helix-core/src/main/java/org/apache/helix/controller/stages/PerReplicaThrottleStage.java
##########
@@ -220,6 +271,224 @@ private MessageOutput throttlePerReplicaMessages(IdealState idealState,
     return output;
   }
 
+  private void propagateCountsTopDown(StateModelDefinition stateModelDef,
+      Map<String, Integer> expectedStateCountMap) {
+    // attribute state in higher priority to lower priority
+    List<String> stateList = stateModelDef.getStatesPriorityList();
+    if (stateList.size() <= 0) {
+      return;
+    }
+    int index = 0;
+    String prevState = stateList.get(index);
+    if (!expectedStateCountMap.containsKey(prevState)) {
+      expectedStateCountMap.put(prevState, 0);
+    }
+    while (true) {
+      if (index == stateList.size() - 1) {
+        break;
+      }
+      index++;
+      String curState = stateList.get(index);
+      String num = stateModelDef.getNumInstancesPerState(curState);
+      if ("-1".equals(num)) {
+        break;
+      }
+      Integer prevCnt = expectedStateCountMap.get(prevState);
+      expectedStateCountMap
+          .put(curState, prevCnt + expectedStateCountMap.getOrDefault(curState, 0));
+      prevState = curState;
+    }
+  }
+
+  private void getPartitionExpectedAndCurrentStateCountMap(Partition partition,
+      Map<String, List<String>> preferenceLists, IdealState idealState,
+      ResourceControllerDataProvider cache, Map<String, String> currentStateMap,
+      Map<String, Integer> expectedStateCountMapOut, Map<String, Integer> currentStateCountsOut) {
+    List<String> preferenceList = preferenceLists.get(partition.getPartitionName());
+    if (preferenceList == null) {
+      preferenceList = Collections.emptyList();
+    }

Review comment:
       `getOrDefault()`?

##########
File path: helix-core/src/main/java/org/apache/helix/controller/stages/PerReplicaThrottleStage.java
##########
@@ -220,6 +271,224 @@ private MessageOutput throttlePerReplicaMessages(IdealState idealState,
     return output;
   }
 
+  private void propagateCountsTopDown(StateModelDefinition stateModelDef,
+      Map<String, Integer> expectedStateCountMap) {
+    // attribute state in higher priority to lower priority
+    List<String> stateList = stateModelDef.getStatesPriorityList();
+    if (stateList.size() <= 0) {
+      return;
+    }
+    int index = 0;
+    String prevState = stateList.get(index);
+    if (!expectedStateCountMap.containsKey(prevState)) {
+      expectedStateCountMap.put(prevState, 0);
+    }
+    while (true) {
+      if (index == stateList.size() - 1) {
+        break;
+      }
+      index++;
+      String curState = stateList.get(index);
+      String num = stateModelDef.getNumInstancesPerState(curState);
+      if ("-1".equals(num)) {
+        break;
+      }
+      Integer prevCnt = expectedStateCountMap.get(prevState);
+      expectedStateCountMap
+          .put(curState, prevCnt + expectedStateCountMap.getOrDefault(curState, 0));
+      prevState = curState;
+    }
+  }
+
+  private void getPartitionExpectedAndCurrentStateCountMap(Partition partition,
+      Map<String, List<String>> preferenceLists, IdealState idealState,
+      ResourceControllerDataProvider cache, Map<String, String> currentStateMap,
+      Map<String, Integer> expectedStateCountMapOut, Map<String, Integer> currentStateCountsOut) {
+    List<String> preferenceList = preferenceLists.get(partition.getPartitionName());
+    if (preferenceList == null) {
+      preferenceList = Collections.emptyList();
+    }
+
+    int replica =
+        idealState.getMinActiveReplicas() == -1 ? idealState.getReplicaCount(preferenceList.size())
+            : idealState.getMinActiveReplicas();
+    Set<String> activeList = new HashSet<>(preferenceList);
+    activeList.retainAll(cache.getEnabledLiveInstances());
+
+    // For each state, check that this partition currently has the required number of that state as
+    // required by StateModelDefinition.
+    String stateModelDefName = idealState.getStateModelDefRef();
+    StateModelDefinition stateModelDef = cache.getStateModelDef(stateModelDefName);
+    LinkedHashMap<String, Integer> expectedStateCountMap =
+        stateModelDef.getStateCountMap(activeList.size(), replica); // StateModelDefinition's counts
+
+    // Current counts without disabled partitions or disabled instances
+    Map<String, String> currentStateMapWithoutDisabled = new HashMap<>(currentStateMap);
+    currentStateMapWithoutDisabled.keySet().removeAll(cache
+        .getDisabledInstancesForPartition(idealState.getResourceName(),
+            partition.getPartitionName()));
+    Map<String, Integer> currentStateCounts =
+        StateModelDefinition.getStateCounts(currentStateMapWithoutDisabled);
+
+    expectedStateCountMapOut.putAll(expectedStateCountMap);
+    currentStateCountsOut.putAll(currentStateCounts);
+    propagateCountsTopDown(stateModelDef, expectedStateCountMapOut);
+    propagateCountsTopDown(stateModelDef, currentStateCountsOut);
+  }
+
+  void calculateExistingAndCurrentStateCount(Map<Partition, List<Message>> selectedResourceMessages,
+      CurrentStateOutput currentStateOutput, BestPossibleStateOutput bestPossibleStateOutput,
+      IdealState idealState, ResourceControllerDataProvider cache,
+      Map<Partition, Map<String, Integer>> expectedStateCountByPartition,
+      Map<Partition, Map<String, Integer>> currentStateCountsByPartition) {
+    String resourceName = idealState.getResourceName();
+    Map<String, List<String>> preferenceLists =
+        bestPossibleStateOutput.getPreferenceLists(resourceName);
+    for (Partition partition : selectedResourceMessages.keySet()) {
+      Map<String, String> currentStateMap =
+          currentStateOutput.getCurrentStateMap(resourceName, partition);
+
+      Map<String, Integer> expectedStateCounts = new HashMap<>();
+      Map<String, Integer> currentStateCounts = new HashMap<>();
+      getPartitionExpectedAndCurrentStateCountMap(partition, preferenceLists, idealState, cache,
+          currentStateMap, expectedStateCounts, currentStateCounts);
+
+      // save these two maps for later usage
+      expectedStateCountByPartition.put(partition, expectedStateCounts);
+      currentStateCountsByPartition.put(partition, currentStateCounts);
+    }
+  }
+
+  /*
+   * Classify the messages of each partition into recovery and load messages.
+   */
+  private void classifyMessages(String resourceName, IdealState idealState,
+      ResourceControllerDataProvider cache, Map<Partition, List<Message>> selectedResourceMessages,
+      List<Message> recoveryMessages, List<Message> loadMessages,
+      Map<Partition, Map<String, Integer>> expectedStateCountByPartition,
+      Map<Partition, Map<String, Integer>> currentStateCountsByPartition) {
+    LogUtil.logInfo(logger, _eventId,
+        String.format("Classify message for resource: %s", resourceName));
+
+    for (Partition partition : selectedResourceMessages.keySet()) {
+      Map<String, Integer> expectedStateCountMap = expectedStateCountByPartition.get(partition);
+      Map<String, Integer> currentStateCounts = currentStateCountsByPartition.get(partition);
+
+      List<Message> partitionMessages = selectedResourceMessages.get(partition);
+      if (partitionMessages == null) {
+        continue;
+      }
+
+      String stateModelDefName = idealState.getStateModelDefRef();
+      StateModelDefinition stateModelDef = cache.getStateModelDef(stateModelDefName);
+      // sort partitionMessages based on transition priority and then creation timestamp for transition message
+      // TODO: sort messages in same partition in next PR
+      Set<String> disabledInstances =
+          cache.getDisabledInstancesForPartition(resourceName, partition.getPartitionName());
+      for (Message msg : partitionMessages) {
+        if (!Message.MessageType.STATE_TRANSITION.name().equals(msg.getMsgType())) {
+          if (logger.isDebugEnabled()) {
+            LogUtil.logDebug(logger, _eventId, String
+                .format("Message: %s not subject to throttle in resource: %s with type %s", msg,
+                    resourceName, msg.getMsgType()));
+          }
+          continue;
+        }
+
+        boolean isUpward = !isDownwardTransition(idealState, cache, msg);
+
+        // for disabled disabled instance, the downward transition is not subjected to load throttling
+        // we will let them pass through ASAP.
+        String instance = msg.getTgtName();
+        if (disabledInstances.contains(instance)) {
+          if (!isUpward) {
+            if (logger.isDebugEnabled()) {
+              LogUtil.logDebug(logger, _eventId, String.format(
+                  "Message: %s not subject to throttle in resource: %s to disabled instancce %s",
+                  msg, resourceName, instance));
+            }
+            continue;
+          }
+        }
+
+        String toState = msg.getToState();
+        if (toState.equals(HelixDefinedState.DROPPED.name()) || toState
+            .equals(HelixDefinedState.ERROR.name())) {
+          if (logger.isDebugEnabled()) {
+            LogUtil.logDebug(logger, _eventId, String
+                .format("Message: %s not subject to throttle in resource: %s with toState %s", msg,
+                    resourceName, toState));
+          }
+          continue;
+        }
+
+        Integer expectedCount = expectedStateCountMap.get(toState);
+        Integer currentCount = currentStateCounts.get(toState);
+        expectedCount = expectedCount == null ? 0 : expectedCount;
+        currentCount = currentCount == null ? 0 : currentCount;
+
+        if (isUpward && (currentCount < expectedCount)) {
+          recoveryMessages.add(msg);
+          currentStateCounts.put(toState, currentCount + 1);
+        } else {
+          loadMessages.add(msg);
+        }
+      }
+    }
+  }
+
+  protected void applyThrottling(String resourceName,

Review comment:
       8 params look too many to me. I would simplify it if possible. There are too many logic branches inside the method, based on  `onlyDownwardLoadBalance` and `rebalanceType`, which could also conflict when `isRecovery && onlyDownwardLoadBalance`. Might not be a good design. Not sure if we could split it into two methods for load rebalance and recover rebalance, or upward and downward..
   
   Roughly looking, `throttledMessages` is the output, why not put it as a return?

##########
File path: helix-core/src/main/java/org/apache/helix/controller/stages/PerReplicaThrottleStage.java
##########
@@ -220,6 +271,224 @@ private MessageOutput throttlePerReplicaMessages(IdealState idealState,
     return output;
   }
 
+  private void propagateCountsTopDown(StateModelDefinition stateModelDef,
+      Map<String, Integer> expectedStateCountMap) {
+    // attribute state in higher priority to lower priority
+    List<String> stateList = stateModelDef.getStatesPriorityList();
+    if (stateList.size() <= 0) {
+      return;
+    }
+    int index = 0;
+    String prevState = stateList.get(index);
+    if (!expectedStateCountMap.containsKey(prevState)) {
+      expectedStateCountMap.put(prevState, 0);
+    }

Review comment:
       Nit, can be simplified with a preferred method `expectedStateCountMap.putIfAbsent(prevState, 0);`

##########
File path: helix-core/src/main/java/org/apache/helix/controller/stages/PerReplicaThrottleStage.java
##########
@@ -195,11 +204,48 @@ private MessageOutput throttlePerReplicaMessages(IdealState idealState,
       return output;
     }
 
-    // TODO: later PRs
+    Set<Partition> partitionsWithErrorStateReplica = new HashSet<>();
+
+    Map<Partition, Map<String, Integer>> expectedStateCountByPartition = new HashMap<>();
+    Map<Partition, Map<String, Integer>> currentStateCountsByPartition = new HashMap<>();
+
+    calculateExistingAndCurrentStateCount(selectedResourceMessages, currentStateOutput,
+        bestPossibleStateOutput, idealState, cache, expectedStateCountByPartition,
+        currentStateCountsByPartition);
+
     // Step 1: charge existing pending messages and update retraced state map.
+    // TODO: later PRs
     // Step 2: classify all the messages into recovery message list and load message list
+    List<Message> recoveryMessages = new ArrayList<>();
+    List<Message> loadMessages = new ArrayList<>();
+    classifyMessages(resourceName, idealState, cache, selectedResourceMessages, recoveryMessages,
+        loadMessages, expectedStateCountByPartition, currentStateCountsByPartition);
+
     // Step 3: sorts recovery message list and applies throttling
+    Set<Message> throttledRecoveryMessages = new HashSet<>();
+    LogUtil.logDebug(logger, _eventId,
+        String.format("applying recovery rebalance with resource %s", resourceName));
+    applyThrottling(resourceName, throttleController, idealState, cache, false, recoveryMessages,
+        throttledRecoveryMessages, StateTransitionThrottleConfig.RebalanceType.RECOVERY_BALANCE);
+
     // Step 4: sorts load message list and applies throttling
+    // TODO: calculate error-on-recovery downward threshold with complex backward compatibility next
+    // TODO: this can be done together with chargePendingMessage() where partitionsNeedRecovery is from
+    boolean onlyDownwardLoadBalance = partitionsWithErrorStateReplica.size() > 1;
+    Set<Message> throttledLoadMessages = new HashSet<>();
+    LogUtil.logDebug(logger, _eventId, String
+        .format("applying load rebalance with resource %s, onlyDownwardLoadBalance %s",
+            resourceName, onlyDownwardLoadBalance));
+    applyThrottling(resourceName, throttleController, idealState, cache, onlyDownwardLoadBalance,
+        loadMessages, throttledLoadMessages,
+        StateTransitionThrottleConfig.RebalanceType.LOAD_BALANCE);
+
+    LogUtil.logDebug(logger, _eventId, String
+        .format("resource %s, throttled recovery message: %s", resourceName,
+            throttledRecoveryMessages));
+    LogUtil.logDebug(logger, _eventId, String
+        .format("resource %s, throttled load messages: %s", resourceName, throttledLoadMessages));

Review comment:
       logDebug is not protected by `if debugEnabled`, which would cause unexpected string creation overhead.
   Please add `if debugEnabled`, or use `logger.debug();`, so strings won't be created/formatted at info/error logging level. 
   
   Same for other `logDebug` in this PR.

##########
File path: helix-core/src/main/java/org/apache/helix/controller/stages/PerReplicaThrottleStage.java
##########
@@ -220,6 +271,224 @@ private MessageOutput throttlePerReplicaMessages(IdealState idealState,
     return output;
   }
 
+  private void propagateCountsTopDown(StateModelDefinition stateModelDef,
+      Map<String, Integer> expectedStateCountMap) {
+    // attribute state in higher priority to lower priority
+    List<String> stateList = stateModelDef.getStatesPriorityList();
+    if (stateList.size() <= 0) {
+      return;
+    }
+    int index = 0;
+    String prevState = stateList.get(index);
+    if (!expectedStateCountMap.containsKey(prevState)) {
+      expectedStateCountMap.put(prevState, 0);
+    }
+    while (true) {
+      if (index == stateList.size() - 1) {
+        break;
+      }
+      index++;
+      String curState = stateList.get(index);
+      String num = stateModelDef.getNumInstancesPerState(curState);
+      if ("-1".equals(num)) {
+        break;
+      }
+      Integer prevCnt = expectedStateCountMap.get(prevState);
+      expectedStateCountMap
+          .put(curState, prevCnt + expectedStateCountMap.getOrDefault(curState, 0));
+      prevState = curState;
+    }
+  }
+
+  private void getPartitionExpectedAndCurrentStateCountMap(Partition partition,
+      Map<String, List<String>> preferenceLists, IdealState idealState,
+      ResourceControllerDataProvider cache, Map<String, String> currentStateMap,
+      Map<String, Integer> expectedStateCountMapOut, Map<String, Integer> currentStateCountsOut) {
+    List<String> preferenceList = preferenceLists.get(partition.getPartitionName());
+    if (preferenceList == null) {
+      preferenceList = Collections.emptyList();
+    }
+
+    int replica =
+        idealState.getMinActiveReplicas() == -1 ? idealState.getReplicaCount(preferenceList.size())
+            : idealState.getMinActiveReplicas();
+    Set<String> activeList = new HashSet<>(preferenceList);
+    activeList.retainAll(cache.getEnabledLiveInstances());
+
+    // For each state, check that this partition currently has the required number of that state as
+    // required by StateModelDefinition.
+    String stateModelDefName = idealState.getStateModelDefRef();
+    StateModelDefinition stateModelDef = cache.getStateModelDef(stateModelDefName);
+    LinkedHashMap<String, Integer> expectedStateCountMap =
+        stateModelDef.getStateCountMap(activeList.size(), replica); // StateModelDefinition's counts
+
+    // Current counts without disabled partitions or disabled instances
+    Map<String, String> currentStateMapWithoutDisabled = new HashMap<>(currentStateMap);
+    currentStateMapWithoutDisabled.keySet().removeAll(cache
+        .getDisabledInstancesForPartition(idealState.getResourceName(),
+            partition.getPartitionName()));
+    Map<String, Integer> currentStateCounts =
+        StateModelDefinition.getStateCounts(currentStateMapWithoutDisabled);
+
+    expectedStateCountMapOut.putAll(expectedStateCountMap);
+    currentStateCountsOut.putAll(currentStateCounts);
+    propagateCountsTopDown(stateModelDef, expectedStateCountMapOut);
+    propagateCountsTopDown(stateModelDef, currentStateCountsOut);
+  }
+
+  void calculateExistingAndCurrentStateCount(Map<Partition, List<Message>> selectedResourceMessages,
+      CurrentStateOutput currentStateOutput, BestPossibleStateOutput bestPossibleStateOutput,
+      IdealState idealState, ResourceControllerDataProvider cache,
+      Map<Partition, Map<String, Integer>> expectedStateCountByPartition,
+      Map<Partition, Map<String, Integer>> currentStateCountsByPartition) {
+    String resourceName = idealState.getResourceName();
+    Map<String, List<String>> preferenceLists =
+        bestPossibleStateOutput.getPreferenceLists(resourceName);
+    for (Partition partition : selectedResourceMessages.keySet()) {
+      Map<String, String> currentStateMap =
+          currentStateOutput.getCurrentStateMap(resourceName, partition);
+
+      Map<String, Integer> expectedStateCounts = new HashMap<>();
+      Map<String, Integer> currentStateCounts = new HashMap<>();
+      getPartitionExpectedAndCurrentStateCountMap(partition, preferenceLists, idealState, cache,
+          currentStateMap, expectedStateCounts, currentStateCounts);
+
+      // save these two maps for later usage
+      expectedStateCountByPartition.put(partition, expectedStateCounts);
+      currentStateCountsByPartition.put(partition, currentStateCounts);
+    }
+  }
+
+  /*
+   * Classify the messages of each partition into recovery and load messages.
+   */
+  private void classifyMessages(String resourceName, IdealState idealState,
+      ResourceControllerDataProvider cache, Map<Partition, List<Message>> selectedResourceMessages,
+      List<Message> recoveryMessages, List<Message> loadMessages,
+      Map<Partition, Map<String, Integer>> expectedStateCountByPartition,
+      Map<Partition, Map<String, Integer>> currentStateCountsByPartition) {
+    LogUtil.logInfo(logger, _eventId,
+        String.format("Classify message for resource: %s", resourceName));
+
+    for (Partition partition : selectedResourceMessages.keySet()) {
+      Map<String, Integer> expectedStateCountMap = expectedStateCountByPartition.get(partition);
+      Map<String, Integer> currentStateCounts = currentStateCountsByPartition.get(partition);
+
+      List<Message> partitionMessages = selectedResourceMessages.get(partition);
+      if (partitionMessages == null) {
+        continue;
+      }
+
+      String stateModelDefName = idealState.getStateModelDefRef();
+      StateModelDefinition stateModelDef = cache.getStateModelDef(stateModelDefName);
+      // sort partitionMessages based on transition priority and then creation timestamp for transition message
+      // TODO: sort messages in same partition in next PR
+      Set<String> disabledInstances =
+          cache.getDisabledInstancesForPartition(resourceName, partition.getPartitionName());
+      for (Message msg : partitionMessages) {
+        if (!Message.MessageType.STATE_TRANSITION.name().equals(msg.getMsgType())) {
+          if (logger.isDebugEnabled()) {
+            LogUtil.logDebug(logger, _eventId, String
+                .format("Message: %s not subject to throttle in resource: %s with type %s", msg,
+                    resourceName, msg.getMsgType()));
+          }
+          continue;
+        }
+
+        boolean isUpward = !isDownwardTransition(idealState, cache, msg);
+
+        // for disabled disabled instance, the downward transition is not subjected to load throttling
+        // we will let them pass through ASAP.
+        String instance = msg.getTgtName();
+        if (disabledInstances.contains(instance)) {
+          if (!isUpward) {
+            if (logger.isDebugEnabled()) {
+              LogUtil.logDebug(logger, _eventId, String.format(
+                  "Message: %s not subject to throttle in resource: %s to disabled instancce %s",
+                  msg, resourceName, instance));
+            }
+            continue;
+          }
+        }
+
+        String toState = msg.getToState();
+        if (toState.equals(HelixDefinedState.DROPPED.name()) || toState
+            .equals(HelixDefinedState.ERROR.name())) {
+          if (logger.isDebugEnabled()) {
+            LogUtil.logDebug(logger, _eventId, String
+                .format("Message: %s not subject to throttle in resource: %s with toState %s", msg,
+                    resourceName, toState));
+          }
+          continue;
+        }
+
+        Integer expectedCount = expectedStateCountMap.get(toState);
+        Integer currentCount = currentStateCounts.get(toState);
+        expectedCount = expectedCount == null ? 0 : expectedCount;
+        currentCount = currentCount == null ? 0 : currentCount;
+
+        if (isUpward && (currentCount < expectedCount)) {
+          recoveryMessages.add(msg);
+          currentStateCounts.put(toState, currentCount + 1);
+        } else {
+          loadMessages.add(msg);
+        }
+      }
+    }
+  }
+
+  protected void applyThrottling(String resourceName,
+      StateTransitionThrottleController throttleController, IdealState idealState,
+      ResourceControllerDataProvider cache, boolean onlyDownwardLoadBalance, List<Message> messages,
+      Set<Message> throttledMessages, StateTransitionThrottleConfig.RebalanceType rebalanceType) {
+    boolean isRecovery =
+        rebalanceType == StateTransitionThrottleConfig.RebalanceType.RECOVERY_BALANCE;
+    if (isRecovery && onlyDownwardLoadBalance) {
+      logger.error("onlyDownwardLoadBalance can't be used together with recovery_rebalance");
+      return;
+    }
+
+    // TODO: add message sorting in next PR
+    logger.trace("throttleControllerstate->{} before load", throttleController);
+    for (Message msg : messages) {
+      if (onlyDownwardLoadBalance) {
+        if (!isDownwardTransition(idealState, cache, msg)) {
+          throttledMessages.add(msg);
+          if (logger.isDebugEnabled()) {

Review comment:
       The debug code is kind of distracting. If they are really necessary, I would suggest changing them to `logger.debug()` which could reduce lines of distracting code.

##########
File path: helix-core/src/main/java/org/apache/helix/controller/stages/PerReplicaThrottleStage.java
##########
@@ -220,6 +271,224 @@ private MessageOutput throttlePerReplicaMessages(IdealState idealState,
     return output;
   }
 
+  private void propagateCountsTopDown(StateModelDefinition stateModelDef,
+      Map<String, Integer> expectedStateCountMap) {
+    // attribute state in higher priority to lower priority
+    List<String> stateList = stateModelDef.getStatesPriorityList();
+    if (stateList.size() <= 0) {
+      return;
+    }
+    int index = 0;
+    String prevState = stateList.get(index);
+    if (!expectedStateCountMap.containsKey(prevState)) {
+      expectedStateCountMap.put(prevState, 0);
+    }
+    while (true) {
+      if (index == stateList.size() - 1) {
+        break;
+      }
+      index++;
+      String curState = stateList.get(index);
+      String num = stateModelDef.getNumInstancesPerState(curState);
+      if ("-1".equals(num)) {
+        break;
+      }
+      Integer prevCnt = expectedStateCountMap.get(prevState);
+      expectedStateCountMap
+          .put(curState, prevCnt + expectedStateCountMap.getOrDefault(curState, 0));
+      prevState = curState;
+    }
+  }
+
+  private void getPartitionExpectedAndCurrentStateCountMap(Partition partition,
+      Map<String, List<String>> preferenceLists, IdealState idealState,
+      ResourceControllerDataProvider cache, Map<String, String> currentStateMap,
+      Map<String, Integer> expectedStateCountMapOut, Map<String, Integer> currentStateCountsOut) {
+    List<String> preferenceList = preferenceLists.get(partition.getPartitionName());
+    if (preferenceList == null) {
+      preferenceList = Collections.emptyList();
+    }
+
+    int replica =
+        idealState.getMinActiveReplicas() == -1 ? idealState.getReplicaCount(preferenceList.size())
+            : idealState.getMinActiveReplicas();
+    Set<String> activeList = new HashSet<>(preferenceList);
+    activeList.retainAll(cache.getEnabledLiveInstances());
+
+    // For each state, check that this partition currently has the required number of that state as
+    // required by StateModelDefinition.
+    String stateModelDefName = idealState.getStateModelDefRef();
+    StateModelDefinition stateModelDef = cache.getStateModelDef(stateModelDefName);
+    LinkedHashMap<String, Integer> expectedStateCountMap =
+        stateModelDef.getStateCountMap(activeList.size(), replica); // StateModelDefinition's counts
+
+    // Current counts without disabled partitions or disabled instances
+    Map<String, String> currentStateMapWithoutDisabled = new HashMap<>(currentStateMap);
+    currentStateMapWithoutDisabled.keySet().removeAll(cache
+        .getDisabledInstancesForPartition(idealState.getResourceName(),
+            partition.getPartitionName()));
+    Map<String, Integer> currentStateCounts =
+        StateModelDefinition.getStateCounts(currentStateMapWithoutDisabled);
+
+    expectedStateCountMapOut.putAll(expectedStateCountMap);
+    currentStateCountsOut.putAll(currentStateCounts);
+    propagateCountsTopDown(stateModelDef, expectedStateCountMapOut);
+    propagateCountsTopDown(stateModelDef, currentStateCountsOut);
+  }
+
+  void calculateExistingAndCurrentStateCount(Map<Partition, List<Message>> selectedResourceMessages,
+      CurrentStateOutput currentStateOutput, BestPossibleStateOutput bestPossibleStateOutput,
+      IdealState idealState, ResourceControllerDataProvider cache,
+      Map<Partition, Map<String, Integer>> expectedStateCountByPartition,
+      Map<Partition, Map<String, Integer>> currentStateCountsByPartition) {
+    String resourceName = idealState.getResourceName();
+    Map<String, List<String>> preferenceLists =
+        bestPossibleStateOutput.getPreferenceLists(resourceName);
+    for (Partition partition : selectedResourceMessages.keySet()) {
+      Map<String, String> currentStateMap =
+          currentStateOutput.getCurrentStateMap(resourceName, partition);
+
+      Map<String, Integer> expectedStateCounts = new HashMap<>();
+      Map<String, Integer> currentStateCounts = new HashMap<>();
+      getPartitionExpectedAndCurrentStateCountMap(partition, preferenceLists, idealState, cache,
+          currentStateMap, expectedStateCounts, currentStateCounts);
+
+      // save these two maps for later usage
+      expectedStateCountByPartition.put(partition, expectedStateCounts);
+      currentStateCountsByPartition.put(partition, currentStateCounts);
+    }
+  }
+
+  /*
+   * Classify the messages of each partition into recovery and load messages.
+   */
+  private void classifyMessages(String resourceName, IdealState idealState,
+      ResourceControllerDataProvider cache, Map<Partition, List<Message>> selectedResourceMessages,
+      List<Message> recoveryMessages, List<Message> loadMessages,
+      Map<Partition, Map<String, Integer>> expectedStateCountByPartition,
+      Map<Partition, Map<String, Integer>> currentStateCountsByPartition) {
+    LogUtil.logInfo(logger, _eventId,
+        String.format("Classify message for resource: %s", resourceName));
+
+    for (Partition partition : selectedResourceMessages.keySet()) {
+      Map<String, Integer> expectedStateCountMap = expectedStateCountByPartition.get(partition);
+      Map<String, Integer> currentStateCounts = currentStateCountsByPartition.get(partition);
+
+      List<Message> partitionMessages = selectedResourceMessages.get(partition);
+      if (partitionMessages == null) {
+        continue;
+      }
+
+      String stateModelDefName = idealState.getStateModelDefRef();
+      StateModelDefinition stateModelDef = cache.getStateModelDef(stateModelDefName);
+      // sort partitionMessages based on transition priority and then creation timestamp for transition message
+      // TODO: sort messages in same partition in next PR
+      Set<String> disabledInstances =
+          cache.getDisabledInstancesForPartition(resourceName, partition.getPartitionName());
+      for (Message msg : partitionMessages) {
+        if (!Message.MessageType.STATE_TRANSITION.name().equals(msg.getMsgType())) {
+          if (logger.isDebugEnabled()) {
+            LogUtil.logDebug(logger, _eventId, String
+                .format("Message: %s not subject to throttle in resource: %s with type %s", msg,
+                    resourceName, msg.getMsgType()));
+          }
+          continue;
+        }
+
+        boolean isUpward = !isDownwardTransition(idealState, cache, msg);
+
+        // for disabled disabled instance, the downward transition is not subjected to load throttling
+        // we will let them pass through ASAP.
+        String instance = msg.getTgtName();
+        if (disabledInstances.contains(instance)) {
+          if (!isUpward) {
+            if (logger.isDebugEnabled()) {
+              LogUtil.logDebug(logger, _eventId, String.format(
+                  "Message: %s not subject to throttle in resource: %s to disabled instancce %s",
+                  msg, resourceName, instance));
+            }
+            continue;
+          }
+        }
+
+        String toState = msg.getToState();
+        if (toState.equals(HelixDefinedState.DROPPED.name()) || toState
+            .equals(HelixDefinedState.ERROR.name())) {
+          if (logger.isDebugEnabled()) {
+            LogUtil.logDebug(logger, _eventId, String
+                .format("Message: %s not subject to throttle in resource: %s with toState %s", msg,
+                    resourceName, toState));
+          }
+          continue;
+        }
+
+        Integer expectedCount = expectedStateCountMap.get(toState);
+        Integer currentCount = currentStateCounts.get(toState);
+        expectedCount = expectedCount == null ? 0 : expectedCount;
+        currentCount = currentCount == null ? 0 : currentCount;

Review comment:
       `map.getOrDefault()` will help.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@helix.apache.org
For additional commands, e-mail: reviews-help@helix.apache.org