You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@helix.apache.org by jx...@apache.org on 2021/06/01 18:44:45 UTC

[helix] 04/07: Applying per replica logic for entire stage #1724

This is an automated email from the ASF dual-hosted git repository.

jxue pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/helix.git

commit 3a737ef7c94233d21dacce2bb211514abc012c46
Author: Junkai Xue <jx...@linkedin.com>
AuthorDate: Sun May 16 14:41:34 2021 -0700

    Applying per replica logic for entire stage #1724
    
    This commit contains:
    
    1.Per resource looping with dynamic rebalance type computation and partition level ordering.
    2.Let entire stage works against per replica throttling logic.
    3.Move intermediate state compute logic to a centralized place.
---
 .../helix/controller/common/PartitionStateMap.java |   9 +-
 .../stages/IntermediateStateCalcStage.java         | 244 +++++++++++----------
 .../helix/controller/stages/MessageOutput.java     |   4 +
 3 files changed, 137 insertions(+), 120 deletions(-)

diff --git a/helix-core/src/main/java/org/apache/helix/controller/common/PartitionStateMap.java b/helix-core/src/main/java/org/apache/helix/controller/common/PartitionStateMap.java
index e3c899a..7bfa6dc 100644
--- a/helix-core/src/main/java/org/apache/helix/controller/common/PartitionStateMap.java
+++ b/helix-core/src/main/java/org/apache/helix/controller/common/PartitionStateMap.java
@@ -24,6 +24,7 @@ import java.util.HashMap;
 import java.util.Map;
 import java.util.Set;
 
+import java.util.stream.Collectors;
 import org.apache.helix.model.Partition;
 
 /**
@@ -39,10 +40,12 @@ public class PartitionStateMap {
     _stateMap = new HashMap<>();
   }
 
-  public PartitionStateMap(String resourceName,
-      Map<Partition, Map<String, String>> partitionStateMap) {
+  // Deep copy of the partitionStateMap is a safer way.
+  public PartitionStateMap(String resourceName, Map<Partition, Map<String, String>> partitionStateMap) {
     _resourceName = resourceName;
-    _stateMap = partitionStateMap;
+    _stateMap = partitionStateMap.entrySet()
+        .stream()
+        .collect(Collectors.toMap(e -> e.getKey(), e -> new HashMap<>(e.getValue())));
   }
 
   public Set<Partition> partitionSet() {
diff --git a/helix-core/src/main/java/org/apache/helix/controller/stages/IntermediateStateCalcStage.java b/helix-core/src/main/java/org/apache/helix/controller/stages/IntermediateStateCalcStage.java
index 6a5d2f1..b91dca6 100644
--- a/helix-core/src/main/java/org/apache/helix/controller/stages/IntermediateStateCalcStage.java
+++ b/helix-core/src/main/java/org/apache/helix/controller/stages/IntermediateStateCalcStage.java
@@ -73,16 +73,17 @@ public class IntermediateStateCalcStage extends AbstractBaseStage {
         event.getAttribute(AttributeName.RESOURCES_TO_REBALANCE.name());
     ResourceControllerDataProvider cache =
         event.getAttribute(AttributeName.ControllerDataProvider.name());
+    MessageOutput messageOutput =event.getAttribute(AttributeName.MESSAGES_SELECTED.name());
 
     if (currentStateOutput == null || bestPossibleStateOutput == null || resourceToRebalance == null
-        || cache == null) {
+        || cache == null || messageOutput == null) {
       throw new StageException(String.format("Missing attributes in event: %s. "
-          + "Requires CURRENT_STATE (%s) |BEST_POSSIBLE_STATE (%s) |RESOURCES (%s) |DataCache (%s)",
-          event, currentStateOutput, bestPossibleStateOutput, resourceToRebalance, cache));
+              + "Requires CURRENT_STATE (%s) |BEST_POSSIBLE_STATE (%s) |RESOURCES (%s) |MESSAGE_SELECT (%s) |DataCache (%s)",
+          event, currentStateOutput, bestPossibleStateOutput, resourceToRebalance, messageOutput, cache));
     }
 
     IntermediateStateOutput intermediateStateOutput =
-        compute(event, resourceToRebalance, currentStateOutput, bestPossibleStateOutput);
+        compute(event, resourceToRebalance, currentStateOutput, bestPossibleStateOutput, messageOutput);
     event.addAttribute(AttributeName.INTERMEDIATE_STATE.name(), intermediateStateOutput);
 
     // Make sure no instance has more replicas/partitions assigned than maxPartitionPerInstance. If
@@ -106,7 +107,7 @@ public class IntermediateStateCalcStage extends AbstractBaseStage {
    * @return
    */
   private IntermediateStateOutput compute(ClusterEvent event, Map<String, Resource> resourceMap,
-      CurrentStateOutput currentStateOutput, BestPossibleStateOutput bestPossibleStateOutput) {
+      CurrentStateOutput currentStateOutput, BestPossibleStateOutput bestPossibleStateOutput, MessageOutput messageOutput) {
     IntermediateStateOutput output = new IntermediateStateOutput();
     ResourceControllerDataProvider dataCache =
         event.getAttribute(AttributeName.ControllerDataProvider.name());
@@ -173,14 +174,13 @@ public class IntermediateStateCalcStage extends AbstractBaseStage {
       }
 
       try {
-        output.setState(resourceName,
-            computeIntermediatePartitionState(dataCache, clusterStatusMonitor, idealState,
-                resourceMap.get(resourceName), currentStateOutput,
-                bestPossibleStateOutput.getPartitionStateMap(resourceName),
-                bestPossibleStateOutput.getPreferenceLists(resourceName), throttleController));
+        output.setState(resourceName, computeIntermediatePartitionState(dataCache, clusterStatusMonitor, idealState,
+            resourceMap.get(resourceName), currentStateOutput,
+            bestPossibleStateOutput.getPartitionStateMap(resourceName),
+            bestPossibleStateOutput.getPreferenceLists(resourceName), throttleController,
+            messageOutput.getResourceMessageMap(resourceName)));
       } catch (HelixException ex) {
-        LogUtil.logInfo(logger, _eventId,
-            "Failed to calculate intermediate partition states for resource " + resourceName, ex);
+        LogUtil.logInfo(logger, _eventId, "Failed to calculate intermediate partition states for resource " + resourceName, ex);
         failedResources.add(resourceName);
       }
     }
@@ -294,99 +294,41 @@ public class IntermediateStateCalcStage extends AbstractBaseStage {
       ClusterStatusMonitor clusterStatusMonitor, IdealState idealState, Resource resource,
       CurrentStateOutput currentStateOutput, PartitionStateMap bestPossiblePartitionStateMap,
       Map<String, List<String>> preferenceLists,
-      StateTransitionThrottleController throttleController) {
+      StateTransitionThrottleController throttleController, Map<Partition, List<Message>> resourceMessageMap) {
     String resourceName = resource.getResourceName();
     LogUtil.logDebug(logger, _eventId, String.format("Processing resource: %s", resourceName));
 
-    // Throttling is applied only on FULL-AUTO mode
-    if (!throttleController.isThrottleEnabled()
-        || !IdealState.RebalanceMode.FULL_AUTO.equals(idealState.getRebalanceMode())) {
+    // Throttling is applied only on FULL-AUTO mode and if the resource message map is empty, no throttling needed.
+    if (!throttleController.isThrottleEnabled() || !IdealState.RebalanceMode.FULL_AUTO.equals(
+        idealState.getRebalanceMode()) || resourceMessageMap.isEmpty()) {
       return bestPossiblePartitionStateMap;
     }
 
     String stateModelDefName = idealState.getStateModelDefRef();
     StateModelDefinition stateModelDef = cache.getStateModelDef(stateModelDefName);
-    PartitionStateMap intermediatePartitionStateMap = new PartitionStateMap(resourceName);
+    // This require a deep copy of current state map because some of the states will be overwritten by applying
+    // messages to it.
 
-    Set<Partition> partitionsNeedRecovery = new HashSet<>();
-    Set<Partition> partitionsNeedLoadBalance = new HashSet<>();
     Set<Partition> partitionsWithErrorStateReplica = new HashSet<>();
-    for (Partition partition : resource.getPartitions()) {
-      Map<String, String> currentStateMap =
-          currentStateOutput.getCurrentStateMap(resourceName, partition);
-      Map<String, String> bestPossibleMap =
-          bestPossiblePartitionStateMap.getPartitionMap(partition);
-      List<String> preferenceList = preferenceLists.get(partition.getPartitionName());
-
-      RebalanceType rebalanceType = getRebalanceType(cache, bestPossibleMap, preferenceList,
-          stateModelDef, currentStateMap, idealState, partition.getPartitionName());
-
-      // TODO: refine getRebalanceType to return more accurate rebalance types. So the following
-      // logic doesn't need to check for more details.
-      boolean isRebalanceNeeded = false;
-
-      // Check whether partition has any ERROR state replicas
-      if (currentStateMap.values().contains(HelixDefinedState.ERROR.name())) {
-        partitionsWithErrorStateReplica.add(partition);
-      }
-
-      // Number of states required by StateModelDefinition are not satisfied, need recovery
-      if (rebalanceType.equals(RebalanceType.RECOVERY_BALANCE)) {
-        // Check if recovery is needed for this partition
-        if (!currentStateMap.equals(bestPossibleMap)) {
-          partitionsNeedRecovery.add(partition);
-          isRebalanceNeeded = true;
-        }
-      } else if (rebalanceType.equals(RebalanceType.LOAD_BALANCE)) {
-        // Number of states required by StateModelDefinition are satisfied, but to achieve
-        // BestPossibleState, need load balance
-        partitionsNeedLoadBalance.add(partition);
-        isRebalanceNeeded = true;
-      }
-
-      // Currently at BestPossibleState, no further action necessary
-      if (!isRebalanceNeeded) {
-        Map<String, String> intermediateMap = new HashMap<>(bestPossibleMap);
-        intermediatePartitionStateMap.setState(partition, intermediateMap);
-      }
-    }
-
-    if (!partitionsNeedRecovery.isEmpty()) {
-      LogUtil.logInfo(logger, _eventId, String.format(
-          "Recovery balance needed for %s partitions: %s", resourceName, partitionsNeedRecovery));
-    }
-    if (!partitionsNeedLoadBalance.isEmpty()) {
-      LogUtil.logInfo(logger, _eventId, String.format("Load balance needed for %s partitions: %s",
-          resourceName, partitionsNeedLoadBalance));
-    }
-    if (!partitionsWithErrorStateReplica.isEmpty()) {
-      LogUtil.logInfo(logger, _eventId,
-          String.format("Partition currently has an ERROR replica in %s partitions: %s",
-              resourceName, partitionsWithErrorStateReplica));
-    }
-
-    chargePendingTransition(resource, currentStateOutput, throttleController, cache, preferenceLists, stateModelDef, intermediatePartitionStateMap);
-
-    // Perform recovery balance
-    Set<Partition> recoveryThrottledPartitions =
-        recoveryRebalance(resource, bestPossiblePartitionStateMap, throttleController,
-            intermediatePartitionStateMap, partitionsNeedRecovery, currentStateOutput,
-            cache.getStateModelDef(resource.getStateModelDefRef()).getTopState(), cache);
-
-    // Perform load balance upon checking conditions below
-    Set<Partition> loadbalanceThrottledPartitions;
+    Set<String> messagesForRecovery = new HashSet<>();
+    Set<String> messagesForLoad = new HashSet<>();
+    Set<String> messagesThrottledForRecovery = new HashSet<>();
+    Set<String> messagesThrottledForLoad = new HashSet<>();
     ClusterConfig clusterConfig = cache.getClusterConfig();
 
     // If the threshold (ErrorOrRecovery) is set, then use it, if not, then check if the old
     // threshold (Error) is set. If the old threshold is set, use it. If not, use the default value
     // for the new one. This is for backward-compatibility
     int threshold = 1; // Default threshold for ErrorOrRecoveryPartitionThresholdForLoadBalance
-    int partitionCount = partitionsWithErrorStateReplica.size();
+    // Keep the error count as partition level. This logic only applies to downward state transition determination
+    int numPartitionsWithErrorReplica = (int) currentStateOutput.getCurrentStateMap(resourceName)
+        .values()
+        .stream()
+        .filter(i -> i.values().contains(HelixDefinedState.ERROR.name()))
+        .count();
     if (clusterConfig.getErrorOrRecoveryPartitionThresholdForLoadBalance() != -1) {
       // ErrorOrRecovery is set
       threshold = clusterConfig.getErrorOrRecoveryPartitionThresholdForLoadBalance();
-      partitionCount += partitionsNeedRecovery.size(); // Only add this count when the threshold is
-      // set
     } else {
       if (clusterConfig.getErrorPartitionThresholdForLoadBalance() != 0) {
         // 0 is the default value so the old threshold has been set
@@ -396,23 +338,79 @@ public class IntermediateStateCalcStage extends AbstractBaseStage {
 
     // Perform regular load balance only if the number of partitions in recovery and in error is
     // less than the threshold. Otherwise, only allow downward-transition load balance
-    boolean onlyDownwardLoadBalance = partitionCount > threshold;
+    boolean onlyDownwardLoadBalance = numPartitionsWithErrorReplica > threshold;
+
+    chargePendingTransition(resource, currentStateOutput, throttleController, cache, preferenceLists, stateModelDef);
 
-    loadbalanceThrottledPartitions = loadRebalance(resource, currentStateOutput,
-        bestPossiblePartitionStateMap, throttleController, intermediatePartitionStateMap,
-        partitionsNeedLoadBalance, currentStateOutput.getCurrentStateMap(resourceName),
-        onlyDownwardLoadBalance, stateModelDef, cache);
+    // Sort partitions in case of urgent partition need to take the quota first.
+    List<Partition> partitions = new ArrayList<>(resource.getPartitions());
+    Collections.sort(partitions, new PartitionPriorityComparator(bestPossiblePartitionStateMap.getStateMap(),
+        currentStateOutput.getCurrentStateMap(resourceName), stateModelDef.getTopState()));
+    for (Partition partition : partitions) {
+      List<Message> messagesToThrottle = new ArrayList<>(resourceMessageMap.get(partition));
+      if (messagesToThrottle == null || messagesToThrottle.isEmpty()) {
+        continue;
+      }
+
+      Map<String, String> derivedCurrentStateMap = currentStateOutput.getCurrentStateMap(resourceName, partition)
+          .entrySet()
+          .stream()
+          .collect(Collectors.toMap(entry -> entry.getKey(), entry -> entry.getValue()));
+      List<String> preferenceList = preferenceLists.get(partition.getPartitionName());
+      Map<String, Integer> requiredState = getRequiredStates(resourceName, cache, preferenceList);
+      Collections.sort(messagesToThrottle,
+          new MessagePriorityComparator(preferenceList, stateModelDef.getStatePriorityMap()));
+      for (Message message : messagesToThrottle) {
+        RebalanceType rebalanceType = getRebalanceTypePerMessage(requiredState, message, derivedCurrentStateMap);
+
+        // Number of states required by StateModelDefinition are not satisfied, need recovery
+        if (rebalanceType.equals(RebalanceType.RECOVERY_BALANCE)) {
+          messagesForRecovery.add(message.getId());
+          recoveryRebalance(resource, partition, throttleController, message, cache, messagesThrottledForRecovery,
+              resourceMessageMap);
+        } else if (rebalanceType.equals(RebalanceType.LOAD_BALANCE)) {
+          messagesForLoad.add(message.getId());
+          loadRebalance(resource, partition, throttleController, message, cache, onlyDownwardLoadBalance, stateModelDef,
+              messagesThrottledForLoad, resourceMessageMap);
+        }
+
+        // Apply the message to temporary current state map
+        if (!messagesThrottledForRecovery.contains(message.getId()) && !messagesThrottledForLoad.contains(
+            message.getId())) {
+          derivedCurrentStateMap.put(message.getTgtName(), message.getToState());
+        }
+      }
+    }
+    // TODO: We may need to optimize it to be async compute for intermediate state output.
+    PartitionStateMap intermediatePartitionStateMap =
+        new PartitionStateMap(resourceName, currentStateOutput.getCurrentStateMap(resourceName));
+    computeIntermediateMap(intermediatePartitionStateMap, currentStateOutput.getPendingMessageMap(resourceName),
+        resourceMessageMap);
+
+    if (!messagesForRecovery.isEmpty()) {
+      LogUtil.logInfo(logger, _eventId, String.format(
+          "Recovery balance needed for %s with messages: %s", resourceName, messagesForRecovery));
+    }
+    if (!messagesForLoad.isEmpty()) {
+      LogUtil.logInfo(logger, _eventId, String.format("Load balance needed for %s with messages: %s",
+          resourceName, messagesForLoad));
+    }
+    if (!partitionsWithErrorStateReplica.isEmpty()) {
+      LogUtil.logInfo(logger, _eventId,
+          String.format("Partition currently has an ERROR replica in %s partitions: %s",
+              resourceName, partitionsWithErrorStateReplica));
+    }
 
     if (clusterStatusMonitor != null) {
-      clusterStatusMonitor.updateRebalancerStats(resourceName, partitionsNeedRecovery.size(),
-          partitionsNeedLoadBalance.size(), recoveryThrottledPartitions.size(),
-          loadbalanceThrottledPartitions.size());
+      clusterStatusMonitor.updateRebalancerStats(resourceName, messagesForRecovery.size(),
+          messagesForLoad.size(), messagesThrottledForRecovery.size(),
+          messagesThrottledForLoad.size());
     }
 
     if (logger.isDebugEnabled()) {
       logPartitionMapState(resourceName, new HashSet<>(resource.getPartitions()),
-          partitionsNeedRecovery, recoveryThrottledPartitions, partitionsNeedLoadBalance,
-          loadbalanceThrottledPartitions, currentStateOutput, bestPossiblePartitionStateMap,
+          messagesForRecovery, messagesThrottledForRecovery, messagesForLoad,
+          messagesThrottledForLoad, currentStateOutput, bestPossiblePartitionStateMap,
           intermediatePartitionStateMap);
     }
 
@@ -449,8 +447,7 @@ public class IntermediateStateCalcStage extends AbstractBaseStage {
    */
   private void chargePendingTransition(Resource resource, CurrentStateOutput currentStateOutput,
       StateTransitionThrottleController throttleController, ResourceControllerDataProvider cache,
-      Map<String, List<String>> preferenceLists, StateModelDefinition stateModelDefinition,
-      PartitionStateMap intermediatePartitionStateMap) {
+      Map<String, List<String>> preferenceLists, StateModelDefinition stateModelDefinition) {
     String resourceName = resource.getResourceName();
     // check and charge pending transitions
     for (Partition partition : resource.getPartitions()) {
@@ -478,7 +475,6 @@ public class IntermediateStateCalcStage extends AbstractBaseStage {
           throttleController.chargeResource(rebalanceType, resourceName);
           throttleController.chargeCluster(rebalanceType);
         }
-        intermediatePartitionStateMap.setState(partition, message.getTgtName(), message.getToState());
       }
     }
   }
@@ -489,7 +485,6 @@ public class IntermediateStateCalcStage extends AbstractBaseStage {
    * @param resource                      the resource to throttle
    * @param throttleController            throttle controller object
    * @param messageToThrottle             the message to be throttled
-   * @param intermediatePartitionStateMap output result for this stage that intermediate state map
    * @param cache                         cache object for computational metadata from external storage
    * @param messagesThrottled             messages that have already been throttled
    * @param resourceMessageMap            the map for all messages from MessageSelectStage. Remove the message
@@ -497,10 +492,10 @@ public class IntermediateStateCalcStage extends AbstractBaseStage {
    */
   private void recoveryRebalance(Resource resource, Partition partition,
       StateTransitionThrottleController throttleController, Message messageToThrottle,
-      PartitionStateMap intermediatePartitionStateMap, ResourceControllerDataProvider cache,
-      Set<Message> messagesThrottled, Map<Partition, List<Message>> resourceMessageMap) {
+      ResourceControllerDataProvider cache, Set<String> messagesThrottled,
+      Map<Partition, List<Message>> resourceMessageMap) {
     throttleStateTransitionsForReplica(throttleController, resource.getResourceName(), partition, messageToThrottle,
-        messagesThrottled, intermediatePartitionStateMap, RebalanceType.RECOVERY_BALANCE, cache, resourceMessageMap);
+        messagesThrottled, RebalanceType.RECOVERY_BALANCE, cache, resourceMessageMap);
   }
 
   /**
@@ -509,7 +504,6 @@ public class IntermediateStateCalcStage extends AbstractBaseStage {
    * @param resource                      the resource to throttle
    * @param throttleController            throttle controller object
    * @param messageToThrottle             the message to be throttle
-   * @param intermediatePartitionStateMap output result for this stage that intermediate state map
    * @param cache                         cache object for computational metadata from external storage
    * @param onlyDownwardLoadBalance       does allow only downward load balance
    * @param stateModelDefinition          state model definition of this resource
@@ -519,16 +513,15 @@ public class IntermediateStateCalcStage extends AbstractBaseStage {
    */
   private void loadRebalance(Resource resource, Partition partition,
       StateTransitionThrottleController throttleController, Message messageToThrottle,
-      PartitionStateMap intermediatePartitionStateMap, ResourceControllerDataProvider cache,
-      boolean onlyDownwardLoadBalance, StateModelDefinition stateModelDefinition, Set<Message> messagesThrottled,
+       ResourceControllerDataProvider cache,
+      boolean onlyDownwardLoadBalance, StateModelDefinition stateModelDefinition, Set<String> messagesThrottled,
       Map<Partition, List<Message>> resourceMessageMap) {
     if (onlyDownwardLoadBalance && isLoadBalanceDownwardStateTransition(messageToThrottle, stateModelDefinition)) {
       // Remove the message already allowed for downward state transitions.
-      intermediatePartitionStateMap.setState(partition, messageToThrottle.getTgtName(), messageToThrottle.getToState());
       return;
     }
     throttleStateTransitionsForReplica(throttleController, resource.getResourceName(), partition, messageToThrottle,
-        messagesThrottled, intermediatePartitionStateMap, RebalanceType.LOAD_BALANCE, cache, resourceMessageMap);
+        messagesThrottled, RebalanceType.LOAD_BALANCE, cache, resourceMessageMap);
   }
 
   /**
@@ -540,17 +533,15 @@ public class IntermediateStateCalcStage extends AbstractBaseStage {
    * @param messageToThrottle                 the message to be throttled
    * @param messagesThrottled                 the cumulative set of messages that have been throttled already. These
    *                                          messages represent the replicas of this partition that have been throttled.
-   * @param intermediatePartitionStateMap     the cumulative partition-state mapping as a result of the throttling step
-   *                                          of IntermediateStateCalcStage
    * @param rebalanceType                     the rebalance type to charge quota
    * @param cache                             cached cluster metadata required by the throttle controller
    * @param resourceMessageMap                the map for all messages from MessageSelectStage. Remove the message
    *                                          if it has been throttled.
    */
   private void throttleStateTransitionsForReplica(StateTransitionThrottleController throttleController,
-      String resourceName, Partition partition, Message messageToThrottle, Set<Message> messagesThrottled,
-      PartitionStateMap intermediatePartitionStateMap, RebalanceType rebalanceType,
-      ResourceControllerDataProvider cache, Map<Partition, List<Message>> resourceMessageMap) {
+      String resourceName, Partition partition, Message messageToThrottle, Set<String> messagesThrottled,
+      RebalanceType rebalanceType, ResourceControllerDataProvider cache,
+      Map<Partition, List<Message>> resourceMessageMap) {
     boolean hasReachedThrottlingLimit = false;
     if (throttleController.shouldThrottleForResource(rebalanceType, resourceName)) {
       hasReachedThrottlingLimit = true;
@@ -578,13 +569,12 @@ public class IntermediateStateCalcStage extends AbstractBaseStage {
     if (!hasReachedThrottlingLimit) {
       throttleController.chargeCluster(rebalanceType);
       throttleController.chargeResource(rebalanceType, resourceName);
-      intermediatePartitionStateMap.setState(partition, messageToThrottle.getTgtName(), messageToThrottle.getToState());
     } else {
       // Intermediate Map is based on current state
       // Remove the message from MessageSelection result if it has been throttled since the message will be dispatched
       // by next stage if it is not removed.
       resourceMessageMap.get(partition).remove(messageToThrottle);
-      messagesThrottled.add(messageToThrottle);
+      messagesThrottled.add(messageToThrottle.getId());
     }
   }
 
@@ -716,8 +706,8 @@ public class IntermediateStateCalcStage extends AbstractBaseStage {
    * @param intermediateStateMap
    */
   private void logPartitionMapState(String resource, Set<Partition> allPartitions,
-      Set<Partition> recoveryPartitions, Set<Partition> recoveryThrottledPartitions,
-      Set<Partition> loadbalancePartitions, Set<Partition> loadbalanceThrottledPartitions,
+      Set<String> recoveryPartitions, Set<String> recoveryThrottledPartitions,
+      Set<String> loadbalancePartitions, Set<String> loadbalanceThrottledPartitions,
       CurrentStateOutput currentStateOutput, PartitionStateMap bestPossibleStateMap,
       PartitionStateMap intermediateStateMap) {
 
@@ -894,6 +884,26 @@ public class IntermediateStateCalcStage extends AbstractBaseStage {
   }
 
   /**
+   * Generate the IntermediateStateMap from pending messages + message generated.
+   */
+  private void computeIntermediateMap(PartitionStateMap intermediateStateMap,
+      Map<Partition, Map<String, Message>> pendingMessageMap, Map<Partition, List<Message>> resourceMessageMap) {
+    for (Map.Entry<Partition, Map<String, Message>> entry : pendingMessageMap.entrySet()) {
+      entry.getValue()
+          .entrySet()
+          .stream()
+          .forEach(
+              e -> intermediateStateMap.setState(entry.getKey(), e.getValue().getTgtName(), e.getValue().getToState()));
+    }
+
+    for (Map.Entry<Partition, List<Message>> entry : resourceMessageMap.entrySet()) {
+      entry.getValue()
+          .stream()
+          .forEach(e -> intermediateStateMap.setState(entry.getKey(), e.getTgtName(), e.getToState()));
+    }
+  }
+
+  /**
    * Handle a partition with a pending message so that the partition will not be double-charged or double-assigned during recovery and load balance.
    * @param partition
    * @param partitionsNeedRecovery
diff --git a/helix-core/src/main/java/org/apache/helix/controller/stages/MessageOutput.java b/helix-core/src/main/java/org/apache/helix/controller/stages/MessageOutput.java
index dd545c3..ad7e6c8 100644
--- a/helix-core/src/main/java/org/apache/helix/controller/stages/MessageOutput.java
+++ b/helix-core/src/main/java/org/apache/helix/controller/stages/MessageOutput.java
@@ -64,6 +64,10 @@ public class MessageOutput {
     return Collections.emptyList();
   }
 
+  public Map<Partition, List<Message>> getResourceMessageMap(String resourceName) {
+    return _messagesMap.getOrDefault(resourceName, Collections.emptyMap());
+  }
+
   @Override
   public String toString() {
     return _messagesMap.toString();