You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@helix.apache.org by ji...@apache.org on 2019/11/13 23:50:10 UTC

[helix] branch wagedRebalancer updated: Improve the WAGED rebalancer performance. (#586)

This is an automated email from the ASF dual-hosted git repository.

jiajunwang pushed a commit to branch wagedRebalancer
in repository https://gitbox.apache.org/repos/asf/helix.git


The following commit(s) were added to refs/heads/wagedRebalancer by this push:
     new 1f6ed2d  Improve the WAGED rebalancer performance. (#586)
1f6ed2d is described below

commit 1f6ed2d76bf36210ade0356f4ff551acd393e371
Author: Jiajun Wang <18...@users.noreply.github.com>
AuthorDate: Wed Nov 13 15:49:59 2019 -0800

    Improve the WAGED rebalancer performance. (#586)
    
    This change improves the rebalance's speed by 2x to 5x depends on the host capacity.
    
    Parallelism the loop processing whenever possible and help to improve the performance. This does not change the logic.
    Avoid some duplicate logic in the loop. Put the calculation outside the loop and only do it once.
---
 .../rebalancer/waged/WagedRebalancer.java          | 185 +++++++++++----------
 .../constraints/ConstraintBasedAlgorithm.java      |  26 +--
 .../rebalancer/waged/model/AssignableNode.java     |  17 +-
 .../rebalancer/waged/model/AssignableReplica.java  |   4 +-
 .../rebalancer/waged/model/ClusterModel.java       |   2 +-
 .../waged/model/ClusterModelProvider.java          |  27 ++-
 .../rebalancer/waged/model/OptimalAssignment.java  |  36 ++--
 .../stages/BestPossibleStateCalcStage.java         |   2 -
 8 files changed, 150 insertions(+), 149 deletions(-)

diff --git a/helix-core/src/main/java/org/apache/helix/controller/rebalancer/waged/WagedRebalancer.java b/helix-core/src/main/java/org/apache/helix/controller/rebalancer/waged/WagedRebalancer.java
index c472e77..bd28de0 100644
--- a/helix-core/src/main/java/org/apache/helix/controller/rebalancer/waged/WagedRebalancer.java
+++ b/helix-core/src/main/java/org/apache/helix/controller/rebalancer/waged/WagedRebalancer.java
@@ -84,7 +84,17 @@ public class WagedRebalancer {
   private final HelixManager _manager;
   private final MappingCalculator<ResourceControllerDataProvider> _mappingCalculator;
   private final AssignmentMetadataStore _assignmentMetadataStore;
+
   private final MetricCollector _metricCollector;
+  private final CountMetric _rebalanceFailureCount;
+  private final CountMetric _globalBaselineCalcCounter;
+  private final LatencyMetric _globalBaselineCalcLatency;
+  private final LatencyMetric _writeLatency;
+  private final CountMetric _partialRebalanceCounter;
+  private final LatencyMetric _partialRebalanceLatency;
+  private final LatencyMetric _stateReadLatency;
+  private final BaselineDivergenceGauge _baselineDivergenceGauge;
+
   private RebalanceAlgorithm _rebalanceAlgorithm;
   private Map<ClusterConfig.GlobalRebalancePreferenceKey, Integer> _preference =
       NOT_CONFIGURED_PREFERENCE;
@@ -155,10 +165,38 @@ public class WagedRebalancer {
     _rebalanceAlgorithm = algorithm;
     _mappingCalculator = mappingCalculator;
     _manager = manager;
+
     // If metricCollector is null, instantiate a version that does not register metrics in order to
     // allow rebalancer to proceed
     _metricCollector =
         metricCollector == null ? new WagedRebalancerMetricCollector() : metricCollector;
+    _rebalanceFailureCount = _metricCollector.getMetric(
+        WagedRebalancerMetricCollector.WagedRebalancerMetricNames.RebalanceFailureCounter.name(),
+        CountMetric.class);
+    _globalBaselineCalcCounter = _metricCollector.getMetric(
+        WagedRebalancerMetricCollector.WagedRebalancerMetricNames.GlobalBaselineCalcCounter.name(),
+        CountMetric.class);
+    _globalBaselineCalcLatency = _metricCollector.getMetric(
+        WagedRebalancerMetricCollector.WagedRebalancerMetricNames.GlobalBaselineCalcLatencyGauge
+            .name(),
+        LatencyMetric.class);
+    _partialRebalanceCounter = _metricCollector.getMetric(
+        WagedRebalancerMetricCollector.WagedRebalancerMetricNames.PartialRebalanceCounter.name(),
+        CountMetric.class);
+    _partialRebalanceLatency = _metricCollector.getMetric(
+        WagedRebalancerMetricCollector.WagedRebalancerMetricNames.PartialRebalanceLatencyGauge
+            .name(),
+        LatencyMetric.class);
+    _writeLatency = _metricCollector.getMetric(
+        WagedRebalancerMetricCollector.WagedRebalancerMetricNames.StateWriteLatencyGauge.name(),
+        LatencyMetric.class);
+    _stateReadLatency = _metricCollector.getMetric(
+        WagedRebalancerMetricCollector.WagedRebalancerMetricNames.StateReadLatencyGauge.name(),
+        LatencyMetric.class);
+    _baselineDivergenceGauge = _metricCollector.getMetric(
+        WagedRebalancerMetricCollector.WagedRebalancerMetricNames.BaselineDivergenceGauge.name(),
+        BaselineDivergenceGauge.class);
+
     _changeDetector = new ResourceChangeDetector(true);
   }
 
@@ -209,10 +247,7 @@ public class WagedRebalancer {
     } catch (HelixRebalanceException ex) {
       LOG.error("Failed to calculate the new assignments.", ex);
       // Record the failure in metrics.
-      CountMetric rebalanceFailureCount = _metricCollector.getMetric(
-          WagedRebalancerMetricCollector.WagedRebalancerMetricNames.RebalanceFailureCounter.name(),
-          CountMetric.class);
-      rebalanceFailureCount.increment(1L);
+      _rebalanceFailureCount.increment(1L);
 
       HelixRebalanceException.Type failureType = ex.getFailureType();
       if (failureType.equals(HelixRebalanceException.Type.INVALID_REBALANCER_STATUS) || failureType
@@ -236,22 +271,23 @@ public class WagedRebalancer {
     // Construct the new best possible states according to the current state and target assignment.
     // Note that the new ideal state might be an intermediate state between the current state and
     // the target assignment.
-    for (IdealState is : newIdealStates.values()) {
-      String resourceName = is.getResourceName();
+    newIdealStates.values().parallelStream().forEach(idealState -> {
+      String resourceName = idealState.getResourceName();
       // Adjust the states according to the current state.
-      ResourceAssignment finalAssignment = _mappingCalculator.computeBestPossiblePartitionState(
-          clusterData, is, resourceMap.get(resourceName), currentStateOutput);
+      ResourceAssignment finalAssignment = _mappingCalculator
+          .computeBestPossiblePartitionState(clusterData, idealState, resourceMap.get(resourceName),
+              currentStateOutput);
 
       // Clean up the state mapping fields. Use the final assignment that is calculated by the
       // mapping calculator to replace them.
-      is.getRecord().getMapFields().clear();
+      idealState.getRecord().getMapFields().clear();
       for (Partition partition : finalAssignment.getMappedPartitions()) {
         Map<String, String> newStateMap = finalAssignment.getReplicaMap(partition);
         // if the final states cannot be generated, override the best possible state with empty map.
-        is.setInstanceStateMap(partition.getPartitionName(),
+        idealState.setInstanceStateMap(partition.getPartitionName(),
             newStateMap == null ? Collections.emptyMap() : newStateMap);
       }
-    }
+    });
     LOG.info("Finish computing new ideal states for resources: {}",
         resourceMap.keySet().toString());
     return newIdealStates;
@@ -296,8 +332,8 @@ public class WagedRebalancer {
       Set<String> activeNodes, final CurrentStateOutput currentStateOutput)
       throws HelixRebalanceException {
     getChangeDetector().updateSnapshots(clusterData);
-    // Get all the changed items' information
-    Map<HelixConstants.ChangeType, Set<String>> clusterChanges =
+    // Get all the changed items' information. Filter for the items that have content changed.
+    final Map<HelixConstants.ChangeType, Set<String>> clusterChanges =
         getChangeDetector().getChangeTypes().stream()
             .collect(Collectors.toMap(changeType -> changeType, changeType -> {
               Set<String> itemKeys = new HashSet<>();
@@ -305,18 +341,12 @@ public class WagedRebalancer {
               itemKeys.addAll(getChangeDetector().getChangesByType(changeType));
               itemKeys.addAll(getChangeDetector().getRemovalsByType(changeType));
               return itemKeys;
-            }));
-    // Filter for the items that have content changed.
-    clusterChanges =
-        clusterChanges.entrySet().stream().filter(changeEntry -> !changeEntry.getValue().isEmpty())
+            })).entrySet().stream().filter(changeEntry -> !changeEntry.getValue().isEmpty())
             .collect(Collectors
                 .toMap(changeEntry -> changeEntry.getKey(), changeEntry -> changeEntry.getValue()));
 
     // Perform Global Baseline Calculation
-    if (clusterChanges.keySet().stream()
-        .anyMatch(GLOBAL_REBALANCE_REQUIRED_CHANGE_TYPES::contains)) {
-      refreshBaseline(clusterData, clusterChanges, resourceMap, currentStateOutput);
-    }
+    refreshBaseline(clusterData, clusterChanges, resourceMap, currentStateOutput);
 
     // Perform partial rebalance
     Map<String, ResourceAssignment> newAssignment =
@@ -362,47 +392,41 @@ public class WagedRebalancer {
   // TODO make the Baseline calculation async if complicated algorithm is used for the Baseline
   private void refreshBaseline(ResourceControllerDataProvider clusterData,
       Map<HelixConstants.ChangeType, Set<String>> clusterChanges, Map<String, Resource> resourceMap,
-      final CurrentStateOutput currentStateOutput) throws HelixRebalanceException {
-    LOG.info("Start calculating the new baseline.");
-    CountMetric globalBaselineCalcCounter = _metricCollector.getMetric(
-        WagedRebalancerMetricCollector.WagedRebalancerMetricNames.GlobalBaselineCalcCounter.name(),
-        CountMetric.class);
-    globalBaselineCalcCounter.increment(1L);
-
-    LatencyMetric globalBaselineCalcLatency = _metricCollector.getMetric(
-        WagedRebalancerMetricCollector.WagedRebalancerMetricNames.GlobalBaselineCalcLatencyGauge
-            .name(),
-        LatencyMetric.class);
-    globalBaselineCalcLatency.startMeasuringLatency();
-    // Read the baseline from metadata store
-    Map<String, ResourceAssignment> currentBaseline =
-        getBaselineAssignment(_assignmentMetadataStore, currentStateOutput, resourceMap.keySet());
-
-    // For baseline calculation
-    // 1. Ignore node status (disable/offline).
-    // 2. Use the baseline as the previous best possible assignment since there is no "baseline" for
-    // the baseline.
-    Map<String, ResourceAssignment> newBaseline = calculateAssignment(clusterData, clusterChanges,
-        resourceMap, clusterData.getAllInstances(), Collections.emptyMap(), currentBaseline);
-
-    // Write the new baseline to metadata store
-    if (_assignmentMetadataStore != null) {
-      try {
-        LatencyMetric writeLatency = _metricCollector.getMetric(
-            WagedRebalancerMetricCollector.WagedRebalancerMetricNames.StateWriteLatencyGauge.name(),
-            LatencyMetric.class);
-        writeLatency.startMeasuringLatency();
-        _assignmentMetadataStore.persistBaseline(newBaseline);
-        writeLatency.endMeasuringLatency();
-      } catch (Exception ex) {
-        throw new HelixRebalanceException("Failed to persist the new baseline assignment.",
-            HelixRebalanceException.Type.INVALID_REBALANCER_STATUS, ex);
+      final CurrentStateOutput currentStateOutput)
+      throws HelixRebalanceException {
+    if (clusterChanges.keySet().stream()
+        .anyMatch(GLOBAL_REBALANCE_REQUIRED_CHANGE_TYPES::contains)) {
+      LOG.info("Start calculating the new baseline.");
+      _globalBaselineCalcCounter.increment(1L);
+      _globalBaselineCalcLatency.startMeasuringLatency();
+
+      // For baseline calculation
+      // 1. Ignore node status (disable/offline).
+      // 2. Use the baseline as the previous best possible assignment since there is no "baseline" for
+      // the baseline.
+      // Read the baseline from metadata store
+      Map<String, ResourceAssignment> currentBaseline =
+          getBaselineAssignment(_assignmentMetadataStore, currentStateOutput, resourceMap.keySet());
+      Map<String, ResourceAssignment> newBaseline =
+          calculateAssignment(clusterData, clusterChanges, resourceMap,
+              clusterData.getAllInstances(), Collections.emptyMap(), currentBaseline);
+
+      // Write the new baseline to metadata store
+      if (_assignmentMetadataStore != null) {
+        try {
+          _writeLatency.startMeasuringLatency();
+          _assignmentMetadataStore.persistBaseline(newBaseline);
+          _writeLatency.endMeasuringLatency();
+        } catch (Exception ex) {
+          throw new HelixRebalanceException("Failed to persist the new baseline assignment.",
+              HelixRebalanceException.Type.INVALID_REBALANCER_STATUS, ex);
+        }
+      } else {
+        LOG.debug("Assignment Metadata Store is empty. Skip persist the baseline assignment.");
       }
-    } else {
-      LOG.debug("Assignment Metadata Store is empty. Skip persist the baseline assignment.");
+      _globalBaselineCalcLatency.endMeasuringLatency();
+      LOG.info("Finish calculating the new baseline.");
     }
-    globalBaselineCalcLatency.endMeasuringLatency();
-    LOG.info("Finish calculating the new baseline.");
   }
 
   private Map<String, ResourceAssignment> partialRebalance(
@@ -411,16 +435,8 @@ public class WagedRebalancer {
       Set<String> activeNodes, final CurrentStateOutput currentStateOutput)
       throws HelixRebalanceException {
     LOG.info("Start calculating the new best possible assignment.");
-    CountMetric partialRebalanceCounter = _metricCollector.getMetric(
-        WagedRebalancerMetricCollector.WagedRebalancerMetricNames.PartialRebalanceCounter.name(),
-        CountMetric.class);
-    partialRebalanceCounter.increment(1L);
-
-    LatencyMetric partialRebalanceLatency = _metricCollector.getMetric(
-        WagedRebalancerMetricCollector.WagedRebalancerMetricNames.PartialRebalanceLatencyGauge
-            .name(),
-        LatencyMetric.class);
-    partialRebalanceLatency.startMeasuringLatency();
+    _partialRebalanceCounter.increment(1L);
+    _partialRebalanceLatency.startMeasuringLatency();
     // TODO: Consider combining the metrics for both baseline/best possible?
     // Read the baseline from metadata store
     Map<String, ResourceAssignment> currentBaseline =
@@ -442,22 +458,17 @@ public class WagedRebalancer {
     for (Map.Entry<String, ResourceAssignment> entry : newAssignment.entrySet()) {
       newAssignmentCopy.put(entry.getKey(), new ResourceAssignment(entry.getValue().getRecord()));
     }
-    BaselineDivergenceGauge baselineDivergenceGauge = _metricCollector.getMetric(
-        WagedRebalancerMetricCollector.WagedRebalancerMetricNames.BaselineDivergenceGauge.name(),
-        BaselineDivergenceGauge.class);
-    baselineDivergenceGauge.asyncMeasureAndUpdateValue(clusterData.getAsyncTasksThreadPool(),
+
+    _baselineDivergenceGauge.asyncMeasureAndUpdateValue(clusterData.getAsyncTasksThreadPool(),
         currentBaseline, newAssignmentCopy);
 
     if (_assignmentMetadataStore != null) {
       try {
-        LatencyMetric writeLatency = _metricCollector.getMetric(
-            WagedRebalancerMetricCollector.WagedRebalancerMetricNames.StateWriteLatencyGauge.name(),
-            LatencyMetric.class);
-        writeLatency.startMeasuringLatency();
+        _writeLatency.startMeasuringLatency();
         // TODO Test to confirm if persisting the final assignment (with final partition states)
         // would be a better option.
         _assignmentMetadataStore.persistBestPossibleAssignment(newAssignment);
-        writeLatency.endMeasuringLatency();
+        _writeLatency.endMeasuringLatency();
       } catch (Exception ex) {
         throw new HelixRebalanceException("Failed to persist the new best possible assignment.",
             HelixRebalanceException.Type.INVALID_REBALANCER_STATUS, ex);
@@ -465,7 +476,7 @@ public class WagedRebalancer {
     } else {
       LOG.debug("Assignment Metadata Store is empty. Skip persist the baseline assignment.");
     }
-    partialRebalanceLatency.endMeasuringLatency();
+    _partialRebalanceLatency.endMeasuringLatency();
     LOG.info("Finish calculating the new best possible assignment.");
     return newAssignment;
   }
@@ -561,12 +572,9 @@ public class WagedRebalancer {
     Map<String, ResourceAssignment> currentBaseline = Collections.emptyMap();
     if (assignmentMetadataStore != null) {
       try {
-        LatencyMetric stateReadLatency = _metricCollector.getMetric(
-            WagedRebalancerMetricCollector.WagedRebalancerMetricNames.StateReadLatencyGauge.name(),
-            LatencyMetric.class);
-        stateReadLatency.startMeasuringLatency();
+        _stateReadLatency.startMeasuringLatency();
         currentBaseline = assignmentMetadataStore.getBaseline();
-        stateReadLatency.endMeasuringLatency();
+        _stateReadLatency.endMeasuringLatency();
       } catch (Exception ex) {
         throw new HelixRebalanceException(
             "Failed to get the current baseline assignment because of unexpected error.",
@@ -595,12 +603,9 @@ public class WagedRebalancer {
     Map<String, ResourceAssignment> currentBestAssignment = Collections.emptyMap();
     if (assignmentMetadataStore != null) {
       try {
-        LatencyMetric stateReadLatency = _metricCollector.getMetric(
-            WagedRebalancerMetricCollector.WagedRebalancerMetricNames.StateReadLatencyGauge.name(),
-            LatencyMetric.class);
-        stateReadLatency.startMeasuringLatency();
+        _stateReadLatency.startMeasuringLatency();
         currentBestAssignment = assignmentMetadataStore.getBestPossibleAssignment();
-        stateReadLatency.endMeasuringLatency();
+        _stateReadLatency.endMeasuringLatency();
       } catch (Exception ex) {
         throw new HelixRebalanceException(
             "Failed to get the current best possible assignment because of unexpected error.",
diff --git a/helix-core/src/main/java/org/apache/helix/controller/rebalancer/waged/constraints/ConstraintBasedAlgorithm.java b/helix-core/src/main/java/org/apache/helix/controller/rebalancer/waged/constraints/ConstraintBasedAlgorithm.java
index 65737f7..0956341 100644
--- a/helix-core/src/main/java/org/apache/helix/controller/rebalancer/waged/constraints/ConstraintBasedAlgorithm.java
+++ b/helix-core/src/main/java/org/apache/helix/controller/rebalancer/waged/constraints/ConstraintBasedAlgorithm.java
@@ -27,7 +27,7 @@ import java.util.Map;
 import java.util.Objects;
 import java.util.Optional;
 import java.util.Set;
-import java.util.function.Function;
+import java.util.concurrent.ConcurrentHashMap;
 import java.util.stream.Collectors;
 
 import com.google.common.collect.Maps;
@@ -90,8 +90,8 @@ class ConstraintBasedAlgorithm implements RebalanceAlgorithm {
   private Optional<AssignableNode> getNodeWithHighestPoints(AssignableReplica replica,
       List<AssignableNode> assignableNodes, ClusterContext clusterContext,
       OptimalAssignment optimalAssignment) {
-    Map<AssignableNode, List<HardConstraint>> hardConstraintFailures = new HashMap<>();
-    List<AssignableNode> candidateNodes = assignableNodes.stream().filter(candidateNode -> {
+    Map<AssignableNode, List<HardConstraint>> hardConstraintFailures = new ConcurrentHashMap<>();
+    List<AssignableNode> candidateNodes = assignableNodes.parallelStream().filter(candidateNode -> {
       boolean isValid = true;
       // need to record all the failure reasons and it gives us the ability to debug/fix the runtime
       // cluster environment
@@ -104,16 +104,17 @@ class ConstraintBasedAlgorithm implements RebalanceAlgorithm {
       }
       return isValid;
     }).collect(Collectors.toList());
+
     if (candidateNodes.isEmpty()) {
       optimalAssignment.recordAssignmentFailure(replica,
           Maps.transformValues(hardConstraintFailures, this::convertFailureReasons));
       return Optional.empty();
     }
 
-    Function<AssignableNode, Double> calculatePoints =
-        (candidateNode) -> getAssignmentNormalizedScore(candidateNode, replica, clusterContext);
-
-    return candidateNodes.stream().max(Comparator.comparing(calculatePoints));
+    return candidateNodes.parallelStream().map(node -> new HashMap.SimpleEntry<>(node,
+        getAssignmentNormalizedScore(node, replica, clusterContext)))
+        .max(Comparator.comparingDouble((scoreEntry) -> scoreEntry.getValue()))
+        .map(Map.Entry::getKey);
   }
 
   private double getAssignmentNormalizedScore(AssignableNode node, AssignableReplica replica,
@@ -146,6 +147,11 @@ class ConstraintBasedAlgorithm implements RebalanceAlgorithm {
     Map<String, ResourceAssignment> baselineAssignment =
         clusterModel.getContext().getBaselineAssignment();
 
+    Map<String, Integer> replicaHashCodeMap = orderedAssignableReplicas.parallelStream().collect(
+        Collectors.toMap(AssignableReplica::toString,
+            replica -> Objects.hash(replica.toString(), clusterModel.getAssignableNodes().keySet()),
+            (hash1, hash2) -> hash2));
+
     // 1. Sort according if the assignment exists in the best possible and/or baseline assignment
     // 2. Sort according to the state priority. Note that prioritizing the top state is required.
     // Or the greedy algorithm will unnecessarily shuffle the states between replicas.
@@ -170,10 +176,8 @@ class ConstraintBasedAlgorithm implements RebalanceAlgorithm {
             // Note that to ensure the algorithm is deterministic with the same inputs, do not use
             // Random functions here. Use hashcode based on the cluster topology information to get
             // a controlled randomized order is good enough.
-            Long replicaHash1 = (long) Objects
-                .hash(replica1.toString(), clusterModel.getAssignableNodes().keySet());
-            Long replicaHash2 = (long) Objects
-                .hash(replica2.toString(), clusterModel.getAssignableNodes().keySet());
+            Integer replicaHash1 = replicaHashCodeMap.get(replica1.toString());
+            Integer replicaHash2 = replicaHashCodeMap.get(replica2.toString());
             if (!replicaHash1.equals(replicaHash2)) {
               return replicaHash1.compareTo(replicaHash2);
             } else {
diff --git a/helix-core/src/main/java/org/apache/helix/controller/rebalancer/waged/model/AssignableNode.java b/helix-core/src/main/java/org/apache/helix/controller/rebalancer/waged/model/AssignableNode.java
index 09a3cba..06d4976 100644
--- a/helix-core/src/main/java/org/apache/helix/controller/rebalancer/waged/model/AssignableNode.java
+++ b/helix-core/src/main/java/org/apache/helix/controller/rebalancer/waged/model/AssignableNode.java
@@ -104,7 +104,7 @@ public class AssignableNode implements Comparable<AssignableNode> {
 
     // Update the global state after all single replications' calculation is done.
     for (String capacityKey : totalPartitionCapacity.keySet()) {
-      updateCapacityAndUtilization(capacityKey, totalPartitionCapacity.get(capacityKey));
+      updateRemainingCapacity(capacityKey, totalPartitionCapacity.get(capacityKey));
     }
   }
 
@@ -115,7 +115,7 @@ public class AssignableNode implements Comparable<AssignableNode> {
   void assign(AssignableReplica assignableReplica) {
     addToAssignmentRecord(assignableReplica);
     assignableReplica.getCapacity().entrySet().stream()
-        .forEach(capacity -> updateCapacityAndUtilization(capacity.getKey(), capacity.getValue()));
+            .forEach(capacity -> updateRemainingCapacity(capacity.getKey(), capacity.getValue()));
   }
 
   /**
@@ -145,13 +145,13 @@ public class AssignableNode implements Comparable<AssignableNode> {
 
     AssignableReplica removedReplica = partitionMap.remove(partitionName);
     removedReplica.getCapacity().entrySet().stream()
-        .forEach(entry -> updateCapacityAndUtilization(entry.getKey(), -1 * entry.getValue()));
+        .forEach(entry -> updateRemainingCapacity(entry.getKey(), -1 * entry.getValue()));
   }
 
   /**
    * @return A set of all assigned replicas on the node.
    */
-  public Set<AssignableReplica> getAssignedReplicas() {
+  Set<AssignableReplica> getAssignedReplicas() {
     return _currentAssignedReplicaMap.values().stream()
         .flatMap(replicaMap -> replicaMap.values().stream()).collect(Collectors.toSet());
   }
@@ -159,7 +159,7 @@ public class AssignableNode implements Comparable<AssignableNode> {
   /**
    * @return The current assignment in a map of <resource name, set of partition names>
    */
-  public Map<String, Set<String>> getAssignedPartitionsMap() {
+  Map<String, Set<String>> getAssignedPartitionsMap() {
     Map<String, Set<String>> assignmentMap = new HashMap<>();
     for (String resourceName : _currentAssignedReplicaMap.keySet()) {
       assignmentMap.put(resourceName, _currentAssignedReplicaMap.get(resourceName).keySet());
@@ -180,7 +180,7 @@ public class AssignableNode implements Comparable<AssignableNode> {
    * @return A set of the current assigned replicas' partition names with the top state in the
    *         specified resource.
    */
-  public Set<String> getAssignedTopStatePartitionsByResource(String resource) {
+  Set<String> getAssignedTopStatePartitionsByResource(String resource) {
     return _currentAssignedReplicaMap.getOrDefault(resource, Collections.emptyMap()).entrySet()
         .stream().filter(partitionEntry -> partitionEntry.getValue().isReplicaTopState())
         .map(partitionEntry -> partitionEntry.getKey()).collect(Collectors.toSet());
@@ -335,14 +335,13 @@ public class AssignableNode implements Comparable<AssignableNode> {
     }
   }
 
-  private void updateCapacityAndUtilization(String capacityKey, int usage) {
+  private void updateRemainingCapacity(String capacityKey, int usage) {
     if (!_remainingCapacity.containsKey(capacityKey)) {
       //if the capacityKey belongs to replicas does not exist in the instance's capacity,
       // it will be treated as if it has unlimited capacity of that capacityKey
       return;
     }
-    int newCapacity = _remainingCapacity.get(capacityKey) - usage;
-    _remainingCapacity.put(capacityKey, newCapacity);
+    _remainingCapacity.put(capacityKey, _remainingCapacity.get(capacityKey) - usage);
   }
 
   /**
diff --git a/helix-core/src/main/java/org/apache/helix/controller/rebalancer/waged/model/AssignableReplica.java b/helix-core/src/main/java/org/apache/helix/controller/rebalancer/waged/model/AssignableReplica.java
index 12b5105..fdcc03a 100644
--- a/helix-core/src/main/java/org/apache/helix/controller/rebalancer/waged/model/AssignableReplica.java
+++ b/helix-core/src/main/java/org/apache/helix/controller/rebalancer/waged/model/AssignableReplica.java
@@ -38,6 +38,7 @@ import org.slf4j.LoggerFactory;
 public class AssignableReplica implements Comparable<AssignableReplica> {
   private static final Logger LOG = LoggerFactory.getLogger(AssignableReplica.class);
 
+  private final String _replicaKey;
   private final String _partitionName;
   private final String _resourceName;
   private final String _resourceInstanceGroupTag;
@@ -64,6 +65,7 @@ public class AssignableReplica implements Comparable<AssignableReplica> {
     _capacityUsage = fetchCapacityUsage(partitionName, resourceConfig, clusterConfig);
     _resourceInstanceGroupTag = resourceConfig.getInstanceGroupTag();
     _resourceMaxPartitionsPerInstance = resourceConfig.getMaxPartitionsPerInstance();
+    _replicaKey = generateReplicaKey(_resourceName, _partitionName,_replicaState);
   }
 
   public Map<String, Integer> getCapacity() {
@@ -104,7 +106,7 @@ public class AssignableReplica implements Comparable<AssignableReplica> {
 
   @Override
   public String toString() {
-    return generateReplicaKey(_resourceName, _partitionName, _replicaState);
+    return _replicaKey;
   }
 
   @Override
diff --git a/helix-core/src/main/java/org/apache/helix/controller/rebalancer/waged/model/ClusterModel.java b/helix-core/src/main/java/org/apache/helix/controller/rebalancer/waged/model/ClusterModel.java
index 3d31c04..57ffa42 100644
--- a/helix-core/src/main/java/org/apache/helix/controller/rebalancer/waged/model/ClusterModel.java
+++ b/helix-core/src/main/java/org/apache/helix/controller/rebalancer/waged/model/ClusterModel.java
@@ -58,7 +58,7 @@ public class ClusterModel {
             .toMap(AssignableReplica::toString, replica -> replica,
                 (oldValue, newValue) -> oldValue)));
 
-    _assignableNodeMap = assignableNodes.stream()
+    _assignableNodeMap = assignableNodes.parallelStream()
         .collect(Collectors.toMap(AssignableNode::getInstanceName, node -> node));
   }
 
diff --git a/helix-core/src/main/java/org/apache/helix/controller/rebalancer/waged/model/ClusterModelProvider.java b/helix-core/src/main/java/org/apache/helix/controller/rebalancer/waged/model/ClusterModelProvider.java
index f777534..dc36fba 100644
--- a/helix-core/src/main/java/org/apache/helix/controller/rebalancer/waged/model/ClusterModelProvider.java
+++ b/helix-core/src/main/java/org/apache/helix/controller/rebalancer/waged/model/ClusterModelProvider.java
@@ -80,7 +80,7 @@ public class ClusterModelProvider {
             dataProvider.getLiveInstances().keySet(), bestPossibleAssignment, allocatedReplicas);
 
     // Update the allocated replicas to the assignable nodes.
-    assignableNodes.stream().forEach(node -> node.assignInitBatch(
+    assignableNodes.parallelStream().forEach(node -> node.assignInitBatch(
         allocatedReplicas.getOrDefault(node.getInstanceName(), Collections.emptySet())));
 
     // Construct and initialize cluster context.
@@ -207,7 +207,7 @@ public class ClusterModelProvider {
    */
   private static Set<AssignableNode> parseAllNodes(ClusterConfig clusterConfig,
       Map<String, InstanceConfig> instanceConfigMap, Set<String> activeInstances) {
-    return activeInstances.stream().map(
+    return activeInstances.parallelStream().map(
         instanceName -> new AssignableNode(clusterConfig, instanceConfigMap.get(instanceName),
             instanceName))
         .collect(Collectors.toSet());
@@ -224,10 +224,10 @@ public class ClusterModelProvider {
   private static Map<String, Set<AssignableReplica>> parseAllReplicas(
       ResourceControllerDataProvider dataProvider, Map<String, Resource> resourceMap,
       Set<AssignableNode> assignableNodes) {
-    Map<String, Set<AssignableReplica>> totalReplicaMap = new HashMap<>();
     ClusterConfig clusterConfig = dataProvider.getClusterConfig();
-
-    for (String resourceName : resourceMap.keySet()) {
+    int activeFaultZoneCount = assignableNodes.stream().map(node -> node.getFaultZone())
+        .collect(Collectors.toSet()).size();
+    return resourceMap.keySet().parallelStream().map(resourceName -> {
       ResourceConfig resourceConfig = dataProvider.getResourceConfig(resourceName);
       if (resourceConfig == null) {
         resourceConfig = new ResourceConfig(resourceName);
@@ -244,26 +244,21 @@ public class ClusterModelProvider {
             .format("Cannot find state model definition %s for resource %s.",
                 is.getStateModelDefRef(), resourceName));
       }
-
-      int activeFaultZoneCount =
-          assignableNodes.stream().map(node -> node.getFaultZone()).collect(Collectors.toSet())
-              .size();
       Map<String, Integer> stateCountMap =
           def.getStateCountMap(activeFaultZoneCount, is.getReplicaCount(assignableNodes.size()));
-
+      mergeIdealStateWithResourceConfig(resourceConfig, is);
+      Set<AssignableReplica> replicas = new HashSet<>();
       for (String partition : is.getPartitionSet()) {
         for (Map.Entry<String, Integer> entry : stateCountMap.entrySet()) {
           String state = entry.getKey();
           for (int i = 0; i < entry.getValue(); i++) {
-            mergeIdealStateWithResourceConfig(resourceConfig, is);
-            totalReplicaMap.computeIfAbsent(resourceName, key -> new HashSet<>()).add(
-                new AssignableReplica(clusterConfig, resourceConfig, partition, state,
-                    def.getStatePriorityMap().get(state)));
+            replicas.add(new AssignableReplica(clusterConfig, resourceConfig, partition, state,
+                def.getStatePriorityMap().get(state)));
           }
         }
       }
-    }
-    return totalReplicaMap;
+      return new HashMap.SimpleEntry<>(resourceName, replicas);
+    }).collect(Collectors.toMap(entry -> entry.getKey(), entry -> entry.getValue()));
   }
 
   /**
diff --git a/helix-core/src/main/java/org/apache/helix/controller/rebalancer/waged/model/OptimalAssignment.java b/helix-core/src/main/java/org/apache/helix/controller/rebalancer/waged/model/OptimalAssignment.java
index 138f30c..1ff00c9 100644
--- a/helix-core/src/main/java/org/apache/helix/controller/rebalancer/waged/model/OptimalAssignment.java
+++ b/helix-core/src/main/java/org/apache/helix/controller/rebalancer/waged/model/OptimalAssignment.java
@@ -19,7 +19,7 @@ package org.apache.helix.controller.rebalancer.waged.model;
  * under the License.
  */
 
-import java.util.ArrayList;
+import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
@@ -35,7 +35,7 @@ import org.apache.helix.model.ResourceAssignment;
  * Note that this class is not thread safe.
  */
 public class OptimalAssignment {
-  private Map<AssignableNode, List<AssignableReplica>> _optimalAssignment = new HashMap<>();
+  private Map<String, ResourceAssignment> _optimalAssignment = Collections.emptyMap();
   private Map<AssignableReplica, Map<AssignableNode, List<String>>> _failedAssignments =
       new HashMap<>();
 
@@ -45,23 +45,9 @@ public class OptimalAssignment {
    * @param clusterModel
    */
   public void updateAssignments(ClusterModel clusterModel) {
-    _optimalAssignment.clear();
-    clusterModel.getAssignableNodes().values().stream()
-        .forEach(node -> _optimalAssignment.put(node, new ArrayList<>(node.getAssignedReplicas())));
-  }
-
-  /**
-   * @return The optimal assignment in the form of a <Resource Name, ResourceAssignment> map.
-   */
-  public Map<String, ResourceAssignment> getOptimalResourceAssignment() {
-    if (hasAnyFailure()) {
-      throw new HelixException(
-          "Cannot get the optimal resource assignment since a calculation failure is recorded. "
-              + getFailures());
-    }
     Map<String, ResourceAssignment> assignmentMap = new HashMap<>();
-    for (AssignableNode node : _optimalAssignment.keySet()) {
-      for (AssignableReplica replica : _optimalAssignment.get(node)) {
+    for (AssignableNode node : clusterModel.getAssignableNodes().values()) {
+      for (AssignableReplica replica : node.getAssignedReplicas()) {
         String resourceName = replica.getResourceName();
         Partition partition = new Partition(replica.getPartitionName());
         ResourceAssignment resourceAssignment = assignmentMap
@@ -76,7 +62,19 @@ public class OptimalAssignment {
         resourceAssignment.addReplicaMap(partition, partitionStateMap);
       }
     }
-    return assignmentMap;
+    _optimalAssignment = assignmentMap;
+  }
+
+  /**
+   * @return The optimal assignment in the form of a <Resource Name, ResourceAssignment> map.
+   */
+  public Map<String, ResourceAssignment> getOptimalResourceAssignment() {
+    if (hasAnyFailure()) {
+      throw new HelixException(
+          "Cannot get the optimal resource assignment since a calculation failure is recorded. "
+              + getFailures());
+    }
+    return _optimalAssignment;
   }
 
   public void recordAssignmentFailure(AssignableReplica replica,
diff --git a/helix-core/src/main/java/org/apache/helix/controller/stages/BestPossibleStateCalcStage.java b/helix-core/src/main/java/org/apache/helix/controller/stages/BestPossibleStateCalcStage.java
index 671604e..fa580b7 100644
--- a/helix-core/src/main/java/org/apache/helix/controller/stages/BestPossibleStateCalcStage.java
+++ b/helix-core/src/main/java/org/apache/helix/controller/stages/BestPossibleStateCalcStage.java
@@ -52,8 +52,6 @@ import org.apache.helix.model.ResourceAssignment;
 import org.apache.helix.model.StateModelDefinition;
 import org.apache.helix.monitoring.mbeans.ClusterStatusMonitor;
 import org.apache.helix.monitoring.mbeans.ResourceMonitor;
-import org.apache.helix.monitoring.metrics.MetricCollector;
-import org.apache.helix.monitoring.metrics.WagedRebalancerMetricCollector;
 import org.apache.helix.task.TaskConstants;
 import org.apache.helix.util.HelixUtil;
 import org.slf4j.Logger;