You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@pinot.apache.org by ja...@apache.org on 2019/06/14 20:46:30 UTC

[incubator-pinot] branch master updated: Add interface and implementations for the new segment assignment (#4269)

This is an automated email from the ASF dual-hosted git repository.

jackie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-pinot.git


The following commit(s) were added to refs/heads/master by this push:
     new 2e543e9  Add interface and implementations for the new segment assignment (#4269)
2e543e9 is described below

commit 2e543e992f842918e2405c276b041d03b0958bf6
Author: Xiaotian (Jackie) Jiang <17...@users.noreply.github.com>
AuthorDate: Fri Jun 14 13:46:24 2019 -0700

    Add interface and implementations for the new segment assignment (#4269)
    
    Added InstancePartitions implementations
    Added InstancePartitionsUtils to fetch/persist it to ZK and also handle backward-compatibility if it does not exist in ZK
    Added SegmentAssignmentStrategy to handle both segment assignment and table rebalance
    Implemented the balance-number, replica-group segment assignment strategy for both OFFLINE and REALTIME tables that works on InstancePartitions
    
    There is no change to the persistent state
---
 .../pinot/common/metadata/ZKMetadataProvider.java  |   4 +-
 .../pinot/common/utils/InstancePartitionsType.java |  37 ++
 .../java/org/apache/pinot/common/utils/Pairs.java  |  37 +-
 .../helix/core/assignment/InstancePartitions.java  |  99 +++++
 .../core/assignment/InstancePartitionsUtils.java   |  97 +++++
 ...OfflineBalanceNumSegmentAssignmentStrategy.java |  96 +++++
 ...flineReplicaGroupSegmentAssignmentStrategy.java | 200 ++++++++++
 ...ealtimeBalanceNumSegmentAssignmentStrategy.java | 155 ++++++++
 ...ltimeReplicaGroupSegmentAssignmentStrategy.java | 159 ++++++++
 .../segment/SegmentAssignmentStrategy.java         |  59 +++
 .../segment/SegmentAssignmentStrategyFactory.java  |  57 +++
 .../assignment/segment/SegmentAssignmentUtils.java | 276 +++++++++++++
 ...ineBalanceNumSegmentAssignmentStrategyTest.java | 137 +++++++
 ...eReplicaGroupSegmentAssignmentStrategyTest.java | 289 ++++++++++++++
 ...imeBalanceNumSegmentAssignmentStrategyTest.java | 208 ++++++++++
 ...eReplicaGroupSegmentAssignmentStrategyTest.java | 230 +++++++++++
 .../segment/SegmentAssignmentTestUtils.java        |  39 ++
 .../segment/SegmentAssignmentUtilsTest.java        | 434 +++++++++++++++++++++
 18 files changed, 2594 insertions(+), 19 deletions(-)

diff --git a/pinot-common/src/main/java/org/apache/pinot/common/metadata/ZKMetadataProvider.java b/pinot-common/src/main/java/org/apache/pinot/common/metadata/ZKMetadataProvider.java
index 0915119..7cf7448 100644
--- a/pinot-common/src/main/java/org/apache/pinot/common/metadata/ZKMetadataProvider.java
+++ b/pinot-common/src/main/java/org/apache/pinot/common/metadata/ZKMetadataProvider.java
@@ -94,8 +94,8 @@ public class ZKMetadataProvider {
     return StringUtil.join("/", PROPERTYSTORE_SCHEMAS_PREFIX, schemaName);
   }
 
-  public static String constructPropertyStorePathForInstancePartitions(String tableNameWithType) {
-    return StringUtil.join("/", PROPERTYSTORE_INSTANCE_PARTITIONS_PREFIX, tableNameWithType);
+  public static String constructPropertyStorePathForInstancePartitions(String instancePartitionsName) {
+    return StringUtil.join("/", PROPERTYSTORE_INSTANCE_PARTITIONS_PREFIX, instancePartitionsName);
   }
 
   public static String constructPropertyStorePathForResource(String resourceName) {
diff --git a/pinot-common/src/main/java/org/apache/pinot/common/utils/InstancePartitionsType.java b/pinot-common/src/main/java/org/apache/pinot/common/utils/InstancePartitionsType.java
new file mode 100644
index 0000000..90e42a0
--- /dev/null
+++ b/pinot-common/src/main/java/org/apache/pinot/common/utils/InstancePartitionsType.java
@@ -0,0 +1,37 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.common.utils;
+
+/**
+ * The type of the instance partitions.
+ * <p>
+ *   The instance partitions name will be of the format {@code <rawTableName>_<type>}, e.g. {@code table_OFFLINE},
+ *   {@code table_CONSUMING}, {@code table_COMPLETED}.
+ */
+public enum InstancePartitionsType {
+  OFFLINE,    // For segments from offline table// For consuming segments from real-time table
+  CONSUMING,  // For consuming segments from real-time table
+  COMPLETED;  // For completed segments from real-time table
+
+  public static final char TYPE_SUFFIX_SEPARATOR = '_';
+
+  public String getInstancePartitionsName(String rawTableName) {
+    return rawTableName + TYPE_SUFFIX_SEPARATOR + name();
+  }
+}
diff --git a/pinot-common/src/main/java/org/apache/pinot/common/utils/Pairs.java b/pinot-common/src/main/java/org/apache/pinot/common/utils/Pairs.java
index c67bca3..adb95e7 100644
--- a/pinot-common/src/main/java/org/apache/pinot/common/utils/Pairs.java
+++ b/pinot-common/src/main/java/org/apache/pinot/common/utils/Pairs.java
@@ -32,46 +32,45 @@ public class Pairs {
   }
 
   public static class IntPair {
-    int a;
+    private int _left;
+    private int _right;
 
-    int b;
-
-    public IntPair(int a, int b) {
-      this.a = a;
-      this.b = b;
+    public IntPair(int left, int right) {
+      _left = left;
+      _right = right;
     }
 
     public int getLeft() {
-      return a;
+      return _left;
     }
 
     public int getRight() {
-      return b;
+      return _right;
     }
 
-    public void setLeft(int a) {
-      this.a = a;
+    public void setLeft(int left) {
+      _left = left;
     }
 
-    public void setRight(int b) {
-      this.b = b;
+    public void setRight(int right) {
+      _right = right;
     }
 
     @Override
     public String toString() {
-      return "[" + a + "," + b + "]";
+      return "[" + _left + "," + _right + "]";
     }
 
     @Override
     public int hashCode() {
-      return toString().hashCode();
+      return 37 * _left + _right;
     }
 
     @Override
     public boolean equals(Object obj) {
       if (obj instanceof IntPair) {
         IntPair that = (IntPair) obj;
-        return obj != null && a == (that.a) && b == that.b;
+        return _left == that._left && _right == that._right;
       }
       return false;
     }
@@ -80,8 +79,12 @@ public class Pairs {
   public static class AscendingIntPairComparator implements Comparator<IntPair> {
 
     @Override
-    public int compare(IntPair o1, IntPair o2) {
-      return Integer.compare(o1.a, o2.a);
+    public int compare(IntPair pair1, IntPair pair2) {
+      if (pair1._left != pair2._left) {
+        return Integer.compare(pair1._left, pair2._left);
+      } else {
+        return Integer.compare(pair1._right, pair2._right);
+      }
     }
   }
 
diff --git a/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/assignment/InstancePartitions.java b/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/assignment/InstancePartitions.java
new file mode 100644
index 0000000..6d9a35f
--- /dev/null
+++ b/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/assignment/InstancePartitions.java
@@ -0,0 +1,99 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.controller.helix.core.assignment;
+
+import com.fasterxml.jackson.annotation.JsonIgnore;
+import java.util.List;
+import java.util.Map;
+import java.util.TreeMap;
+import org.apache.helix.ZNRecord;
+
+
+/**
+ * Instance (server) partitions for the table.
+ *
+ * <p>The instance partitions is stored as a map from partition of the format: {@code <partitionId>_<replicaId>} to
+ * list of server instances, and is persisted under the ZK path: {@code <cluster>/PROPERTYSTORE/INSTANCE_PARTITIONS}.
+ * <p>The segment assignment will be based on the instance partitions of the table.
+ */
+public class InstancePartitions {
+  private static final char PARTITION_REPLICA_SEPARATOR = '_';
+
+  // Name will be of the format "<rawTableName>_<type>", e.g. "table_OFFLINE", "table_CONSUMING", "table_COMPLETED"
+  private final String _name;
+  private final Map<String, List<String>> _partitionToInstancesMap;
+  private int _numPartitions;
+  private int _numReplicas;
+
+  public InstancePartitions(String name) {
+    _name = name;
+    _partitionToInstancesMap = new TreeMap<>();
+  }
+
+  private InstancePartitions(String name, Map<String, List<String>> partitionToInstancesMap) {
+    _name = name;
+    _partitionToInstancesMap = partitionToInstancesMap;
+    for (String key : partitionToInstancesMap.keySet()) {
+      int splitterIndex = key.indexOf(PARTITION_REPLICA_SEPARATOR);
+      int partition = Integer.parseInt(key.substring(0, splitterIndex));
+      int replica = Integer.parseInt(key.substring(splitterIndex + 1));
+      _numPartitions = Integer.max(_numPartitions, partition + 1);
+      _numReplicas = Integer.max(_numReplicas, replica + 1);
+    }
+  }
+
+  public String getName() {
+    return _name;
+  }
+
+  public Map<String, List<String>> getPartitionToInstancesMap() {
+    return _partitionToInstancesMap;
+  }
+
+  @JsonIgnore
+  public int getNumPartitions() {
+    return _numPartitions;
+  }
+
+  @JsonIgnore
+  public int getNumReplicas() {
+    return _numReplicas;
+  }
+
+  public List<String> getInstances(int partitionId, int replicaId) {
+    return _partitionToInstancesMap.get(Integer.toString(partitionId) + PARTITION_REPLICA_SEPARATOR + replicaId);
+  }
+
+  public void setInstances(int partitionId, int replicaId, List<String> instances) {
+    String key = Integer.toString(partitionId) + PARTITION_REPLICA_SEPARATOR + replicaId;
+    _partitionToInstancesMap.put(key, instances);
+    _numPartitions = Integer.max(_numPartitions, partitionId + 1);
+    _numReplicas = Integer.max(_numReplicas, replicaId + 1);
+  }
+
+  public static InstancePartitions fromZNRecord(ZNRecord znRecord) {
+    return new InstancePartitions(znRecord.getId(), znRecord.getListFields());
+  }
+
+  public ZNRecord toZNRecord() {
+    ZNRecord znRecord = new ZNRecord(_name);
+    znRecord.setListFields(_partitionToInstancesMap);
+    return znRecord;
+  }
+}
diff --git a/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/assignment/InstancePartitionsUtils.java b/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/assignment/InstancePartitionsUtils.java
new file mode 100644
index 0000000..4121246
--- /dev/null
+++ b/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/assignment/InstancePartitionsUtils.java
@@ -0,0 +1,97 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.controller.helix.core.assignment;
+
+import java.util.Collections;
+import java.util.List;
+import org.apache.helix.AccessOption;
+import org.apache.helix.HelixManager;
+import org.apache.helix.ZNRecord;
+import org.apache.helix.model.InstanceConfig;
+import org.apache.helix.store.HelixPropertyStore;
+import org.apache.pinot.common.config.OfflineTagConfig;
+import org.apache.pinot.common.config.RealtimeTagConfig;
+import org.apache.pinot.common.config.TableConfig;
+import org.apache.pinot.common.config.TableNameBuilder;
+import org.apache.pinot.common.metadata.ZKMetadataProvider;
+import org.apache.pinot.common.utils.InstancePartitionsType;
+import org.apache.pinot.common.utils.helix.HelixHelper;
+
+
+/**
+ * Utility class for instance partitions.
+ */
+public class InstancePartitionsUtils {
+  private InstancePartitionsUtils() {
+  }
+
+  /**
+   * Fetches the instance partitions from Helix property store if exists, or computes it for backward compatibility.
+   */
+  public static InstancePartitions fetchOrComputeInstancePartitions(HelixManager helixManager, TableConfig tableConfig,
+      InstancePartitionsType instancePartitionsType) {
+    String tableNameWithType = tableConfig.getTableName();
+    String instancePartitionsName =
+        instancePartitionsType.getInstancePartitionsName(TableNameBuilder.extractRawTableName(tableNameWithType));
+
+    // Fetch the instance partitions from property store if exists
+    String path = ZKMetadataProvider.constructPropertyStorePathForInstancePartitions(instancePartitionsName);
+    ZNRecord znRecord = helixManager.getHelixPropertyStore().get(path, null, AccessOption.PERSISTENT);
+    if (znRecord != null) {
+      return InstancePartitions.fromZNRecord(znRecord);
+    }
+
+    // Compute the instance partitions (for backward compatible)
+    // Sort all enabled instances with the server tag, rotate the list based on the table name to prevent creating
+    // hotspot servers
+    InstancePartitions instancePartitions = new InstancePartitions(instancePartitionsName);
+    List<InstanceConfig> instanceConfigs = HelixHelper.getInstanceConfigs(helixManager);
+    String serverTag;
+    switch (instancePartitionsType) {
+      case OFFLINE:
+        serverTag = new OfflineTagConfig(tableConfig).getOfflineServerTag();
+        break;
+      case CONSUMING:
+        serverTag = new RealtimeTagConfig(tableConfig).getConsumingServerTag();
+        break;
+      case COMPLETED:
+        serverTag = new RealtimeTagConfig(tableConfig).getCompletedServerTag();
+        break;
+      default:
+        throw new IllegalArgumentException();
+    }
+    List<String> instances = HelixHelper.getEnabledInstancesWithTag(instanceConfigs, serverTag);
+    instances.sort(null);
+    int numInstances = instances.size();
+    Collections.rotate(instances, -(Math.abs(tableNameWithType.hashCode()) % numInstances));
+    instancePartitions.setInstances(0, 0, instances);
+    return instancePartitions;
+  }
+
+  /**
+   * Persists the instance partitions to Helix property store.
+   *
+   * @return true if the instance partitions was successfully persisted, false otherwise
+   */
+  public static boolean persistInstancePartitions(HelixPropertyStore<ZNRecord> propertyStore,
+      InstancePartitions instancePartitions) {
+    String path = ZKMetadataProvider.constructPropertyStorePathForInstancePartitions(instancePartitions.getName());
+    return propertyStore.set(path, instancePartitions.toZNRecord(), AccessOption.PERSISTENT);
+  }
+}
diff --git a/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/assignment/segment/OfflineBalanceNumSegmentAssignmentStrategy.java b/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/assignment/segment/OfflineBalanceNumSegmentAssignmentStrategy.java
new file mode 100644
index 0000000..4bd34b5
--- /dev/null
+++ b/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/assignment/segment/OfflineBalanceNumSegmentAssignmentStrategy.java
@@ -0,0 +1,96 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.controller.helix.core.assignment.segment;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+import java.util.PriorityQueue;
+import org.apache.commons.configuration.Configuration;
+import org.apache.helix.HelixManager;
+import org.apache.pinot.common.config.TableConfig;
+import org.apache.pinot.common.utils.InstancePartitionsType;
+import org.apache.pinot.common.utils.Pairs;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+
+/**
+ * Segment assignment strategy for offline segments that assigns segment to the instance with the least number of
+ * segments. In case of a tie, assigns to the instance with the smallest index in the list. The strategy ensures that
+ * replicas of the same segment are not assigned to the same server.
+ * <p>To rebalance a table, use Helix AutoRebalanceStrategy.
+ */
+public class OfflineBalanceNumSegmentAssignmentStrategy implements SegmentAssignmentStrategy {
+  private static final Logger LOGGER = LoggerFactory.getLogger(OfflineBalanceNumSegmentAssignmentStrategy.class);
+
+  private HelixManager _helixManager;
+  private TableConfig _tableConfig;
+  private String _tableNameWithType;
+  private int _replication;
+
+  @Override
+  public void init(HelixManager helixManager, TableConfig tableConfig) {
+    _helixManager = helixManager;
+    _tableConfig = tableConfig;
+    _tableNameWithType = tableConfig.getTableName();
+    _replication = tableConfig.getValidationConfig().getReplicationNumber();
+
+    LOGGER.info("Initialized OfflineBalanceNumSegmentAssignmentStrategy for table: {} with replication: {}",
+        _tableNameWithType, _replication);
+  }
+
+  @Override
+  public List<String> assignSegment(String segmentName, Map<String, Map<String, String>> currentAssignment) {
+    List<String> instances = SegmentAssignmentUtils
+        .getInstancesForBalanceNumStrategy(_helixManager, _tableConfig, _replication, InstancePartitionsType.OFFLINE);
+    int[] numSegmentsAssignedPerInstance =
+        SegmentAssignmentUtils.getNumSegmentsAssignedPerInstance(currentAssignment, instances);
+
+    // Assign the segment to the instance with the least segments, or the smallest id if there is a tie
+    int numInstances = numSegmentsAssignedPerInstance.length;
+    PriorityQueue<Pairs.IntPair> heap = new PriorityQueue<>(numInstances, Pairs.intPairComparator());
+    for (int instanceId = 0; instanceId < numInstances; instanceId++) {
+      heap.add(new Pairs.IntPair(numSegmentsAssignedPerInstance[instanceId], instanceId));
+    }
+    List<String> instancesAssigned = new ArrayList<>(_replication);
+    for (int i = 0; i < _replication; i++) {
+      instancesAssigned.add(instances.get(heap.remove().getRight()));
+    }
+
+    LOGGER.info("Assigned segment: {} to instances: {} for table: {}", segmentName, instancesAssigned,
+        _tableNameWithType);
+    return instancesAssigned;
+  }
+
+  @Override
+  public Map<String, Map<String, String>> rebalanceTable(Map<String, Map<String, String>> currentAssignment,
+      Configuration config) {
+    List<String> instances = SegmentAssignmentUtils
+        .getInstancesForBalanceNumStrategy(_helixManager, _tableConfig, _replication, InstancePartitionsType.OFFLINE);
+    Map<String, Map<String, String>> newAssignment =
+        SegmentAssignmentUtils.rebalanceTableWithHelixAutoRebalanceStrategy(currentAssignment, instances, _replication);
+
+    LOGGER.info(
+        "Rebalanced {} segments to instances: {} for table: {} with replication: {}, number of segments to be moved to each instance: {}",
+        currentAssignment.size(), instances, _tableNameWithType, _replication,
+        SegmentAssignmentUtils.getNumSegmentsToBeMovedPerInstance(currentAssignment, newAssignment));
+    return newAssignment;
+  }
+}
diff --git a/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/assignment/segment/OfflineReplicaGroupSegmentAssignmentStrategy.java b/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/assignment/segment/OfflineReplicaGroupSegmentAssignmentStrategy.java
new file mode 100644
index 0000000..8124d30
--- /dev/null
+++ b/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/assignment/segment/OfflineReplicaGroupSegmentAssignmentStrategy.java
@@ -0,0 +1,200 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.controller.helix.core.assignment.segment;
+
+import com.google.common.base.Preconditions;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.TreeMap;
+import org.apache.commons.configuration.Configuration;
+import org.apache.helix.HelixManager;
+import org.apache.pinot.common.config.ReplicaGroupStrategyConfig;
+import org.apache.pinot.common.config.TableConfig;
+import org.apache.pinot.common.metadata.ZKMetadataProvider;
+import org.apache.pinot.common.metadata.segment.ColumnPartitionMetadata;
+import org.apache.pinot.common.metadata.segment.OfflineSegmentZKMetadata;
+import org.apache.pinot.common.utils.InstancePartitionsType;
+import org.apache.pinot.controller.helix.core.assignment.InstancePartitions;
+import org.apache.pinot.controller.helix.core.assignment.InstancePartitionsUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+
+/**
+ * Segment assignment strategy for offline segments that assigns segment to the instance in the replica with the least
+ * number of segments.
+ * <p>Among multiple replicas, always mirror the assignment (pick the same index of the instance).
+ * <p>Inside each partition, assign the segment to the servers with the least segments already assigned. In case of a
+ * tie, assign to the server with the smallest index in the list. Do this for one replica and mirror the assignment to
+ * other replicas.
+ * <p>To rebalance a table, inside each partition, first calculate the number of segments on each server, loop over all
+ * the segments and keep the assignment if number of segments for the server has not been reached and track the not
+ * assigned segments, then assign the left-over segments to the servers with the least segments, or the smallest index
+ * if there is a tie. Repeat the process for all the partitions in one replica, and mirror the assignment to other
+ * replicas. With this greedy algorithm, the result is deterministic and with minimum segment moves.
+ */
+public class OfflineReplicaGroupSegmentAssignmentStrategy implements SegmentAssignmentStrategy {
+  private static final Logger LOGGER = LoggerFactory.getLogger(OfflineReplicaGroupSegmentAssignmentStrategy.class);
+
+  private HelixManager _helixManager;
+  private TableConfig _tableConfig;
+  private String _tableNameWithType;
+  private String _partitionColumn;
+
+  @Override
+  public void init(HelixManager helixManager, TableConfig tableConfig) {
+    _helixManager = helixManager;
+    _tableConfig = tableConfig;
+    _tableNameWithType = tableConfig.getTableName();
+    ReplicaGroupStrategyConfig strategyConfig = tableConfig.getValidationConfig().getReplicaGroupStrategyConfig();
+    _partitionColumn = strategyConfig != null ? strategyConfig.getPartitionColumn() : null;
+
+    if (_partitionColumn == null) {
+      LOGGER.info("Initialized OfflineReplicaGroupSegmentAssignmentStrategy for table: {} without partition column",
+          _tableNameWithType);
+    } else {
+      LOGGER.info("Initialized OfflineReplicaGroupSegmentAssignmentStrategy for table: {} with partition column: {}",
+          _tableNameWithType, _partitionColumn);
+    }
+  }
+
+  @Override
+  public List<String> assignSegment(String segmentName, Map<String, Map<String, String>> currentAssignment) {
+    InstancePartitions instancePartitions = InstancePartitionsUtils
+        .fetchOrComputeInstancePartitions(_helixManager, _tableConfig, InstancePartitionsType.OFFLINE);
+
+    // Fetch partition id from segment ZK metadata if partition column is configured
+    int partitionId = 0;
+    if (_partitionColumn == null) {
+      Preconditions.checkState(instancePartitions.getNumPartitions() == 1,
+          "The instance partitions: %s should contain only 1 partition", instancePartitions.getName());
+    } else {
+      OfflineSegmentZKMetadata segmentZKMetadata = ZKMetadataProvider
+          .getOfflineSegmentZKMetadata(_helixManager.getHelixPropertyStore(), _tableNameWithType, segmentName);
+      Preconditions
+          .checkState(segmentZKMetadata != null, "Failed to fetch segment ZK metadata for table: %s, segment: %s",
+              _tableNameWithType, segmentName);
+      // Uniformly spray the segment partitions over the instance partitions
+      partitionId = getPartitionId(segmentZKMetadata) % instancePartitions.getNumPartitions();
+    }
+
+    // First assign the segment to replica 0
+    List<String> instances = instancePartitions.getInstances(partitionId, 0);
+    int[] numSegmentsAssignedPerInstance =
+        SegmentAssignmentUtils.getNumSegmentsAssignedPerInstance(currentAssignment, instances);
+    int minNumSegmentsAssigned = numSegmentsAssignedPerInstance[0];
+    int instanceIdWithLeastSegmentsAssigned = 0;
+    int numInstances = numSegmentsAssignedPerInstance.length;
+    for (int instanceId = 1; instanceId < numInstances; instanceId++) {
+      if (numSegmentsAssignedPerInstance[instanceId] < minNumSegmentsAssigned) {
+        minNumSegmentsAssigned = numSegmentsAssignedPerInstance[instanceId];
+        instanceIdWithLeastSegmentsAssigned = instanceId;
+      }
+    }
+
+    // Mirror the assignment to all replicas
+    int numReplicas = instancePartitions.getNumReplicas();
+    List<String> instancesAssigned = new ArrayList<>(numReplicas);
+    for (int replicaId = 0; replicaId < numReplicas; replicaId++) {
+      instancesAssigned
+          .add(instancePartitions.getInstances(partitionId, replicaId).get(instanceIdWithLeastSegmentsAssigned));
+    }
+
+    if (_partitionColumn == null) {
+      LOGGER.info("Assigned segment: {} to instances: {} for table: {}", segmentName, instancesAssigned,
+          _tableNameWithType);
+    } else {
+      LOGGER.info("Assigned segment: {} with partition id: {} to instances: {} for table: {}", segmentName, partitionId,
+          instancesAssigned, _tableNameWithType);
+    }
+    return instancesAssigned;
+  }
+
+  @Override
+  public Map<String, Map<String, String>> rebalanceTable(Map<String, Map<String, String>> currentAssignment,
+      Configuration config) {
+    InstancePartitions instancePartitions = InstancePartitionsUtils
+        .fetchOrComputeInstancePartitions(_helixManager, _tableConfig, InstancePartitionsType.OFFLINE);
+    if (_partitionColumn == null) {
+      return rebalanceTableWithoutPartition(currentAssignment, instancePartitions);
+    } else {
+      return rebalanceTableWithPartition(currentAssignment, instancePartitions);
+    }
+  }
+
+  private Map<String, Map<String, String>> rebalanceTableWithoutPartition(
+      Map<String, Map<String, String>> currentAssignment, InstancePartitions instancePartitions) {
+    Preconditions.checkState(instancePartitions.getNumPartitions() == 1,
+        "The instance partitions: %s should contain only 1 partition", instancePartitions.getName());
+
+    Map<String, Map<String, String>> newAssignment = new TreeMap<>();
+    SegmentAssignmentUtils
+        .rebalanceReplicaGroupBasedPartition(currentAssignment, instancePartitions, 0, currentAssignment.keySet(),
+            newAssignment);
+
+    LOGGER.info(
+        "Rebalanced {} segments with instance partitions: {} for table: {} without partition column, number of segments to be moved to each instance: {}",
+        currentAssignment.size(), instancePartitions.getPartitionToInstancesMap(), _tableNameWithType,
+        SegmentAssignmentUtils.getNumSegmentsToBeMovedPerInstance(currentAssignment, newAssignment));
+    return newAssignment;
+  }
+
+  private Map<String, Map<String, String>> rebalanceTableWithPartition(
+      Map<String, Map<String, String>> currentAssignment, InstancePartitions instancePartitions) {
+    // Fetch partition id from segment ZK metadata
+    List<OfflineSegmentZKMetadata> segmentZKMetadataList = ZKMetadataProvider
+        .getOfflineSegmentZKMetadataListForTable(_helixManager.getHelixPropertyStore(), _tableNameWithType);
+    Map<String, OfflineSegmentZKMetadata> segmentZKMetadataMap = new HashMap<>();
+    for (OfflineSegmentZKMetadata segmentZKMetadata : segmentZKMetadataList) {
+      segmentZKMetadataMap.put(segmentZKMetadata.getSegmentName(), segmentZKMetadata);
+    }
+    Map<Integer, Set<String>> partitionIdToSegmentsMap = new HashMap<>();
+    for (String segmentName : currentAssignment.keySet()) {
+      int partitionId = getPartitionId(segmentZKMetadataMap.get(segmentName));
+      partitionIdToSegmentsMap.computeIfAbsent(partitionId, k -> new HashSet<>()).add(segmentName);
+    }
+
+    Map<String, Map<String, String>> newAssignment = SegmentAssignmentUtils
+        .rebalanceReplicaGroupBasedTable(currentAssignment, instancePartitions, partitionIdToSegmentsMap);
+
+    LOGGER.info(
+        "Rebalanced {} segments with instance partitions: {} for table: {} with partition column: {}, number of segments to be moved to each instance: {}",
+        currentAssignment.size(), instancePartitions.getPartitionToInstancesMap(), _tableNameWithType, _partitionColumn,
+        SegmentAssignmentUtils.getNumSegmentsToBeMovedPerInstance(currentAssignment, newAssignment));
+    return newAssignment;
+  }
+
+  private int getPartitionId(OfflineSegmentZKMetadata segmentZKMetadata) {
+    String segmentName = segmentZKMetadata.getSegmentName();
+    ColumnPartitionMetadata partitionMetadata =
+        segmentZKMetadata.getPartitionMetadata().getColumnPartitionMap().get(_partitionColumn);
+    Preconditions.checkState(partitionMetadata != null,
+        "Segment ZK metadata for table: %s, segment: %s does not contain partition metadata for column: %s",
+        _tableNameWithType, segmentName, _partitionColumn);
+    Set<Integer> partitions = partitionMetadata.getPartitions();
+    Preconditions.checkState(partitions.size() == 1,
+        "Segment ZK metadata for table: %s, segment: %s contains multiple partitions for column: %s",
+        _tableNameWithType, segmentName, _partitionColumn);
+    return partitions.iterator().next();
+  }
+}
diff --git a/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/assignment/segment/RealtimeBalanceNumSegmentAssignmentStrategy.java b/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/assignment/segment/RealtimeBalanceNumSegmentAssignmentStrategy.java
new file mode 100644
index 0000000..72bdf29
--- /dev/null
+++ b/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/assignment/segment/RealtimeBalanceNumSegmentAssignmentStrategy.java
@@ -0,0 +1,155 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.controller.helix.core.assignment.segment;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+import org.apache.commons.configuration.Configuration;
+import org.apache.helix.HelixManager;
+import org.apache.pinot.common.config.TableConfig;
+import org.apache.pinot.common.utils.CommonConstants.Helix.StateModel.RealtimeSegmentOnlineOfflineStateModel;
+import org.apache.pinot.common.utils.InstancePartitionsType;
+import org.apache.pinot.common.utils.LLCSegmentName;
+import org.apache.pinot.controller.helix.core.assignment.segment.SegmentAssignmentUtils.CompletedConsumingSegmentAssignmentPair;
+import org.apache.pinot.controller.helix.core.rebalance.RebalanceUserConfigConstants;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+
+/**
+ * Segment assignment strategy for LLC real-time segments without replica-group.
+ * <ul>
+ *   <li>
+ *     For the CONSUMING segments, it is very similar to replica-group based segment assignment with the following
+ *     differences:
+ *     <ul>
+ *       <li>
+ *         1. Within a replica, all segments of a partition (steam partition) always exist in exactly one one server
+ *       </li>
+ *       <li>
+ *         2. Partition id for an instance is derived from the index of the instance, instead of explicitly stored in
+ *         the instance partitions
+ *       </li>
+ *       <li>
+ *         3. In addition to the ONLINE segments, there are also CONSUMING segments to be assigned
+ *       </li>
+ *     </ul>
+ *     Since within a replica, each partition contains only one server, we can directly assign or rebalance the
+ *     CONSUMING segments to the servers based on the partition id.
+ *     <p>The strategy does not minimize segment movements for CONSUMING segments because within a replica, the server
+ *     is fixed for each partition. The instance assignment is responsible for keeping minimum changes to the instance
+ *     partitions to reduce the number of segments need to be moved.
+ *   </li>
+ *   <li>
+ *     For the COMPLETED segments, rebalance segments the same way as OfflineBalanceNumSegmentAssignmentStrategy.
+ *   </li>
+ * </ul>
+ */
+public class RealtimeBalanceNumSegmentAssignmentStrategy implements SegmentAssignmentStrategy {
+  private static final Logger LOGGER = LoggerFactory.getLogger(RealtimeBalanceNumSegmentAssignmentStrategy.class);
+
+  private HelixManager _helixManager;
+  private TableConfig _tableConfig;
+  private String _tableNameWithType;
+  private int _replication;
+
+  @Override
+  public void init(HelixManager helixManager, TableConfig tableConfig) {
+    _helixManager = helixManager;
+    _tableConfig = tableConfig;
+    _tableNameWithType = tableConfig.getTableName();
+    _replication = tableConfig.getValidationConfig().getReplicasPerPartitionNumber();
+
+    LOGGER.info("Initialized RealtimeBalanceNumSegmentAssignmentStrategy for table: {} with replication: {}",
+        _tableNameWithType, _replication);
+  }
+
+  @Override
+  public List<String> assignSegment(String segmentName, Map<String, Map<String, String>> currentAssignment) {
+    List<String> instances = SegmentAssignmentUtils
+        .getInstancesForBalanceNumStrategy(_helixManager, _tableConfig, _replication, InstancePartitionsType.CONSUMING);
+    int partitionId = new LLCSegmentName(segmentName).getPartitionId();
+    List<String> instancesAssigned = getInstances(instances, partitionId);
+    LOGGER.info("Assigned segment: {} with partition id: {} to instances: {} for table: {}", segmentName, partitionId,
+        instancesAssigned, _tableNameWithType);
+    return instancesAssigned;
+  }
+
+  @Override
+  public Map<String, Map<String, String>> rebalanceTable(Map<String, Map<String, String>> currentAssignment,
+      Configuration config) {
+    CompletedConsumingSegmentAssignmentPair pair = new CompletedConsumingSegmentAssignmentPair(currentAssignment);
+
+    // Rebalance COMPLETED segments first
+    Map<String, Map<String, String>> completedSegmentAssignment = pair.getCompletedSegmentAssignment();
+    List<String> instancesForCompletedSegments = SegmentAssignmentUtils
+        .getInstancesForBalanceNumStrategy(_helixManager, _tableConfig, _replication, InstancePartitionsType.COMPLETED);
+    Map<String, Map<String, String>> newAssignment = SegmentAssignmentUtils
+        .rebalanceTableWithHelixAutoRebalanceStrategy(completedSegmentAssignment, instancesForCompletedSegments,
+            _replication);
+
+    // Rebalance CONSUMING segments if needed
+    Map<String, Map<String, String>> consumingSegmentAssignment = pair.getConsumingSegmentAssignment();
+    if (config.getBoolean(RebalanceUserConfigConstants.INCLUDE_CONSUMING,
+        RebalanceUserConfigConstants.DEFAULT_INCLUDE_CONSUMING)) {
+      List<String> instancesForConsumingSegments = SegmentAssignmentUtils
+          .getInstancesForBalanceNumStrategy(_helixManager, _tableConfig, _replication,
+              InstancePartitionsType.CONSUMING);
+      for (String segmentName : consumingSegmentAssignment.keySet()) {
+        int partitionId = new LLCSegmentName(segmentName).getPartitionId();
+        List<String> instancesAssigned = getInstances(instancesForConsumingSegments, partitionId);
+        Map<String, String> instanceStateMap = SegmentAssignmentUtils
+            .getInstanceStateMap(instancesAssigned, RealtimeSegmentOnlineOfflineStateModel.CONSUMING);
+        newAssignment.put(segmentName, instanceStateMap);
+      }
+      LOGGER.info(
+          "Rebalanced {} COMPLETED segments to instances: {} and {} CONSUMING segments to instances: {} for table: {} with replication: {}, number of segments to be moved to each instances: {}",
+          completedSegmentAssignment.size(), instancesForCompletedSegments, consumingSegmentAssignment.size(),
+          instancesForConsumingSegments, _tableNameWithType, _replication,
+          SegmentAssignmentUtils.getNumSegmentsToBeMovedPerInstance(currentAssignment, newAssignment));
+    } else {
+      LOGGER.info(
+          "Rebalanced {} COMPLETED segments to instances: {} for table: {} with replication: {}, number of segments to be moved to each instance: {}",
+          completedSegmentAssignment.size(), instancesForCompletedSegments, _tableNameWithType, _replication,
+          SegmentAssignmentUtils.getNumSegmentsToBeMovedPerInstance(completedSegmentAssignment, newAssignment));
+      newAssignment.putAll(consumingSegmentAssignment);
+    }
+
+    return newAssignment;
+  }
+
+  /**
+   * Returns the instances for the given partition id for CONSUMING segments.
+   * <p>Uniformly spray the partitions and replicas across the instances.
+   * <p>E.g. (6 servers, 3 partitions, 4 replicas)
+   * <pre>
+   *   "0_0": [i0,   i1,   i2,   i3,   i4,   i5  ]
+   *           p0r0, p0r1, p0r2, p1r3, p1r0, p1r1
+   *           p1r2, p1r3, p2r0, p2r1, p2r2, p2r3
+   * </pre>
+   */
+  private List<String> getInstances(List<String> instances, int partitionId) {
+    List<String> instancesAssigned = new ArrayList<>(_replication);
+    for (int replicaId = 0; replicaId < _replication; replicaId++) {
+      instancesAssigned.add(instances.get((partitionId * _replication + replicaId) % instances.size()));
+    }
+    return instancesAssigned;
+  }
+}
diff --git a/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/assignment/segment/RealtimeReplicaGroupSegmentAssignmentStrategy.java b/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/assignment/segment/RealtimeReplicaGroupSegmentAssignmentStrategy.java
new file mode 100644
index 0000000..afd74be
--- /dev/null
+++ b/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/assignment/segment/RealtimeReplicaGroupSegmentAssignmentStrategy.java
@@ -0,0 +1,159 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.controller.helix.core.assignment.segment;
+
+import com.google.common.base.Preconditions;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import org.apache.commons.configuration.Configuration;
+import org.apache.helix.HelixManager;
+import org.apache.pinot.common.config.TableConfig;
+import org.apache.pinot.common.utils.CommonConstants;
+import org.apache.pinot.common.utils.InstancePartitionsType;
+import org.apache.pinot.common.utils.LLCSegmentName;
+import org.apache.pinot.controller.helix.core.assignment.InstancePartitions;
+import org.apache.pinot.controller.helix.core.assignment.InstancePartitionsUtils;
+import org.apache.pinot.controller.helix.core.rebalance.RebalanceUserConfigConstants;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+
+/**
+ * Segment assignment strategy for LLC real-time segments (both consuming and completed).
+ * <p>It is very similar to replica-group based segment assignment with the following differences:
+ * <ul>
+ *   <li>1. Inside one replica, each partition (stream partition) always contains one server</li>
+ *   <li>
+ *     2. Partition id for an instance is derived from the index of the instance in the replica group, instead of
+ *     explicitly stored in the instance partitions
+ *   </li>
+ *   <li>3. In addition to the ONLINE segments, there are also CONSUMING segments to be assigned</li>
+ * </ul>
+ * <p>
+ *   Since each partition contains only one server (in one replica), we can directly assign or rebalance segments to the
+ *   servers based on the partition id.
+ * <p>
+ *   The real-time segment assignment does not minimize segment moves because the server is fixed for each partition in
+ *   each replica. The instance assignment is responsible for keeping minimum changes to the instance partitions to
+ *   reduce the number of segments need to be moved.
+ */
+public class RealtimeReplicaGroupSegmentAssignmentStrategy implements SegmentAssignmentStrategy {
+  private static final Logger LOGGER = LoggerFactory.getLogger(RealtimeReplicaGroupSegmentAssignmentStrategy.class);
+
+  private HelixManager _helixManager;
+  private TableConfig _tableConfig;
+  private String _tableNameWithType;
+
+  @Override
+  public void init(HelixManager helixManager, TableConfig tableConfig) {
+    _helixManager = helixManager;
+    _tableConfig = tableConfig;
+    _tableNameWithType = tableConfig.getTableName();
+
+    LOGGER.info("Initialized RealtimeReplicaGroupSegmentAssignmentStrategy for table: {}", _tableNameWithType);
+  }
+
+  @Override
+  public List<String> assignSegment(String segmentName, Map<String, Map<String, String>> currentAssignment) {
+    InstancePartitions instancePartitions = InstancePartitionsUtils
+        .fetchOrComputeInstancePartitions(_helixManager, _tableConfig, InstancePartitionsType.CONSUMING);
+    Preconditions.checkState(instancePartitions.getNumPartitions() == 1,
+        "The instance partitions: %s should contain only 1 partition", instancePartitions.getName());
+
+    int partitionId = new LLCSegmentName(segmentName).getPartitionId();
+    List<String> instancesAssigned = getInstances(instancePartitions, partitionId);
+    LOGGER.info("Assigned segment: {} with partition id: {} to instances: {} for table: {}", segmentName, partitionId,
+        instancesAssigned, _tableNameWithType);
+    return instancesAssigned;
+  }
+
+  @Override
+  public Map<String, Map<String, String>> rebalanceTable(Map<String, Map<String, String>> currentAssignment,
+      Configuration config) {
+    SegmentAssignmentUtils.CompletedConsumingSegmentAssignmentPair pair =
+        new SegmentAssignmentUtils.CompletedConsumingSegmentAssignmentPair(currentAssignment);
+
+    // Rebalance COMPLETED segments first
+    Map<String, Map<String, String>> completedSegmentAssignment = pair.getCompletedSegmentAssignment();
+    InstancePartitions instancePartitionsForCompletedSegments = InstancePartitionsUtils
+        .fetchOrComputeInstancePartitions(_helixManager, _tableConfig, InstancePartitionsType.COMPLETED);
+    Map<Integer, Set<String>> partitionIdToSegmentsMap = new HashMap<>();
+    for (String segmentName : completedSegmentAssignment.keySet()) {
+      int partitionId = new LLCSegmentName(segmentName).getPartitionId();
+      partitionIdToSegmentsMap.computeIfAbsent(partitionId, k -> new HashSet<>()).add(segmentName);
+    }
+    Map<String, Map<String, String>> newAssignment = SegmentAssignmentUtils
+        .rebalanceReplicaGroupBasedTable(completedSegmentAssignment, instancePartitionsForCompletedSegments,
+            partitionIdToSegmentsMap);
+
+    // Rebalance CONSUMING segments if needed
+    Map<String, Map<String, String>> consumingSegmentAssignment = pair.getConsumingSegmentAssignment();
+    if (config.getBoolean(RebalanceUserConfigConstants.INCLUDE_CONSUMING,
+        RebalanceUserConfigConstants.DEFAULT_INCLUDE_CONSUMING)) {
+      InstancePartitions instancePartitionsForConsumingSegments = InstancePartitionsUtils
+          .fetchOrComputeInstancePartitions(_helixManager, _tableConfig, InstancePartitionsType.CONSUMING);
+      for (String segmentName : consumingSegmentAssignment.keySet()) {
+        int partitionId = new LLCSegmentName(segmentName).getPartitionId();
+        List<String> instancesAssigned = getInstances(instancePartitionsForConsumingSegments, partitionId);
+        Map<String, String> instanceStateMap = SegmentAssignmentUtils.getInstanceStateMap(instancesAssigned,
+            CommonConstants.Helix.StateModel.RealtimeSegmentOnlineOfflineStateModel.CONSUMING);
+        newAssignment.put(segmentName, instanceStateMap);
+      }
+      LOGGER.info(
+          "Rebalanced {} COMPLETED segments with instance partitions: {} and {} CONSUMING segments with instance partitions: {} for table: {}, number of segments to be moved to each instances: {}",
+          completedSegmentAssignment.size(), instancePartitionsForCompletedSegments.getPartitionToInstancesMap(),
+          consumingSegmentAssignment.size(), instancePartitionsForConsumingSegments.getPartitionToInstancesMap(),
+          _tableNameWithType,
+          SegmentAssignmentUtils.getNumSegmentsToBeMovedPerInstance(currentAssignment, newAssignment));
+    } else {
+      LOGGER.info(
+          "Rebalanced {} COMPLETED segments with instance partitions: {} for table: {}, number of segments to be moved to each instance: {}",
+          completedSegmentAssignment.size(), instancePartitionsForCompletedSegments.getPartitionToInstancesMap(),
+          _tableNameWithType,
+          SegmentAssignmentUtils.getNumSegmentsToBeMovedPerInstance(completedSegmentAssignment, newAssignment));
+      newAssignment.putAll(consumingSegmentAssignment);
+    }
+
+    return newAssignment;
+  }
+
+  /**
+   * Returns the instances for the given partition id for CONSUMING segments.
+   * <p>Within a replica, uniformly spray the partitions across the instances.
+   * <p>E.g. (within a replica, 3 servers, 6 partitions)
+   * <pre>
+   *   "0_0": [i0, i1, i2]
+   *           p0, p1, p2
+   *           p3, p4, p5
+   * </pre>
+   */
+  private List<String> getInstances(InstancePartitions instancePartitions, int partitionId) {
+    int numReplicas = instancePartitions.getNumReplicas();
+    List<String> instancesAssigned = new ArrayList<>(numReplicas);
+    for (int replicaId = 0; replicaId < numReplicas; replicaId++) {
+      List<String> instances = instancePartitions.getInstances(0, replicaId);
+      instancesAssigned.add(instances.get(partitionId % instances.size()));
+    }
+    return instancesAssigned;
+  }
+}
diff --git a/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/assignment/segment/SegmentAssignmentStrategy.java b/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/assignment/segment/SegmentAssignmentStrategy.java
new file mode 100644
index 0000000..de2ea0d
--- /dev/null
+++ b/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/assignment/segment/SegmentAssignmentStrategy.java
@@ -0,0 +1,59 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.controller.helix.core.assignment.segment;
+
+import java.util.List;
+import java.util.Map;
+import org.apache.commons.configuration.Configuration;
+import org.apache.helix.HelixManager;
+import org.apache.pinot.common.config.TableConfig;
+
+
+/**
+ * Strategy to assign segment to instances or rebalance all segments in a table.
+ */
+public interface SegmentAssignmentStrategy {
+
+  /**
+   * Initializes the segment assignment strategy.
+   *
+   * @param helixManager Helix manager
+   * @param tableConfig Table config
+   */
+  void init(HelixManager helixManager, TableConfig tableConfig);
+
+  /**
+   * Assigns a new segment.
+   *
+   * @param segmentName Name of the segment to be assigned
+   * @param currentAssignment Current segment assignment of the table (map from segment name to instance state map)
+   * @return List of servers to assign the segment to
+   */
+  List<String> assignSegment(String segmentName, Map<String, Map<String, String>> currentAssignment);
+
+  /**
+   * Rebalances the segments for a table.
+   *
+   * @param currentAssignment Current segment assignment of the table (map from segment name to instance state map)
+   * @param config Configuration for the rebalance
+   * @return the rebalanced assignment for the segments
+   */
+  Map<String, Map<String, String>> rebalanceTable(Map<String, Map<String, String>> currentAssignment,
+      Configuration config);
+}
diff --git a/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/assignment/segment/SegmentAssignmentStrategyFactory.java b/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/assignment/segment/SegmentAssignmentStrategyFactory.java
new file mode 100644
index 0000000..d71b1d9
--- /dev/null
+++ b/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/assignment/segment/SegmentAssignmentStrategyFactory.java
@@ -0,0 +1,57 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.controller.helix.core.assignment.segment;
+
+import org.apache.helix.HelixManager;
+import org.apache.pinot.common.config.TableConfig;
+import org.apache.pinot.common.utils.CommonConstants.Helix.TableType;
+
+
+/**
+ * Factory for the {@link SegmentAssignmentStrategy}.
+ */
+public class SegmentAssignmentStrategyFactory {
+  private SegmentAssignmentStrategyFactory() {
+  }
+
+  public enum Strategy {
+    BalanceNumSegmentAssignmentStrategy, ReplicaGroupSegmentAssignmentStrategy
+  }
+
+  public static SegmentAssignmentStrategy getSegmentAssignmentStrategy(HelixManager helixManager,
+      TableConfig tableConfig) {
+    SegmentAssignmentStrategy segmentAssignmentStrategy;
+    if (Strategy.ReplicaGroupSegmentAssignmentStrategy.name()
+        .equalsIgnoreCase(tableConfig.getValidationConfig().getSegmentAssignmentStrategy())) {
+      if (tableConfig.getTableType() == TableType.OFFLINE) {
+        segmentAssignmentStrategy = new OfflineReplicaGroupSegmentAssignmentStrategy();
+      } else {
+        segmentAssignmentStrategy = new RealtimeReplicaGroupSegmentAssignmentStrategy();
+      }
+    } else {
+      if (tableConfig.getTableType() == TableType.OFFLINE) {
+        segmentAssignmentStrategy = new OfflineBalanceNumSegmentAssignmentStrategy();
+      } else {
+        segmentAssignmentStrategy = new RealtimeBalanceNumSegmentAssignmentStrategy();
+      }
+    }
+    segmentAssignmentStrategy.init(helixManager, tableConfig);
+    return segmentAssignmentStrategy;
+  }
+}
diff --git a/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/assignment/segment/SegmentAssignmentUtils.java b/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/assignment/segment/SegmentAssignmentUtils.java
new file mode 100644
index 0000000..4333f91
--- /dev/null
+++ b/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/assignment/segment/SegmentAssignmentUtils.java
@@ -0,0 +1,276 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.controller.helix.core.assignment.segment;
+
+import com.google.common.base.Preconditions;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.PriorityQueue;
+import java.util.Set;
+import java.util.TreeMap;
+import org.apache.helix.HelixManager;
+import org.apache.helix.controller.rebalancer.strategy.AutoRebalanceStrategy;
+import org.apache.pinot.common.config.TableConfig;
+import org.apache.pinot.common.utils.CommonConstants.Helix.StateModel.RealtimeSegmentOnlineOfflineStateModel;
+import org.apache.pinot.common.utils.CommonConstants.Helix.StateModel.SegmentOnlineOfflineStateModel;
+import org.apache.pinot.common.utils.InstancePartitionsType;
+import org.apache.pinot.common.utils.Pairs;
+import org.apache.pinot.controller.helix.core.assignment.InstancePartitions;
+import org.apache.pinot.controller.helix.core.assignment.InstancePartitionsUtils;
+
+
+/**
+ * Utility class for segment assignment.
+ */
+class SegmentAssignmentUtils {
+  private SegmentAssignmentUtils() {
+  }
+
+  /**
+   * Returns the number of segments assigned to each instance.
+   */
+  static int[] getNumSegmentsAssignedPerInstance(Map<String, Map<String, String>> segmentAssignment,
+      List<String> instances) {
+    int[] numSegmentsPerInstance = new int[instances.size()];
+    Map<String, Integer> instanceNameToIdMap = getInstanceNameToIdMap(instances);
+    for (Map<String, String> instanceStateMep : segmentAssignment.values()) {
+      for (String instanceName : instanceStateMep.keySet()) {
+        Integer instanceId = instanceNameToIdMap.get(instanceName);
+        if (instanceId != null) {
+          numSegmentsPerInstance[instanceId]++;
+        }
+      }
+    }
+    return numSegmentsPerInstance;
+  }
+
+  private static Map<String, Integer> getInstanceNameToIdMap(List<String> instances) {
+    int numInstances = instances.size();
+    Map<String, Integer> instanceNameToIdMap = new HashMap<>();
+    for (int i = 0; i < numInstances; i++) {
+      instanceNameToIdMap.put(instances.get(i), i);
+    }
+    return instanceNameToIdMap;
+  }
+
+  /**
+   * Returns the instances for the balance number segment assignment strategy.
+   */
+  static List<String> getInstancesForBalanceNumStrategy(HelixManager helixManager, TableConfig tableConfig,
+      int replication, InstancePartitionsType instancePartitionsType) {
+    InstancePartitions instancePartitions =
+        InstancePartitionsUtils.fetchOrComputeInstancePartitions(helixManager, tableConfig, instancePartitionsType);
+    Preconditions.checkArgument(instancePartitions.getNumPartitions() == 1 && instancePartitions.getNumReplicas() == 1,
+        "The instance partitions: %s should contain only 1 partition and 1 replica", instancePartitions.getName());
+    List<String> instances = instancePartitions.getInstances(0, 0);
+    Preconditions.checkState(instances.size() >= replication,
+        "There are less instances: %d than the replication: %d for table: %s", instances.size(), replication,
+        tableConfig.getTableName());
+    return instances;
+  }
+
+  /**
+   * Rebalances the table with Helix AutoRebalanceStrategy for the balance number segment assignment strategy.
+   */
+  static Map<String, Map<String, String>> rebalanceTableWithHelixAutoRebalanceStrategy(
+      Map<String, Map<String, String>> currentAssignment, List<String> instances, int replication) {
+    // Use Helix AutoRebalanceStrategy to rebalance the table
+    LinkedHashMap<String, Integer> states = new LinkedHashMap<>();
+    states.put(SegmentOnlineOfflineStateModel.ONLINE, replication);
+    AutoRebalanceStrategy autoRebalanceStrategy =
+        new AutoRebalanceStrategy(null, new ArrayList<>(currentAssignment.keySet()), states);
+    // Make a copy of the current assignment because this step might change the passed in assignment
+    Map<String, Map<String, String>> currentAssignmentCopy = new TreeMap<>();
+    for (Map.Entry<String, Map<String, String>> entry : currentAssignment.entrySet()) {
+      String segmentName = entry.getKey();
+      Map<String, String> instanceStateMap = entry.getValue();
+      currentAssignmentCopy.put(segmentName, new TreeMap<>(instanceStateMap));
+    }
+    return autoRebalanceStrategy.computePartitionAssignment(instances, instances, currentAssignmentCopy, null)
+        .getMapFields();
+  }
+
+  /**
+   * Rebalances the table for the replica-group based segment assignment strategy.
+   * <p>The number of partitions for the segments can be different from the number of partitions in the instance
+   * partitions. Uniformly spray the segment partitions over the instance partitions.
+   */
+  static Map<String, Map<String, String>> rebalanceReplicaGroupBasedTable(
+      Map<String, Map<String, String>> currentAssignment, InstancePartitions instancePartitions,
+      Map<Integer, Set<String>> partitionIdToSegmentsMap) {
+    Map<String, Map<String, String>> newAssignment = new TreeMap<>();
+    int numPartitions = instancePartitions.getNumPartitions();
+    for (Map.Entry<Integer, Set<String>> entry : partitionIdToSegmentsMap.entrySet()) {
+      // Uniformly spray the segment partitions over the instance partitions
+      int partitionId = entry.getKey() % numPartitions;
+      SegmentAssignmentUtils
+          .rebalanceReplicaGroupBasedPartition(currentAssignment, instancePartitions, partitionId, entry.getValue(),
+              newAssignment);
+    }
+    return newAssignment;
+  }
+
+  /**
+   * Rebalances one partition of the table for the replica-group based segment assignment strategy.
+   * <ul>
+   *   <li>
+   *     1. Calculate the target number of segments on each server
+   *   </li>
+   *   <li>
+   *     2. Loop over all the segments and keep the assignment if target number of segments for the server has not been
+   *     reached and track the not assigned segments
+   *   </li>
+   *   <li>
+   *     3. Assign the left-over segments to the servers with the least segments, or the smallest index if there is a
+   *     tie
+   *   </li>
+   *   <li>
+   *     4. Mirror the assignment to other replicas
+   *   </li>
+   * </ul>
+   */
+  static void rebalanceReplicaGroupBasedPartition(Map<String, Map<String, String>> currentAssignment,
+      InstancePartitions instancePartitions, int partitionId, Set<String> segments,
+      Map<String, Map<String, String>> newAssignment) {
+    // Fetch instances in replica 0
+    List<String> instances = instancePartitions.getInstances(partitionId, 0);
+    Map<String, Integer> instanceNameToIdMap = SegmentAssignmentUtils.getInstanceNameToIdMap(instances);
+
+    // Calculate target number of segments per instance
+    // NOTE: in order to minimize the segment movements, use the ceiling of the quotient
+    int numInstances = instances.size();
+    int numSegments = segments.size();
+    int targetNumSegmentsPerInstance = (numSegments + numInstances - 1) / numInstances;
+
+    // Do not move segment if target number of segments is not reached, track the segments need to be moved
+    int[] numSegmentsAssignedPerInstance = new int[numInstances];
+    List<String> segmentsNotAssigned = new ArrayList<>();
+    for (Map.Entry<String, Map<String, String>> entry : currentAssignment.entrySet()) {
+      String segmentName = entry.getKey();
+      // Skip segments not in the partition
+      if (!segments.contains(segmentName)) {
+        continue;
+      }
+      boolean segmentAssigned = false;
+      for (String instanceName : entry.getValue().keySet()) {
+        Integer instanceId = instanceNameToIdMap.get(instanceName);
+        if (instanceId != null && numSegmentsAssignedPerInstance[instanceId] < targetNumSegmentsPerInstance) {
+          newAssignment
+              .put(segmentName, getReplicaGroupBasedInstanceStateMap(instancePartitions, partitionId, instanceId));
+          numSegmentsAssignedPerInstance[instanceId]++;
+          segmentAssigned = true;
+          break;
+        }
+      }
+      if (!segmentAssigned) {
+        segmentsNotAssigned.add(segmentName);
+      }
+    }
+
+    // Assign each not assigned segment to the instance with the least segments, or the smallest id if there is a tie
+    PriorityQueue<Pairs.IntPair> heap = new PriorityQueue<>(numInstances, Pairs.intPairComparator());
+    for (int instanceId = 0; instanceId < numInstances; instanceId++) {
+      heap.add(new Pairs.IntPair(numSegmentsAssignedPerInstance[instanceId], instanceId));
+    }
+    for (String segmentName : segmentsNotAssigned) {
+      Pairs.IntPair intPair = heap.remove();
+      int instanceId = intPair.getRight();
+      newAssignment.put(segmentName, getReplicaGroupBasedInstanceStateMap(instancePartitions, partitionId, instanceId));
+      intPair.setLeft(intPair.getLeft() + 1);
+      heap.add(intPair);
+    }
+  }
+
+  /**
+   * Returns the map from instance name to Helix partition state for the replica-group based segment assignment
+   * strategy, which can be put into the segment assignment. The instances are picked from the instance partitions by
+   * the given partition id and instance id.
+   */
+  private static Map<String, String> getReplicaGroupBasedInstanceStateMap(InstancePartitions instancePartitions,
+      int partitionId, int instanceId) {
+    Map<String, String> instanceStateMap = new TreeMap<>();
+    int numReplicas = instancePartitions.getNumReplicas();
+    for (int replicaId = 0; replicaId < numReplicas; replicaId++) {
+      instanceStateMap.put(instancePartitions.getInstances(partitionId, replicaId).get(instanceId),
+          SegmentOnlineOfflineStateModel.ONLINE);
+    }
+    return instanceStateMap;
+  }
+
+  /**
+   * Returns the map from instance name to Helix partition state, which can be put into the segment assignment.
+   */
+  static Map<String, String> getInstanceStateMap(List<String> instances, String state) {
+    Map<String, String> instanceStateMap = new TreeMap<>();
+    for (String instanceName : instances) {
+      instanceStateMap.put(instanceName, state);
+    }
+    return instanceStateMap;
+  }
+
+  /**
+   * Returns a map from instance name to number of segments to be moved to it.
+   */
+  static Map<String, Integer> getNumSegmentsToBeMovedPerInstance(Map<String, Map<String, String>> oldAssignment,
+      Map<String, Map<String, String>> newAssignment) {
+    Map<String, Integer> numSegmentsToBeMovedPerInstance = new TreeMap<>();
+    for (Map.Entry<String, Map<String, String>> entry : newAssignment.entrySet()) {
+      String segmentName = entry.getKey();
+      Set<String> newInstancesAssigned = entry.getValue().keySet();
+      Set<String> oldInstancesAssigned = oldAssignment.get(segmentName).keySet();
+      // For each new assigned instance, check if the segment needs to be moved to it
+      for (String instanceName : newInstancesAssigned) {
+        if (!oldInstancesAssigned.contains(instanceName)) {
+          numSegmentsToBeMovedPerInstance.merge(instanceName, 1, Integer::sum);
+        }
+      }
+    }
+    return numSegmentsToBeMovedPerInstance;
+  }
+
+  /**
+   * Class that splits segment assignment into CONSUMING segments and COMPLETED segments.
+   */
+  static class CompletedConsumingSegmentAssignmentPair {
+    private final Map<String, Map<String, String>> _completedSegmentAssignment = new TreeMap<>();
+    private final Map<String, Map<String, String>> _consumingSegmentAssignment = new TreeMap<>();
+
+    CompletedConsumingSegmentAssignmentPair(Map<String, Map<String, String>> segmentAssignment) {
+      for (Map.Entry<String, Map<String, String>> entry : segmentAssignment.entrySet()) {
+        Map<String, String> instanceStateMap = entry.getValue();
+        if (instanceStateMap.values().contains(RealtimeSegmentOnlineOfflineStateModel.ONLINE)) {
+          _completedSegmentAssignment.put(entry.getKey(), instanceStateMap);
+        } else {
+          _consumingSegmentAssignment.put(entry.getKey(), instanceStateMap);
+        }
+      }
+    }
+
+    Map<String, Map<String, String>> getCompletedSegmentAssignment() {
+      return _completedSegmentAssignment;
+    }
+
+    Map<String, Map<String, String>> getConsumingSegmentAssignment() {
+      return _consumingSegmentAssignment;
+    }
+  }
+}
diff --git a/pinot-controller/src/test/java/org/apache/pinot/controller/helix/core/assignment/segment/OfflineBalanceNumSegmentAssignmentStrategyTest.java b/pinot-controller/src/test/java/org/apache/pinot/controller/helix/core/assignment/segment/OfflineBalanceNumSegmentAssignmentStrategyTest.java
new file mode 100644
index 0000000..762773b
--- /dev/null
+++ b/pinot-controller/src/test/java/org/apache/pinot/controller/helix/core/assignment/segment/OfflineBalanceNumSegmentAssignmentStrategyTest.java
@@ -0,0 +1,137 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.controller.helix.core.assignment.segment;
+
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+import java.util.TreeMap;
+import org.apache.helix.HelixManager;
+import org.apache.helix.ZNRecord;
+import org.apache.helix.store.zk.ZkHelixPropertyStore;
+import org.apache.pinot.common.config.TableConfig;
+import org.apache.pinot.common.metadata.ZKMetadataProvider;
+import org.apache.pinot.common.utils.CommonConstants.Helix.StateModel.SegmentOnlineOfflineStateModel;
+import org.apache.pinot.common.utils.CommonConstants.Helix.TableType;
+import org.apache.pinot.common.utils.InstancePartitionsType;
+import org.apache.pinot.controller.helix.core.assignment.InstancePartitions;
+import org.testng.annotations.BeforeClass;
+import org.testng.annotations.Test;
+
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.anyInt;
+import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+import static org.testng.Assert.assertEquals;
+import static org.testng.Assert.assertTrue;
+
+
+public class OfflineBalanceNumSegmentAssignmentStrategyTest {
+  private static final int NUM_REPLICAS = 3;
+  private static final String SEGMENT_NAME_PREFIX = "segment_";
+  private static final int NUM_SEGMENTS = 100;
+  private static final List<String> SEGMENTS =
+      SegmentAssignmentTestUtils.getNameList(SEGMENT_NAME_PREFIX, NUM_SEGMENTS);
+  private static final String INSTANCE_NAME_PREFIX = "instance_";
+  private static final int NUM_INSTANCES = 10;
+  private static final List<String> INSTANCES =
+      SegmentAssignmentTestUtils.getNameList(INSTANCE_NAME_PREFIX, NUM_INSTANCES);
+  private static final String RAW_TABLE_NAME = "testTable";
+  private static final String INSTANCE_PARTITIONS_NAME =
+      InstancePartitionsType.OFFLINE.getInstancePartitionsName(RAW_TABLE_NAME);
+
+  private SegmentAssignmentStrategy _strategy;
+
+  @BeforeClass
+  public void setUp() {
+    // {
+    //   0_0=[instance_0, instance_1, instance_2, instance_3, instance_4, instance_5, instance_6, instance_7, instance_8, instance_9]
+    // }
+    InstancePartitions instancePartitions = new InstancePartitions(INSTANCE_PARTITIONS_NAME);
+    instancePartitions.setInstances(0, 0, INSTANCES);
+
+    // Mock HelixManager
+    @SuppressWarnings("unchecked")
+    ZkHelixPropertyStore<ZNRecord> propertyStore = mock(ZkHelixPropertyStore.class);
+    when(propertyStore
+        .get(eq(ZKMetadataProvider.constructPropertyStorePathForInstancePartitions(INSTANCE_PARTITIONS_NAME)), any(),
+            anyInt())).thenReturn(instancePartitions.toZNRecord());
+    HelixManager helixManager = mock(HelixManager.class);
+    when(helixManager.getHelixPropertyStore()).thenReturn(propertyStore);
+
+    TableConfig tableConfig =
+        new TableConfig.Builder(TableType.OFFLINE).setTableName(RAW_TABLE_NAME).setNumReplicas(NUM_REPLICAS).build();
+    _strategy = SegmentAssignmentStrategyFactory.getSegmentAssignmentStrategy(helixManager, tableConfig);
+  }
+
+  @Test
+  public void testFactory() {
+    assertTrue(_strategy instanceof OfflineBalanceNumSegmentAssignmentStrategy);
+  }
+
+  @Test
+  public void testAssignSegment() {
+    Map<String, Map<String, String>> currentAssignment = new TreeMap<>();
+
+    // Segment 0 should be assigned to instance 0, 1, 2
+    // Segment 1 should be assigned to instance 3, 4, 5
+    // Segment 2 should be assigned to instance 6, 7, 8
+    // Segment 3 should be assigned to instance 9, 0, 1
+    // Segment 4 should be assigned to instance 2, 3, 4
+    // ...
+    int expectedAssignedInstanceId = 0;
+    for (String segmentName : SEGMENTS) {
+      List<String> instancesAssigned = _strategy.assignSegment(segmentName, currentAssignment);
+      assertEquals(instancesAssigned.size(), NUM_REPLICAS);
+      for (int replicaId = 0; replicaId < NUM_REPLICAS; replicaId++) {
+        assertEquals(instancesAssigned.get(replicaId), INSTANCES.get(expectedAssignedInstanceId));
+        expectedAssignedInstanceId = (expectedAssignedInstanceId + 1) % NUM_INSTANCES;
+      }
+      currentAssignment.put(segmentName,
+          SegmentAssignmentUtils.getInstanceStateMap(instancesAssigned, SegmentOnlineOfflineStateModel.ONLINE));
+    }
+  }
+
+  @Test
+  public void testTableBalanced() {
+    Map<String, Map<String, String>> currentAssignment = new TreeMap<>();
+    for (String segmentName : SEGMENTS) {
+      List<String> instancesAssigned = _strategy.assignSegment(segmentName, currentAssignment);
+      currentAssignment.put(segmentName,
+          SegmentAssignmentUtils.getInstanceStateMap(instancesAssigned, SegmentOnlineOfflineStateModel.ONLINE));
+    }
+
+    // There should be 100 segments assigned
+    assertEquals(currentAssignment.size(), NUM_SEGMENTS);
+    // Each segment should have 3 replicas
+    for (Map<String, String> instanceStateMap : currentAssignment.values()) {
+      assertEquals(instanceStateMap.size(), NUM_REPLICAS);
+    }
+    // Each instance should have 30 segments assigned
+    int[] numSegmentsAssignedPerInstance =
+        SegmentAssignmentUtils.getNumSegmentsAssignedPerInstance(currentAssignment, INSTANCES);
+    int[] expectedNumSegmentsAssignedPerInstance = new int[NUM_INSTANCES];
+    int numSegmentsPerInstance = NUM_SEGMENTS * NUM_REPLICAS / NUM_INSTANCES;
+    Arrays.fill(expectedNumSegmentsAssignedPerInstance, numSegmentsPerInstance);
+    assertEquals(numSegmentsAssignedPerInstance, expectedNumSegmentsAssignedPerInstance);
+    // Current assignment should already be balanced
+    assertEquals(_strategy.rebalanceTable(currentAssignment, null), currentAssignment);
+  }
+}
diff --git a/pinot-controller/src/test/java/org/apache/pinot/controller/helix/core/assignment/segment/OfflineReplicaGroupSegmentAssignmentStrategyTest.java b/pinot-controller/src/test/java/org/apache/pinot/controller/helix/core/assignment/segment/OfflineReplicaGroupSegmentAssignmentStrategyTest.java
new file mode 100644
index 0000000..47ecdd8
--- /dev/null
+++ b/pinot-controller/src/test/java/org/apache/pinot/controller/helix/core/assignment/segment/OfflineReplicaGroupSegmentAssignmentStrategyTest.java
@@ -0,0 +1,289 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.controller.helix.core.assignment.segment;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+import java.util.TreeMap;
+import org.apache.helix.HelixManager;
+import org.apache.helix.ZNRecord;
+import org.apache.helix.store.zk.ZkHelixPropertyStore;
+import org.apache.pinot.common.config.ReplicaGroupStrategyConfig;
+import org.apache.pinot.common.config.TableConfig;
+import org.apache.pinot.common.config.TableNameBuilder;
+import org.apache.pinot.common.metadata.ZKMetadataProvider;
+import org.apache.pinot.common.metadata.segment.ColumnPartitionMetadata;
+import org.apache.pinot.common.metadata.segment.OfflineSegmentZKMetadata;
+import org.apache.pinot.common.metadata.segment.SegmentPartitionMetadata;
+import org.apache.pinot.common.utils.CommonConstants;
+import org.apache.pinot.common.utils.CommonConstants.Helix.StateModel.SegmentOnlineOfflineStateModel;
+import org.apache.pinot.common.utils.InstancePartitionsType;
+import org.apache.pinot.controller.helix.core.assignment.InstancePartitions;
+import org.testng.annotations.BeforeClass;
+import org.testng.annotations.Test;
+
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.anyInt;
+import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+import static org.testng.Assert.assertEquals;
+import static org.testng.Assert.assertTrue;
+
+
+public class OfflineReplicaGroupSegmentAssignmentStrategyTest {
+  private static final int NUM_REPLICAS = 3;
+  private static final String SEGMENT_NAME_PREFIX = "segment_";
+  private static final int NUM_SEGMENTS = 90;
+  private static final List<String> SEGMENTS =
+      SegmentAssignmentTestUtils.getNameList(SEGMENT_NAME_PREFIX, NUM_SEGMENTS);
+  private static final String INSTANCE_NAME_PREFIX = "instance_";
+  private static final int NUM_INSTANCES = 18;
+  private static final List<String> INSTANCES =
+      SegmentAssignmentTestUtils.getNameList(INSTANCE_NAME_PREFIX, NUM_INSTANCES);
+  private static final String RAW_TABLE_NAME_WITHOUT_PARTITION = "testTableWithoutPartition";
+  private static final String INSTANCE_PARTITIONS_NAME_WITHOUT_PARTITION =
+      InstancePartitionsType.OFFLINE.getInstancePartitionsName(RAW_TABLE_NAME_WITHOUT_PARTITION);
+  private static final String RAW_TABLE_NAME_WITH_PARTITION = "testTableWithPartition";
+  private static final String OFFLINE_TABLE_NAME_WITH_PARTITION =
+      TableNameBuilder.OFFLINE.tableNameWithType(RAW_TABLE_NAME_WITH_PARTITION);
+  private static final String INSTANCE_PARTITIONS_NAME_WITH_PARTITION =
+      InstancePartitionsType.OFFLINE.getInstancePartitionsName(RAW_TABLE_NAME_WITH_PARTITION);
+  private static final String PARTITION_COLUMN = "partitionColumn";
+  private static final int NUM_PARTITIONS = 3;
+
+  private SegmentAssignmentStrategy _strategyWithoutPartition;
+  private SegmentAssignmentStrategy _strategyWithPartition;
+
+  @BeforeClass
+  public void setUp() {
+    // {
+    //   0_0=[instance_0, instance_1, instance_2, instance_3, instance_4, instance_5],
+    //   0_1=[instance_6, instance_7, instance_8, instance_9, instance_10, instance_11],
+    //   0_2=[instance_12, instance_13, instance_14, instance_15, instance_16, instance_17]
+    // }
+    InstancePartitions instancePartitionsWithoutPartition =
+        new InstancePartitions(INSTANCE_PARTITIONS_NAME_WITHOUT_PARTITION);
+    int numInstancesPerReplica = NUM_INSTANCES / NUM_REPLICAS;
+    int instanceIdToAdd = 0;
+    for (int replicaId = 0; replicaId < NUM_REPLICAS; replicaId++) {
+      List<String> instancesForReplica = new ArrayList<>(numInstancesPerReplica);
+      for (int i = 0; i < numInstancesPerReplica; i++) {
+        instancesForReplica.add(INSTANCES.get(instanceIdToAdd++));
+      }
+      instancePartitionsWithoutPartition.setInstances(0, replicaId, instancesForReplica);
+    }
+
+    // Mock HelixManager
+    @SuppressWarnings("unchecked")
+    ZkHelixPropertyStore<ZNRecord> propertyStoreWithoutPartitions = mock(ZkHelixPropertyStore.class);
+    when(propertyStoreWithoutPartitions.get(eq(ZKMetadataProvider
+        .constructPropertyStorePathForInstancePartitions(INSTANCE_PARTITIONS_NAME_WITHOUT_PARTITION)), any(), anyInt()))
+        .thenReturn(instancePartitionsWithoutPartition.toZNRecord());
+    HelixManager helixManagerWithoutPartitions = mock(HelixManager.class);
+    when(helixManagerWithoutPartitions.getHelixPropertyStore()).thenReturn(propertyStoreWithoutPartitions);
+
+    TableConfig tableConfigWithoutPartitions =
+        new TableConfig.Builder(CommonConstants.Helix.TableType.OFFLINE).setTableName(RAW_TABLE_NAME_WITHOUT_PARTITION)
+            .setNumReplicas(NUM_REPLICAS).setSegmentAssignmentStrategy(
+            SegmentAssignmentStrategyFactory.Strategy.ReplicaGroupSegmentAssignmentStrategy.name()).build();
+    _strategyWithoutPartition = SegmentAssignmentStrategyFactory
+        .getSegmentAssignmentStrategy(helixManagerWithoutPartitions, tableConfigWithoutPartitions);
+
+    // {
+    //   0_0=[instance_0, instance_1],
+    //   0_1=[instance_6, instance_7],
+    //   0_2=[instance_12, instance_13],
+    //   1_0=[instance_2, instance_3],
+    //   1_1=[instance_8, instance_9],
+    //   1_2=[instance_14, instance_15],
+    //   2_0=[instance_4, instance_5],
+    //   2_1=[instance_10, instance_11],
+    //   2_2=[instance_16, instance_17]
+    // }
+    InstancePartitions instancePartitionsWithPartition =
+        new InstancePartitions(INSTANCE_PARTITIONS_NAME_WITH_PARTITION);
+    int numInstancesPerPartition = numInstancesPerReplica / NUM_REPLICAS;
+    instanceIdToAdd = 0;
+    for (int replicaId = 0; replicaId < NUM_REPLICAS; replicaId++) {
+      for (int partitionId = 0; partitionId < NUM_PARTITIONS; partitionId++) {
+        List<String> instancesForPartition = new ArrayList<>(numInstancesPerPartition);
+        for (int i = 0; i < numInstancesPerPartition; i++) {
+          instancesForPartition.add(INSTANCES.get(instanceIdToAdd++));
+        }
+        instancePartitionsWithPartition.setInstances(partitionId, replicaId, instancesForPartition);
+      }
+    }
+
+    // Mock HelixManager
+    @SuppressWarnings("unchecked")
+    ZkHelixPropertyStore<ZNRecord> propertyStoreWithPartitions = mock(ZkHelixPropertyStore.class);
+    when(propertyStoreWithPartitions.get(
+        eq(ZKMetadataProvider.constructPropertyStorePathForInstancePartitions(INSTANCE_PARTITIONS_NAME_WITH_PARTITION)),
+        any(), anyInt())).thenReturn(instancePartitionsWithPartition.toZNRecord());
+    List<ZNRecord> segmentZKMetadataZNRecords = new ArrayList<>(NUM_SEGMENTS);
+    for (int segmentId = 0; segmentId < NUM_SEGMENTS; segmentId++) {
+      String segmentName = SEGMENTS.get(segmentId);
+      OfflineSegmentZKMetadata segmentZKMetadata = new OfflineSegmentZKMetadata();
+      segmentZKMetadata.setSegmentName(segmentName);
+      int partitionId = segmentId % NUM_PARTITIONS;
+      segmentZKMetadata.setPartitionMetadata(new SegmentPartitionMetadata(Collections.singletonMap(PARTITION_COLUMN,
+          new ColumnPartitionMetadata(null, NUM_PARTITIONS, Collections.singleton(partitionId)))));
+      ZNRecord segmentZKMetadataZNRecord = segmentZKMetadata.toZNRecord();
+      when(propertyStoreWithPartitions.get(
+          eq(ZKMetadataProvider.constructPropertyStorePathForSegment(OFFLINE_TABLE_NAME_WITH_PARTITION, segmentName)),
+          any(), anyInt())).thenReturn(segmentZKMetadataZNRecord);
+      segmentZKMetadataZNRecords.add(segmentZKMetadataZNRecord);
+    }
+    when(propertyStoreWithPartitions
+        .getChildren(eq(ZKMetadataProvider.constructPropertyStorePathForResource(OFFLINE_TABLE_NAME_WITH_PARTITION)),
+            any(), anyInt())).thenReturn(segmentZKMetadataZNRecords);
+    HelixManager helixManagerWithPartitions = mock(HelixManager.class);
+    when(helixManagerWithPartitions.getHelixPropertyStore()).thenReturn(propertyStoreWithPartitions);
+
+    ReplicaGroupStrategyConfig strategyConfig = new ReplicaGroupStrategyConfig();
+    strategyConfig.setPartitionColumn(PARTITION_COLUMN);
+    TableConfig tableConfigWithPartitions =
+        new TableConfig.Builder(CommonConstants.Helix.TableType.OFFLINE).setTableName(RAW_TABLE_NAME_WITH_PARTITION)
+            .setNumReplicas(NUM_REPLICAS).setSegmentAssignmentStrategy(
+            SegmentAssignmentStrategyFactory.Strategy.ReplicaGroupSegmentAssignmentStrategy.name()).build();
+    tableConfigWithPartitions.getValidationConfig().setReplicaGroupStrategyConfig(strategyConfig);
+    _strategyWithPartition = SegmentAssignmentStrategyFactory
+        .getSegmentAssignmentStrategy(helixManagerWithPartitions, tableConfigWithPartitions);
+  }
+
+  @Test
+  public void testFactory() {
+    assertTrue(_strategyWithoutPartition instanceof OfflineReplicaGroupSegmentAssignmentStrategy);
+    assertTrue(_strategyWithPartition instanceof OfflineReplicaGroupSegmentAssignmentStrategy);
+  }
+
+  @Test
+  public void testAssignSegmentWithoutPartition() {
+    int numInstancesPerReplica = NUM_INSTANCES / NUM_REPLICAS;
+    Map<String, Map<String, String>> currentAssignment = new TreeMap<>();
+    for (int segmentId = 0; segmentId < NUM_SEGMENTS; segmentId++) {
+      String segmentName = SEGMENTS.get(segmentId);
+      List<String> instancesAssigned = _strategyWithoutPartition.assignSegment(segmentName, currentAssignment);
+      assertEquals(instancesAssigned.size(), NUM_REPLICAS);
+      for (int replicaId = 0; replicaId < NUM_REPLICAS; replicaId++) {
+
+        // Segment 0 should be assigned to instance 0, 6, 12
+        // Segment 1 should be assigned to instance 1, 7, 13
+        // Segment 2 should be assigned to instance 2, 8, 14
+        // Segment 3 should be assigned to instance 3, 9, 15
+        // Segment 4 should be assigned to instance 4, 10, 16
+        // Segment 5 should be assigned to instance 5, 11, 17
+        // Segment 6 should be assigned to instance 0, 6, 12
+        // Segment 7 should be assigned to instance 1, 7, 13
+        // ...
+        int expectedAssignedInstanceId = segmentId % numInstancesPerReplica + replicaId * numInstancesPerReplica;
+        assertEquals(instancesAssigned.get(replicaId), INSTANCES.get(expectedAssignedInstanceId));
+      }
+      currentAssignment.put(segmentName,
+          SegmentAssignmentUtils.getInstanceStateMap(instancesAssigned, SegmentOnlineOfflineStateModel.ONLINE));
+    }
+  }
+
+  @Test
+  public void testAssignSegmentWithPartition() {
+    int numInstancesPerReplica = NUM_INSTANCES / NUM_REPLICAS;
+    Map<String, Map<String, String>> currentAssignment = new TreeMap<>();
+    int numInstancesPerPartition = numInstancesPerReplica / NUM_PARTITIONS;
+    for (int segmentId = 0; segmentId < NUM_SEGMENTS; segmentId++) {
+      String segmentName = SEGMENTS.get(segmentId);
+      List<String> instancesAssigned = _strategyWithPartition.assignSegment(segmentName, currentAssignment);
+      assertEquals(instancesAssigned.size(), NUM_REPLICAS);
+      int partitionId = segmentId % NUM_PARTITIONS;
+      for (int replicaId = 0; replicaId < NUM_REPLICAS; replicaId++) {
+
+        // Segment 0 (partition 0) should be assigned to instance 0, 6, 12
+        // Segment 1 (partition 1) should be assigned to instance 2, 8, 14
+        // Segment 2 (partition 2) should be assigned to instance 4, 10, 16
+        // Segment 3 (partition 0) should be assigned to instance 1, 7, 13
+        // Segment 4 (partition 1) should be assigned to instance 3, 9, 15
+        // Segment 5 (partition 2) should be assigned to instance 5, 11, 17
+        // Segment 6 (partition 0) should be assigned to instance 0, 6, 12
+        // Segment 7 (partition 1) should be assigned to instance 2, 8, 14
+        // ...
+        int expectedAssignedInstanceId =
+            (segmentId % numInstancesPerReplica) / NUM_PARTITIONS + partitionId * numInstancesPerPartition
+                + replicaId * numInstancesPerReplica;
+        assertEquals(instancesAssigned.get(replicaId), INSTANCES.get(expectedAssignedInstanceId));
+      }
+      currentAssignment.put(segmentName,
+          SegmentAssignmentUtils.getInstanceStateMap(instancesAssigned, SegmentOnlineOfflineStateModel.ONLINE));
+    }
+  }
+
+  @Test
+  public void testTableBalancedWithoutPartition() {
+    Map<String, Map<String, String>> currentAssignment = new TreeMap<>();
+    for (String segmentName : SEGMENTS) {
+      List<String> instancesAssigned = _strategyWithoutPartition.assignSegment(segmentName, currentAssignment);
+      currentAssignment.put(segmentName,
+          SegmentAssignmentUtils.getInstanceStateMap(instancesAssigned, SegmentOnlineOfflineStateModel.ONLINE));
+    }
+
+    // There should be 90 segments assigned
+    assertEquals(currentAssignment.size(), NUM_SEGMENTS);
+    // Each segment should have 3 replicas
+    for (Map<String, String> instanceStateMap : currentAssignment.values()) {
+      assertEquals(instanceStateMap.size(), NUM_REPLICAS);
+    }
+    // Each instance should have 15 segments assigned
+    int[] numSegmentsAssignedPerInstance =
+        SegmentAssignmentUtils.getNumSegmentsAssignedPerInstance(currentAssignment, INSTANCES);
+    int[] expectedNumSegmentsAssignedPerInstance = new int[NUM_INSTANCES];
+    int numSegmentsPerInstance = NUM_SEGMENTS * NUM_REPLICAS / NUM_INSTANCES;
+    Arrays.fill(expectedNumSegmentsAssignedPerInstance, numSegmentsPerInstance);
+    assertEquals(numSegmentsAssignedPerInstance, expectedNumSegmentsAssignedPerInstance);
+    // Current assignment should already be balanced
+    assertEquals(_strategyWithoutPartition.rebalanceTable(currentAssignment, null), currentAssignment);
+  }
+
+  @Test
+  public void testTableBalancedWithPartition() {
+    Map<String, Map<String, String>> currentAssignment = new TreeMap<>();
+    for (String segmentName : SEGMENTS) {
+      List<String> instancesAssigned = _strategyWithPartition.assignSegment(segmentName, currentAssignment);
+      currentAssignment.put(segmentName,
+          SegmentAssignmentUtils.getInstanceStateMap(instancesAssigned, SegmentOnlineOfflineStateModel.ONLINE));
+    }
+
+    // There should be 90 segments assigned
+    assertEquals(currentAssignment.size(), NUM_SEGMENTS);
+    // Each segment should have 3 replicas
+    for (Map<String, String> instanceStateMap : currentAssignment.values()) {
+      assertEquals(instanceStateMap.size(), NUM_REPLICAS);
+    }
+    // Each instance should have 15 segments assigned
+    int[] numSegmentsAssignedPerInstance =
+        SegmentAssignmentUtils.getNumSegmentsAssignedPerInstance(currentAssignment, INSTANCES);
+    int[] expectedNumSegmentsAssignedPerInstance = new int[NUM_INSTANCES];
+    int numSegmentsPerInstance = NUM_SEGMENTS * NUM_REPLICAS / NUM_INSTANCES;
+    Arrays.fill(expectedNumSegmentsAssignedPerInstance, numSegmentsPerInstance);
+    assertEquals(numSegmentsAssignedPerInstance, expectedNumSegmentsAssignedPerInstance);
+    // Current assignment should already be balanced
+    assertEquals(_strategyWithPartition.rebalanceTable(currentAssignment, null), currentAssignment);
+  }
+}
diff --git a/pinot-controller/src/test/java/org/apache/pinot/controller/helix/core/assignment/segment/RealtimeBalanceNumSegmentAssignmentStrategyTest.java b/pinot-controller/src/test/java/org/apache/pinot/controller/helix/core/assignment/segment/RealtimeBalanceNumSegmentAssignmentStrategyTest.java
new file mode 100644
index 0000000..87ef462
--- /dev/null
+++ b/pinot-controller/src/test/java/org/apache/pinot/controller/helix/core/assignment/segment/RealtimeBalanceNumSegmentAssignmentStrategyTest.java
@@ -0,0 +1,208 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.controller.helix.core.assignment.segment;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+import java.util.TreeMap;
+import org.apache.commons.configuration.BaseConfiguration;
+import org.apache.helix.HelixManager;
+import org.apache.helix.ZNRecord;
+import org.apache.helix.store.zk.ZkHelixPropertyStore;
+import org.apache.pinot.common.config.TableConfig;
+import org.apache.pinot.common.metadata.ZKMetadataProvider;
+import org.apache.pinot.common.utils.CommonConstants;
+import org.apache.pinot.common.utils.CommonConstants.Helix.StateModel.RealtimeSegmentOnlineOfflineStateModel;
+import org.apache.pinot.common.utils.InstancePartitionsType;
+import org.apache.pinot.common.utils.LLCSegmentName;
+import org.apache.pinot.controller.helix.core.assignment.InstancePartitions;
+import org.apache.pinot.controller.helix.core.rebalance.RebalanceUserConfigConstants;
+import org.testng.annotations.BeforeClass;
+import org.testng.annotations.Test;
+
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.anyInt;
+import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+import static org.testng.Assert.assertEquals;
+import static org.testng.Assert.assertTrue;
+
+
+public class RealtimeBalanceNumSegmentAssignmentStrategyTest {
+  private static final int NUM_REPLICAS = 3;
+  private static final int NUM_PARTITIONS = 4;
+  private static final int NUM_SEGMENTS = 100;
+  private static final String CONSUMING_INSTANCE_NAME_PREFIX = "consumingInstance_";
+  private static final int NUM_CONSUMING_INSTANCES = 9;
+  private static final List<String> CONSUMING_INSTANCES =
+      SegmentAssignmentTestUtils.getNameList(CONSUMING_INSTANCE_NAME_PREFIX, NUM_CONSUMING_INSTANCES);
+  private static final String COMPLETED_INSTANCE_NAME_PREFIX = "completedInstance_";
+  private static final int NUM_COMPLETED_INSTANCES = 10;
+  private static final List<String> COMPLETED_INSTANCES =
+      SegmentAssignmentTestUtils.getNameList(COMPLETED_INSTANCE_NAME_PREFIX, NUM_COMPLETED_INSTANCES);
+  private static final String RAW_TABLE_NAME = "testTable";
+  private static final String CONSUMING_INSTANCE_PARTITIONS_NAME =
+      InstancePartitionsType.CONSUMING.getInstancePartitionsName(RAW_TABLE_NAME);
+  private static final String COMPLETED_INSTANCE_PARTITIONS_NAME =
+      InstancePartitionsType.COMPLETED.getInstancePartitionsName(RAW_TABLE_NAME);
+
+  private List<String> _segments;
+  private SegmentAssignmentStrategy _strategy;
+
+  @BeforeClass
+  public void setUp() {
+    _segments = new ArrayList<>(NUM_SEGMENTS);
+    for (int segmentId = 0; segmentId < NUM_SEGMENTS; segmentId++) {
+      _segments.add(new LLCSegmentName(RAW_TABLE_NAME, segmentId % NUM_PARTITIONS, segmentId / NUM_PARTITIONS,
+          System.currentTimeMillis()).getSegmentName());
+    }
+
+    // Consuming instances:
+    // {
+    //   0_0=[instance_0, instance_1, instance_2, instance_3, instance_4, instance_5, instance_6, instance_7, instance_8]
+    // }
+    //        p0r0        p0r1        p0r2        p1r0        p1r1        p1r2        p2r0        p2r1        p2r2
+    //        p3r0        p3r1        p3r2
+    InstancePartitions consumingInstancePartitions = new InstancePartitions(CONSUMING_INSTANCE_PARTITIONS_NAME);
+    consumingInstancePartitions.setInstances(0, 0, CONSUMING_INSTANCES);
+
+    // Completed instances:
+    // {
+    //   0_0=[instance_0, instance_1, instance_2, instance_3, instance_4, instance_5, instance_6, instance_7, instance_8, instance_9]
+    // }
+    InstancePartitions completedInstancePartitions = new InstancePartitions(COMPLETED_INSTANCE_PARTITIONS_NAME);
+    completedInstancePartitions.setInstances(0, 0, COMPLETED_INSTANCES);
+
+    // Mock HelixManager
+    @SuppressWarnings("unchecked")
+    ZkHelixPropertyStore<ZNRecord> propertyStore = mock(ZkHelixPropertyStore.class);
+    when(propertyStore
+        .get(eq(ZKMetadataProvider.constructPropertyStorePathForInstancePartitions(CONSUMING_INSTANCE_PARTITIONS_NAME)),
+            any(), anyInt())).thenReturn(consumingInstancePartitions.toZNRecord());
+    when(propertyStore
+        .get(eq(ZKMetadataProvider.constructPropertyStorePathForInstancePartitions(COMPLETED_INSTANCE_PARTITIONS_NAME)),
+            any(), anyInt())).thenReturn(completedInstancePartitions.toZNRecord());
+    HelixManager helixManager = mock(HelixManager.class);
+    when(helixManager.getHelixPropertyStore()).thenReturn(propertyStore);
+
+    TableConfig tableConfig =
+        new TableConfig.Builder(CommonConstants.Helix.TableType.REALTIME).setTableName(RAW_TABLE_NAME)
+            .setNumReplicas(NUM_REPLICAS).setLLC(true).build();
+    _strategy = SegmentAssignmentStrategyFactory.getSegmentAssignmentStrategy(helixManager, tableConfig);
+  }
+
+  @Test
+  public void testFactory() {
+    assertTrue(_strategy instanceof RealtimeBalanceNumSegmentAssignmentStrategy);
+  }
+
+  @Test
+  public void testAssignSegment() {
+    Map<String, Map<String, String>> currentAssignment = new TreeMap<>();
+    for (int segmentId = 0; segmentId < NUM_SEGMENTS; segmentId++) {
+      String segmentName = _segments.get(segmentId);
+      List<String> instancesAssigned = _strategy.assignSegment(segmentName, currentAssignment);
+      assertEquals(instancesAssigned.size(), NUM_REPLICAS);
+      for (int replicaId = 0; replicaId < NUM_REPLICAS; replicaId++) {
+
+        // Segment 0 (partition 0) should be assigned to instance 0, 1, 2
+        // Segment 1 (partition 1) should be assigned to instance 3, 4, 5
+        // Segment 2 (partition 2) should be assigned to instance 6, 7, 8
+        // Segment 3 (partition 3) should be assigned to instance 0, 1, 2
+        // Segment 4 (partition 0) should be assigned to instance 0, 1, 2
+        // Segment 5 (partition 1) should be assigned to instance 3, 4, 5
+        // ...
+        int partitionId = segmentId % NUM_PARTITIONS;
+        int expectedAssignedInstanceId = (partitionId * NUM_REPLICAS + replicaId) % NUM_CONSUMING_INSTANCES;
+        assertEquals(instancesAssigned.get(replicaId), CONSUMING_INSTANCES.get(expectedAssignedInstanceId));
+      }
+      addToAssignment(currentAssignment, segmentId, instancesAssigned);
+    }
+  }
+
+  @Test
+  public void testRelocateCompletedSegments() {
+    Map<String, Map<String, String>> currentAssignment = new TreeMap<>();
+    for (int segmentId = 0; segmentId < NUM_SEGMENTS; segmentId++) {
+      String segmentName = _segments.get(segmentId);
+      List<String> instancesAssigned = _strategy.assignSegment(segmentName, currentAssignment);
+      addToAssignment(currentAssignment, segmentId, instancesAssigned);
+    }
+
+    // There should be 100 segments assigned
+    assertEquals(currentAssignment.size(), NUM_SEGMENTS);
+    // Each segment should have 3 replicas
+    for (Map<String, String> instanceStateMap : currentAssignment.values()) {
+      assertEquals(instanceStateMap.size(), NUM_REPLICAS);
+    }
+
+    // Rebalance should relocate all completed (ONLINE) segments to the completed instances
+    Map<String, Map<String, String>> newAssignment =
+        _strategy.rebalanceTable(currentAssignment, new BaseConfiguration());
+    assertEquals(newAssignment.size(), NUM_SEGMENTS);
+    for (int segmentId = 0; segmentId < NUM_SEGMENTS; segmentId++) {
+      if (segmentId < NUM_SEGMENTS - NUM_PARTITIONS) {
+        // Completed (ONLINE) segments
+        Map<String, String> instanceStateMap = newAssignment.get(_segments.get(segmentId));
+        for (Map.Entry<String, String> entry : instanceStateMap.entrySet()) {
+          assertTrue(entry.getKey().startsWith(COMPLETED_INSTANCE_NAME_PREFIX));
+          assertEquals(entry.getValue(), RealtimeSegmentOnlineOfflineStateModel.ONLINE);
+        }
+      } else {
+        // Consuming (CONSUMING) segments
+        Map<String, String> instanceStateMap = newAssignment.get(_segments.get(segmentId));
+        for (Map.Entry<String, String> entry : instanceStateMap.entrySet()) {
+          assertTrue(entry.getKey().startsWith(CONSUMING_INSTANCE_NAME_PREFIX));
+          assertEquals(entry.getValue(), RealtimeSegmentOnlineOfflineStateModel.CONSUMING);
+        }
+      }
+    }
+    // Relocated segments should be balanced (each instance should have at least 28 segments assigned)
+    int[] numSegmentsAssignedPerInstance =
+        SegmentAssignmentUtils.getNumSegmentsAssignedPerInstance(newAssignment, COMPLETED_INSTANCES);
+    assertEquals(numSegmentsAssignedPerInstance.length, NUM_COMPLETED_INSTANCES);
+    int expectedMinNumSegmentsPerInstance = (NUM_SEGMENTS - NUM_PARTITIONS) * NUM_REPLICAS / NUM_COMPLETED_INSTANCES;
+    for (int i = 0; i < NUM_COMPLETED_INSTANCES; i++) {
+      assertTrue(numSegmentsAssignedPerInstance[i] >= expectedMinNumSegmentsPerInstance);
+    }
+
+    // Rebalance all segments (both completed and consuming) should give the same assignment
+    BaseConfiguration config = new BaseConfiguration();
+    config.setProperty(RebalanceUserConfigConstants.INCLUDE_CONSUMING, true);
+    assertEquals(_strategy.rebalanceTable(currentAssignment, config), newAssignment);
+  }
+
+  private void addToAssignment(Map<String, Map<String, String>> currentAssignment, int segmentId,
+      List<String> instancesAssigned) {
+    // Change the state of the last segment in the same partition from CONSUMING to ONLINE if exists
+    if (segmentId >= NUM_PARTITIONS) {
+      String lastSegmentInPartition = _segments.get(segmentId - NUM_PARTITIONS);
+      Map<String, String> instanceStateMap = currentAssignment.get(lastSegmentInPartition);
+      currentAssignment.put(lastSegmentInPartition, SegmentAssignmentUtils
+          .getInstanceStateMap(new ArrayList<>(instanceStateMap.keySet()),
+              RealtimeSegmentOnlineOfflineStateModel.ONLINE));
+    }
+
+    // Add the new segment into the assignment as CONSUMING
+    currentAssignment.put(_segments.get(segmentId), SegmentAssignmentUtils
+        .getInstanceStateMap(instancesAssigned, RealtimeSegmentOnlineOfflineStateModel.CONSUMING));
+  }
+}
diff --git a/pinot-controller/src/test/java/org/apache/pinot/controller/helix/core/assignment/segment/RealtimeReplicaGroupSegmentAssignmentStrategyTest.java b/pinot-controller/src/test/java/org/apache/pinot/controller/helix/core/assignment/segment/RealtimeReplicaGroupSegmentAssignmentStrategyTest.java
new file mode 100644
index 0000000..c4c9a82
--- /dev/null
+++ b/pinot-controller/src/test/java/org/apache/pinot/controller/helix/core/assignment/segment/RealtimeReplicaGroupSegmentAssignmentStrategyTest.java
@@ -0,0 +1,230 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.controller.helix.core.assignment.segment;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+import java.util.TreeMap;
+import org.apache.commons.configuration.BaseConfiguration;
+import org.apache.helix.HelixManager;
+import org.apache.helix.ZNRecord;
+import org.apache.helix.store.zk.ZkHelixPropertyStore;
+import org.apache.pinot.common.config.TableConfig;
+import org.apache.pinot.common.metadata.ZKMetadataProvider;
+import org.apache.pinot.common.utils.CommonConstants;
+import org.apache.pinot.common.utils.CommonConstants.Helix.StateModel.RealtimeSegmentOnlineOfflineStateModel;
+import org.apache.pinot.common.utils.InstancePartitionsType;
+import org.apache.pinot.common.utils.LLCSegmentName;
+import org.apache.pinot.controller.helix.core.assignment.InstancePartitions;
+import org.apache.pinot.controller.helix.core.rebalance.RebalanceUserConfigConstants;
+import org.testng.annotations.BeforeClass;
+import org.testng.annotations.Test;
+
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.anyInt;
+import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+import static org.testng.Assert.assertEquals;
+import static org.testng.Assert.assertTrue;
+
+
+public class RealtimeReplicaGroupSegmentAssignmentStrategyTest {
+  private static final int NUM_REPLICAS = 3;
+  private static final int NUM_PARTITIONS = 4;
+  private static final int NUM_SEGMENTS = 100;
+  private static final String CONSUMING_INSTANCE_NAME_PREFIX = "consumingInstance_";
+  private static final int NUM_CONSUMING_INSTANCES = 9;
+  private static final List<String> CONSUMING_INSTANCES =
+      SegmentAssignmentTestUtils.getNameList(CONSUMING_INSTANCE_NAME_PREFIX, NUM_CONSUMING_INSTANCES);
+  private static final String COMPLETED_INSTANCE_NAME_PREFIX = "completedInstance_";
+  private static final int NUM_COMPLETED_INSTANCES = 12;
+  private static final List<String> COMPLETED_INSTANCES =
+      SegmentAssignmentTestUtils.getNameList(COMPLETED_INSTANCE_NAME_PREFIX, NUM_COMPLETED_INSTANCES);
+  private static final String RAW_TABLE_NAME = "testTable";
+  private static final String CONSUMING_INSTANCE_PARTITIONS_NAME =
+      InstancePartitionsType.CONSUMING.getInstancePartitionsName(RAW_TABLE_NAME);
+  private static final String COMPLETED_INSTANCE_PARTITIONS_NAME =
+      InstancePartitionsType.COMPLETED.getInstancePartitionsName(RAW_TABLE_NAME);
+
+  private List<String> _segments;
+  private SegmentAssignmentStrategy _strategy;
+
+  @BeforeClass
+  public void setUp() {
+    _segments = new ArrayList<>(NUM_SEGMENTS);
+    for (int segmentId = 0; segmentId < NUM_SEGMENTS; segmentId++) {
+      _segments.add(new LLCSegmentName(RAW_TABLE_NAME, segmentId % NUM_PARTITIONS, segmentId / NUM_PARTITIONS,
+          System.currentTimeMillis()).getSegmentName());
+    }
+
+    // Consuming instances:
+    // {
+    //   0_0=[instance_0, instance_1, instance_2],
+    //   0_1=[instance_3, instance_4, instance_5],
+    //   0_2=[instance_6, instance_7, instance_8]
+    // }
+    //        p0          p1          p2
+    //        p3
+    InstancePartitions consumingInstancePartitions = new InstancePartitions(CONSUMING_INSTANCE_PARTITIONS_NAME);
+    int numConsumingInstancesPerReplica = NUM_CONSUMING_INSTANCES / NUM_REPLICAS;
+    int consumingInstanceIdToAdd = 0;
+    for (int replicaId = 0; replicaId < NUM_REPLICAS; replicaId++) {
+      List<String> consumingInstancesForReplica = new ArrayList<>(numConsumingInstancesPerReplica);
+      for (int i = 0; i < numConsumingInstancesPerReplica; i++) {
+        consumingInstancesForReplica.add(CONSUMING_INSTANCES.get(consumingInstanceIdToAdd++));
+      }
+      consumingInstancePartitions.setInstances(0, replicaId, consumingInstancesForReplica);
+    }
+
+    // Completed instances:
+    // {
+    //   0_0=[instance_0, instance_1, instance_2, instance_3],
+    //   0_1=[instance_4, instance_5, instance_6, instance_7],
+    //   0_2=[instance_8, instance_9, instance_10, instance_11]
+    // }
+    InstancePartitions completedInstancePartitions = new InstancePartitions(COMPLETED_INSTANCE_PARTITIONS_NAME);
+    int numCompletedInstancesPerReplica = NUM_COMPLETED_INSTANCES / NUM_REPLICAS;
+    int completedInstanceIdToAdd = 0;
+    for (int replicaId = 0; replicaId < NUM_REPLICAS; replicaId++) {
+      List<String> completedInstancesForReplica = new ArrayList<>(numCompletedInstancesPerReplica);
+      for (int i = 0; i < numCompletedInstancesPerReplica; i++) {
+        completedInstancesForReplica.add(COMPLETED_INSTANCES.get(completedInstanceIdToAdd++));
+      }
+      completedInstancePartitions.setInstances(0, replicaId, completedInstancesForReplica);
+    }
+
+    // Mock HelixManager
+    @SuppressWarnings("unchecked")
+    ZkHelixPropertyStore<ZNRecord> propertyStore = mock(ZkHelixPropertyStore.class);
+    when(propertyStore
+        .get(eq(ZKMetadataProvider.constructPropertyStorePathForInstancePartitions(CONSUMING_INSTANCE_PARTITIONS_NAME)),
+            any(), anyInt())).thenReturn(consumingInstancePartitions.toZNRecord());
+    when(propertyStore
+        .get(eq(ZKMetadataProvider.constructPropertyStorePathForInstancePartitions(COMPLETED_INSTANCE_PARTITIONS_NAME)),
+            any(), anyInt())).thenReturn(completedInstancePartitions.toZNRecord());
+    HelixManager helixManager = mock(HelixManager.class);
+    when(helixManager.getHelixPropertyStore()).thenReturn(propertyStore);
+
+    TableConfig tableConfig =
+        new TableConfig.Builder(CommonConstants.Helix.TableType.REALTIME).setTableName(RAW_TABLE_NAME)
+            .setNumReplicas(NUM_REPLICAS).setLLC(true).setSegmentAssignmentStrategy(
+            SegmentAssignmentStrategyFactory.Strategy.ReplicaGroupSegmentAssignmentStrategy.name()).build();
+    _strategy = SegmentAssignmentStrategyFactory.getSegmentAssignmentStrategy(helixManager, tableConfig);
+  }
+
+  @Test
+  public void testFactory() {
+    assertTrue(_strategy instanceof RealtimeReplicaGroupSegmentAssignmentStrategy);
+  }
+
+  @Test
+  public void testAssignSegment() {
+    int numInstancesPerReplica = NUM_CONSUMING_INSTANCES / NUM_REPLICAS;
+    Map<String, Map<String, String>> currentAssignment = new TreeMap<>();
+    for (int segmentId = 0; segmentId < NUM_SEGMENTS; segmentId++) {
+      String segmentName = _segments.get(segmentId);
+      List<String> instancesAssigned = _strategy.assignSegment(segmentName, currentAssignment);
+      assertEquals(instancesAssigned.size(), NUM_REPLICAS);
+      for (int replicaId = 0; replicaId < NUM_REPLICAS; replicaId++) {
+
+        // Segment 0 (partition 0) should be assigned to instance 0, 3, 6
+        // Segment 1 (partition 1) should be assigned to instance 1, 4, 7
+        // Segment 2 (partition 2) should be assigned to instance 2, 5, 8
+        // Segment 3 (partition 3) should be assigned to instance 0, 3, 6
+        // Segment 4 (partition 0) should be assigned to instance 0, 3, 6
+        // Segment 5 (partition 1) should be assigned to instance 1, 4, 7
+        // ...
+        int partitionId = segmentId % NUM_PARTITIONS;
+        int expectedAssignedInstanceId = partitionId % numInstancesPerReplica + replicaId * numInstancesPerReplica;
+        assertEquals(instancesAssigned.get(replicaId), CONSUMING_INSTANCES.get(expectedAssignedInstanceId));
+      }
+      addToAssignment(currentAssignment, segmentId, instancesAssigned);
+    }
+  }
+
+  @Test
+  public void testRelocateCompletedSegments() {
+    Map<String, Map<String, String>> currentAssignment = new TreeMap<>();
+    for (int segmentId = 0; segmentId < NUM_SEGMENTS; segmentId++) {
+      String segmentName = _segments.get(segmentId);
+      List<String> instancesAssigned = _strategy.assignSegment(segmentName, currentAssignment);
+      addToAssignment(currentAssignment, segmentId, instancesAssigned);
+    }
+
+    // There should be 100 segments assigned
+    assertEquals(currentAssignment.size(), NUM_SEGMENTS);
+    // Each segment should have 3 replicas
+    for (Map<String, String> instanceStateMap : currentAssignment.values()) {
+      assertEquals(instanceStateMap.size(), NUM_REPLICAS);
+    }
+
+    // Rebalance should relocate all completed (ONLINE) segments to the completed instances
+    Map<String, Map<String, String>> newAssignment =
+        _strategy.rebalanceTable(currentAssignment, new BaseConfiguration());
+    assertEquals(newAssignment.size(), NUM_SEGMENTS);
+    for (int segmentId = 0; segmentId < NUM_SEGMENTS; segmentId++) {
+      if (segmentId < NUM_SEGMENTS - NUM_PARTITIONS) {
+        // Completed (ONLINE) segments
+        Map<String, String> instanceStateMap = newAssignment.get(_segments.get(segmentId));
+        for (Map.Entry<String, String> entry : instanceStateMap.entrySet()) {
+          assertTrue(entry.getKey().startsWith(COMPLETED_INSTANCE_NAME_PREFIX));
+          assertEquals(entry.getValue(), RealtimeSegmentOnlineOfflineStateModel.ONLINE);
+        }
+      } else {
+        // Consuming (CONSUMING) segments
+        Map<String, String> instanceStateMap = newAssignment.get(_segments.get(segmentId));
+        for (Map.Entry<String, String> entry : instanceStateMap.entrySet()) {
+          assertTrue(entry.getKey().startsWith(CONSUMING_INSTANCE_NAME_PREFIX));
+          assertEquals(entry.getValue(), RealtimeSegmentOnlineOfflineStateModel.CONSUMING);
+        }
+      }
+    }
+    // Relocated segments should be balanced (each instance should have 24 segments assigned)
+    int[] numSegmentsAssignedPerInstance =
+        SegmentAssignmentUtils.getNumSegmentsAssignedPerInstance(newAssignment, COMPLETED_INSTANCES);
+    int[] expectedNumSegmentsAssignedPerInstance = new int[NUM_COMPLETED_INSTANCES];
+    int numSegmentsPerInstance = (NUM_SEGMENTS - NUM_PARTITIONS) * NUM_REPLICAS / NUM_COMPLETED_INSTANCES;
+    Arrays.fill(expectedNumSegmentsAssignedPerInstance, numSegmentsPerInstance);
+    assertEquals(numSegmentsAssignedPerInstance, expectedNumSegmentsAssignedPerInstance);
+
+    // Rebalance all segments (both completed and consuming) should give the same assignment
+    BaseConfiguration config = new BaseConfiguration();
+    config.setProperty(RebalanceUserConfigConstants.INCLUDE_CONSUMING, true);
+    assertEquals(_strategy.rebalanceTable(currentAssignment, config), newAssignment);
+  }
+
+  private void addToAssignment(Map<String, Map<String, String>> currentAssignment, int segmentId,
+      List<String> instancesAssigned) {
+    // Change the state of the last segment in the same partition from CONSUMING to ONLINE if exists
+    if (segmentId >= NUM_PARTITIONS) {
+      String lastSegmentInPartition = _segments.get(segmentId - NUM_PARTITIONS);
+      Map<String, String> instanceStateMap = currentAssignment.get(lastSegmentInPartition);
+      currentAssignment.put(lastSegmentInPartition, SegmentAssignmentUtils
+          .getInstanceStateMap(new ArrayList<>(instanceStateMap.keySet()),
+              RealtimeSegmentOnlineOfflineStateModel.ONLINE));
+    }
+
+    // Add the new segment into the assignment as CONSUMING
+    currentAssignment.put(_segments.get(segmentId), SegmentAssignmentUtils
+        .getInstanceStateMap(instancesAssigned, RealtimeSegmentOnlineOfflineStateModel.CONSUMING));
+  }
+}
diff --git a/pinot-controller/src/test/java/org/apache/pinot/controller/helix/core/assignment/segment/SegmentAssignmentTestUtils.java b/pinot-controller/src/test/java/org/apache/pinot/controller/helix/core/assignment/segment/SegmentAssignmentTestUtils.java
new file mode 100644
index 0000000..c9a7ae0
--- /dev/null
+++ b/pinot-controller/src/test/java/org/apache/pinot/controller/helix/core/assignment/segment/SegmentAssignmentTestUtils.java
@@ -0,0 +1,39 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.controller.helix.core.assignment.segment;
+
+import java.util.ArrayList;
+import java.util.List;
+
+
+public class SegmentAssignmentTestUtils {
+  private SegmentAssignmentTestUtils() {
+  }
+
+  /**
+   * Returns a list of names.
+   */
+  public static List<String> getNameList(String namePrefix, int numNames) {
+    List<String> names = new ArrayList<>(numNames);
+    for (int nameId = 0; nameId < numNames; nameId++) {
+      names.add(namePrefix + nameId);
+    }
+    return names;
+  }
+}
diff --git a/pinot-controller/src/test/java/org/apache/pinot/controller/helix/core/assignment/segment/SegmentAssignmentUtilsTest.java b/pinot-controller/src/test/java/org/apache/pinot/controller/helix/core/assignment/segment/SegmentAssignmentUtilsTest.java
new file mode 100644
index 0000000..86cfac6
--- /dev/null
+++ b/pinot-controller/src/test/java/org/apache/pinot/controller/helix/core/assignment/segment/SegmentAssignmentUtilsTest.java
@@ -0,0 +1,434 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.controller.helix.core.assignment.segment;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.TreeMap;
+import org.apache.pinot.common.utils.CommonConstants.Helix.StateModel.SegmentOnlineOfflineStateModel;
+import org.apache.pinot.controller.helix.core.assignment.InstancePartitions;
+import org.testng.annotations.Test;
+
+import static org.testng.Assert.assertEquals;
+import static org.testng.Assert.assertTrue;
+
+
+public class SegmentAssignmentUtilsTest {
+  private static final int NUM_REPLICAS = 3;
+  private static final String SEGMENT_NAME_PREFIX = "segment_";
+  private static final String INSTANCE_NAME_PREFIX = "instance_";
+
+  @Test
+  public void testRebalanceTableWithHelixAutoRebalanceStrategy() {
+    int numSegments = 100;
+    List<String> segments = SegmentAssignmentTestUtils.getNameList(SEGMENT_NAME_PREFIX, numSegments);
+    int numInstances = 10;
+    List<String> instances = SegmentAssignmentTestUtils.getNameList(INSTANCE_NAME_PREFIX, numInstances);
+
+    // Uniformly spray segments to the instances:
+    // [instance_0,   instance_1,   instance_2,   instance_3,   instance_4,   instance_5,   instance_6,   instance_7,   instance_8,   instance_9]
+    //  segment_0(r0) segment_0(r1) segment_0(r2) segment_1(r0) segment_1(r1) segment_1(r2) segment_2(r0) segment_2(r1) segment_2(r2) segment_3(r0)
+    //  segment_3(r1) segment_3(r2) segment_4(r0) segment_4(r1) segment_4(r2) segment_5(r0) segment_5(r1) segment_5(r2) segment_6(r0) segment_6(r1)
+    //  ...
+    Map<String, Map<String, String>> currentAssignment = new TreeMap<>();
+    int assignedInstanceId = 0;
+    for (String segmentName : segments) {
+      List<String> instancesAssigned = new ArrayList<>(NUM_REPLICAS);
+      for (int replicaId = 0; replicaId < NUM_REPLICAS; replicaId++) {
+        instancesAssigned.add(instances.get(assignedInstanceId));
+        assignedInstanceId = (assignedInstanceId + 1) % numInstances;
+      }
+      currentAssignment.put(segmentName,
+          SegmentAssignmentUtils.getInstanceStateMap(instancesAssigned, SegmentOnlineOfflineStateModel.ONLINE));
+    }
+
+    // There should be 100 segments assigned
+    assertEquals(currentAssignment.size(), numSegments);
+    // Each segment should have 3 replicas
+    for (Map<String, String> instanceStateMap : currentAssignment.values()) {
+      assertEquals(instanceStateMap.size(), NUM_REPLICAS);
+    }
+    // Each instance should have 30 segments assigned
+    int[] numSegmentsAssignedPerInstance =
+        SegmentAssignmentUtils.getNumSegmentsAssignedPerInstance(currentAssignment, instances);
+    int[] expectedNumSegmentsAssignedPerInstance = new int[numInstances];
+    int numSegmentsPerInstance = numSegments * NUM_REPLICAS / numInstances;
+    Arrays.fill(expectedNumSegmentsAssignedPerInstance, numSegmentsPerInstance);
+    assertEquals(numSegmentsAssignedPerInstance, expectedNumSegmentsAssignedPerInstance);
+    // Current assignment should already be balanced
+    assertEquals(
+        SegmentAssignmentUtils.rebalanceTableWithHelixAutoRebalanceStrategy(currentAssignment, instances, NUM_REPLICAS),
+        currentAssignment);
+
+    // Replace instance_0 with instance_10
+    // {
+    //   0_0=[instance_10, instance_1, instance_2, instance_3, instance_4, instance_5, instance_6, instance_7, instance_8, instance_9]
+    // }
+    List<String> newInstances = new ArrayList<>(instances);
+    String newInstanceName = INSTANCE_NAME_PREFIX + 10;
+    newInstances.set(0, newInstanceName);
+    Map<String, Map<String, String>> newAssignment = SegmentAssignmentUtils
+        .rebalanceTableWithHelixAutoRebalanceStrategy(currentAssignment, newInstances, NUM_REPLICAS);
+    // There should be 100 segments assigned
+    assertEquals(currentAssignment.size(), numSegments);
+    // Each segment should have 3 replicas
+    for (Map<String, String> instanceStateMap : newAssignment.values()) {
+      assertEquals(instanceStateMap.size(), NUM_REPLICAS);
+    }
+    // Each instance should have 30 segments assigned
+    numSegmentsAssignedPerInstance =
+        SegmentAssignmentUtils.getNumSegmentsAssignedPerInstance(newAssignment, newInstances);
+    assertEquals(numSegmentsAssignedPerInstance, expectedNumSegmentsAssignedPerInstance);
+    // All segments on instance_0 should be moved to instance_10
+    Map<String, Integer> numSegmentsToBeMovedPerInstance =
+        SegmentAssignmentUtils.getNumSegmentsToBeMovedPerInstance(currentAssignment, newAssignment);
+    assertEquals(numSegmentsToBeMovedPerInstance.size(), 1);
+    assertEquals((int) numSegmentsToBeMovedPerInstance.get(newInstanceName), numSegmentsPerInstance);
+    String oldInstanceName = INSTANCE_NAME_PREFIX + 0;
+    for (String segmentName : segments) {
+      if (currentAssignment.get(segmentName).containsKey(oldInstanceName)) {
+        assertTrue(newAssignment.get(segmentName).containsKey(newInstanceName));
+      }
+    }
+
+    // Remove 5 instances
+    // {
+    //   0_0=[instance_0, instance_1, instance_2, instance_3, instance_4]
+    // }
+    int newNumInstances = numInstances - 5;
+    newInstances = SegmentAssignmentTestUtils.getNameList(INSTANCE_NAME_PREFIX, newNumInstances);
+    newAssignment = SegmentAssignmentUtils
+        .rebalanceTableWithHelixAutoRebalanceStrategy(currentAssignment, newInstances, NUM_REPLICAS);
+    // There should be 100 segments assigned
+    assertEquals(newAssignment.size(), numSegments);
+    // Each segment should have 3 replicas
+    for (Map<String, String> instanceStateMap : newAssignment.values()) {
+      assertEquals(instanceStateMap.size(), NUM_REPLICAS);
+    }
+    // The segments are not perfectly balanced, but should be deterministic
+    numSegmentsAssignedPerInstance =
+        SegmentAssignmentUtils.getNumSegmentsAssignedPerInstance(newAssignment, newInstances);
+    assertEquals(numSegmentsAssignedPerInstance[0], 56);
+    assertEquals(numSegmentsAssignedPerInstance[1], 60);
+    assertEquals(numSegmentsAssignedPerInstance[2], 60);
+    assertEquals(numSegmentsAssignedPerInstance[3], 60);
+    assertEquals(numSegmentsAssignedPerInstance[4], 64);
+    numSegmentsToBeMovedPerInstance =
+        SegmentAssignmentUtils.getNumSegmentsToBeMovedPerInstance(currentAssignment, newAssignment);
+    assertEquals(numSegmentsToBeMovedPerInstance.size(), newNumInstances);
+    assertEquals((int) numSegmentsToBeMovedPerInstance.get(newInstances.get(0)), 26);
+    assertEquals((int) numSegmentsToBeMovedPerInstance.get(newInstances.get(1)), 30);
+    assertEquals((int) numSegmentsToBeMovedPerInstance.get(newInstances.get(2)), 30);
+    assertEquals((int) numSegmentsToBeMovedPerInstance.get(newInstances.get(3)), 30);
+    assertEquals((int) numSegmentsToBeMovedPerInstance.get(newInstances.get(4)), 34);
+
+    // Add 5 instances
+    // {
+    //   0_0=[instance_0, instance_1, instance_2, instance_3, instance_4, instance_5, instance_6, instance_7, instance_8, instance_9, instance_10, instance_11, instance_12, instance_13, instance_14]
+    // }
+    newNumInstances = numInstances + 5;
+    newInstances = SegmentAssignmentTestUtils.getNameList(INSTANCE_NAME_PREFIX, newNumInstances);
+    newAssignment = SegmentAssignmentUtils
+        .rebalanceTableWithHelixAutoRebalanceStrategy(currentAssignment, newInstances, NUM_REPLICAS);
+    // There should be 100 segments assigned
+    assertEquals(newAssignment.size(), numSegments);
+    // Each segment should have 3 replicas
+    for (Map<String, String> instanceStateMap : newAssignment.values()) {
+      assertEquals(instanceStateMap.size(), NUM_REPLICAS);
+    }
+    // Each instance should have 20 segments assigned
+    numSegmentsAssignedPerInstance =
+        SegmentAssignmentUtils.getNumSegmentsAssignedPerInstance(newAssignment, newInstances);
+    expectedNumSegmentsAssignedPerInstance = new int[newNumInstances];
+    int newNumSegmentsPerInstance = numSegments * NUM_REPLICAS / newNumInstances;
+    Arrays.fill(expectedNumSegmentsAssignedPerInstance, newNumSegmentsPerInstance);
+    assertEquals(numSegmentsAssignedPerInstance, expectedNumSegmentsAssignedPerInstance);
+    // Each new added instance should have 20 segments to be moved to it
+    numSegmentsToBeMovedPerInstance =
+        SegmentAssignmentUtils.getNumSegmentsToBeMovedPerInstance(currentAssignment, newAssignment);
+    assertEquals(numSegmentsToBeMovedPerInstance.size(), 5);
+    for (int instanceId = numInstances; instanceId < newNumInstances; instanceId++) {
+      assertEquals((int) numSegmentsToBeMovedPerInstance.get(newInstances.get(instanceId)), newNumSegmentsPerInstance);
+    }
+
+    // Change all instances
+    // {
+    //   0_0=[i_0, i_1, i_2, i_3, i_4, i_5, i_6, i_7, i_8, i_9]
+    // }
+    String newInstanceNamePrefix = "i_";
+    newInstances = SegmentAssignmentTestUtils.getNameList(newInstanceNamePrefix, numInstances);
+    newAssignment = SegmentAssignmentUtils
+        .rebalanceTableWithHelixAutoRebalanceStrategy(currentAssignment, newInstances, NUM_REPLICAS);
+    // There should be 100 segments assigned
+    assertEquals(newAssignment.size(), numSegments);
+    // Each segment should have 3 replicas
+    for (Map<String, String> instanceStateMap : newAssignment.values()) {
+      assertEquals(instanceStateMap.size(), NUM_REPLICAS);
+    }
+    // Each instance should have 30 segments assigned
+    numSegmentsAssignedPerInstance =
+        SegmentAssignmentUtils.getNumSegmentsAssignedPerInstance(newAssignment, newInstances);
+    expectedNumSegmentsAssignedPerInstance = new int[numInstances];
+    Arrays.fill(expectedNumSegmentsAssignedPerInstance, numSegmentsPerInstance);
+    assertEquals(numSegmentsAssignedPerInstance, expectedNumSegmentsAssignedPerInstance);
+    // Each instance should have 30 segments to be moved to it
+    numSegmentsToBeMovedPerInstance =
+        SegmentAssignmentUtils.getNumSegmentsToBeMovedPerInstance(currentAssignment, newAssignment);
+    assertEquals(numSegmentsToBeMovedPerInstance.size(), numInstances);
+    for (String instanceName : newInstances) {
+      assertEquals((int) numSegmentsToBeMovedPerInstance.get(instanceName), numSegmentsPerInstance);
+    }
+  }
+
+  @Test
+  public void testRebalanceReplicaGroupBasedTable() {
+    // Table is rebalanced on a per partition basis, so testing rebalancing one partition is enough
+
+    int numSegments = 90;
+    List<String> segments = SegmentAssignmentTestUtils.getNameList(SEGMENT_NAME_PREFIX, numSegments);
+    Map<Integer, Set<String>> partitionIdToSegmentsMap = Collections.singletonMap(0, new HashSet<>(segments));
+    int numInstances = 9;
+    List<String> instances = SegmentAssignmentTestUtils.getNameList(INSTANCE_NAME_PREFIX, numInstances);
+
+    // {
+    //   0_0=[instance_0, instance_1, instance_2],
+    //   0_1=[instance_3, instance_4, instance_5],
+    //   0_2=[instance_6, instance_7, instance_8]
+    // }
+    InstancePartitions instancePartitions = new InstancePartitions(null);
+    int numInstancesPerReplica = numInstances / NUM_REPLICAS;
+    int instanceIdToAdd = 0;
+    for (int replicaId = 0; replicaId < NUM_REPLICAS; replicaId++) {
+      List<String> instancesForReplica = new ArrayList<>(numInstancesPerReplica);
+      for (int i = 0; i < numInstancesPerReplica; i++) {
+        instancesForReplica.add(instances.get(instanceIdToAdd++));
+      }
+      instancePartitions.setInstances(0, replicaId, instancesForReplica);
+    }
+
+    // Uniformly spray segments to the instances:
+    // Replica group 0: [instance_0, instance_1, instance_2],
+    // Replica group 1: [instance_3, instance_4, instance_5],
+    // Replica group 2: [instance_6, instance_7, instance_8]
+    //                   segment_0   segment_1   segment_2
+    //                   segment_3   segment_4   segment_5
+    //                   ...
+    Map<String, Map<String, String>> currentAssignment = new TreeMap<>();
+    for (int segmentId = 0; segmentId < numSegments; segmentId++) {
+      List<String> instancesAssigned = new ArrayList<>(NUM_REPLICAS);
+      for (int replicaId = 0; replicaId < NUM_REPLICAS; replicaId++) {
+        int assignedInstanceId = segmentId % numInstancesPerReplica + replicaId * numInstancesPerReplica;
+        instancesAssigned.add(instances.get(assignedInstanceId));
+      }
+      currentAssignment.put(segments.get(segmentId),
+          SegmentAssignmentUtils.getInstanceStateMap(instancesAssigned, SegmentOnlineOfflineStateModel.ONLINE));
+    }
+
+    // There should be 90 segments assigned
+    assertEquals(currentAssignment.size(), numSegments);
+    // Each segment should have 3 replicas
+    for (Map<String, String> instanceStateMap : currentAssignment.values()) {
+      assertEquals(instanceStateMap.size(), NUM_REPLICAS);
+    }
+    // Each instance should have 30 segments assigned
+    int[] numSegmentsAssignedPerInstance =
+        SegmentAssignmentUtils.getNumSegmentsAssignedPerInstance(currentAssignment, instances);
+    int[] expectedNumSegmentsAssignedPerInstance = new int[numInstances];
+    int numSegmentsPerInstance = numSegments * NUM_REPLICAS / numInstances;
+    Arrays.fill(expectedNumSegmentsAssignedPerInstance, numSegmentsPerInstance);
+    assertEquals(numSegmentsAssignedPerInstance, expectedNumSegmentsAssignedPerInstance);
+    // Current assignment should already be balanced
+    assertEquals(SegmentAssignmentUtils
+            .rebalanceReplicaGroupBasedTable(currentAssignment, instancePartitions, partitionIdToSegmentsMap),
+        currentAssignment);
+
+    // Replace instance_0 with instance_9, instance_4 with instance_10
+    // {
+    //   0_0=[instance_9, instance_1, instance_2],
+    //   0_1=[instance_3, instance_10, instance_5],
+    //   0_2=[instance_6, instance_7, instance_8]
+    // }
+    List<String> newInstances = new ArrayList<>(numInstances);
+    List<String> newReplica0Instances = new ArrayList<>(instancePartitions.getInstances(0, 0));
+    String newReplica0Instance = INSTANCE_NAME_PREFIX + 9;
+    newReplica0Instances.set(0, newReplica0Instance);
+    newInstances.addAll(newReplica0Instances);
+    List<String> newReplica1Instances = new ArrayList<>(instancePartitions.getInstances(0, 1));
+    String newReplica1Instance = INSTANCE_NAME_PREFIX + 10;
+    newReplica1Instances.set(1, newReplica1Instance);
+    newInstances.addAll(newReplica1Instances);
+    List<String> newReplica2Instances = instancePartitions.getInstances(0, 2);
+    newInstances.addAll(newReplica2Instances);
+    InstancePartitions newInstancePartitions = new InstancePartitions(null);
+    newInstancePartitions.setInstances(0, 0, newReplica0Instances);
+    newInstancePartitions.setInstances(0, 1, newReplica1Instances);
+    newInstancePartitions.setInstances(0, 2, newReplica2Instances);
+    Map<String, Map<String, String>> newAssignment = SegmentAssignmentUtils
+        .rebalanceReplicaGroupBasedTable(currentAssignment, newInstancePartitions, partitionIdToSegmentsMap);
+    // There should be 90 segments assigned
+    assertEquals(newAssignment.size(), numSegments);
+    // Each segment should have 3 replicas
+    for (Map<String, String> instanceStateMap : newAssignment.values()) {
+      assertEquals(instanceStateMap.size(), NUM_REPLICAS);
+    }
+    // Each instance should have 30 segments assigned
+    numSegmentsAssignedPerInstance =
+        SegmentAssignmentUtils.getNumSegmentsAssignedPerInstance(newAssignment, newInstances);
+    assertEquals(numSegmentsAssignedPerInstance, expectedNumSegmentsAssignedPerInstance);
+    // All segments on instance_0 should be moved to instance_9, all segments on instance_4 should be moved to
+    // instance_10
+    Map<String, Integer> numSegmentsToBeMovedPerInstance =
+        SegmentAssignmentUtils.getNumSegmentsToBeMovedPerInstance(currentAssignment, newAssignment);
+    assertEquals(numSegmentsToBeMovedPerInstance.size(), 2);
+    assertEquals((int) numSegmentsToBeMovedPerInstance.get(newReplica0Instance), numSegmentsPerInstance);
+    assertEquals((int) numSegmentsToBeMovedPerInstance.get(newReplica1Instance), numSegmentsPerInstance);
+    String replica0OldInstanceName = INSTANCE_NAME_PREFIX + 0;
+    String replica1OldInstanceName = INSTANCE_NAME_PREFIX + 4;
+    for (String segmentName : segments) {
+      Map<String, String> oldInstanceStateMap = currentAssignment.get(segmentName);
+      if (oldInstanceStateMap.containsKey(replica0OldInstanceName)) {
+        assertTrue(newAssignment.get(segmentName).containsKey(newReplica0Instance));
+      }
+      if (oldInstanceStateMap.containsKey(replica1OldInstanceName)) {
+        assertTrue(newAssignment.get(segmentName).containsKey(newReplica1Instance));
+      }
+    }
+
+    // Remove 3 instances (1 from each replica)
+    // {
+    //   0_0=[instance_0, instance_1],
+    //   0_1=[instance_3, instance_4],
+    //   0_2=[instance_6, instance_7]
+    // }
+    int newNumInstances = numInstances - 3;
+    int newNumInstancesPerReplica = newNumInstances / NUM_REPLICAS;
+    newInstances = new ArrayList<>(newNumInstances);
+    for (int replicaId = 0; replicaId < NUM_REPLICAS; replicaId++) {
+      List<String> newInstancesForReplica =
+          instancePartitions.getInstances(0, replicaId).subList(0, newNumInstancesPerReplica);
+      newInstancePartitions.setInstances(0, replicaId, newInstancesForReplica);
+      newInstances.addAll(newInstancesForReplica);
+    }
+    newAssignment = SegmentAssignmentUtils
+        .rebalanceReplicaGroupBasedTable(currentAssignment, newInstancePartitions, partitionIdToSegmentsMap);
+    // There should be 90 segments assigned
+    assertEquals(newAssignment.size(), numSegments);
+    // Each segment should have 3 replicas
+    for (Map<String, String> instanceStateMap : newAssignment.values()) {
+      assertEquals(instanceStateMap.size(), NUM_REPLICAS);
+    }
+    // Each instance should have 45 segments assigned
+    numSegmentsAssignedPerInstance =
+        SegmentAssignmentUtils.getNumSegmentsAssignedPerInstance(newAssignment, newInstances);
+    expectedNumSegmentsAssignedPerInstance = new int[newNumInstances];
+    int newNumSegmentsPerInstance = numSegments * NUM_REPLICAS / newNumInstances;
+    Arrays.fill(expectedNumSegmentsAssignedPerInstance, newNumSegmentsPerInstance);
+    assertEquals(numSegmentsAssignedPerInstance, expectedNumSegmentsAssignedPerInstance);
+    // Each instance should have 15 segments to be moved to it
+    numSegmentsToBeMovedPerInstance =
+        SegmentAssignmentUtils.getNumSegmentsToBeMovedPerInstance(currentAssignment, newAssignment);
+    assertEquals(numSegmentsToBeMovedPerInstance.size(), newNumInstances);
+    for (String instanceName : newInstances) {
+      assertEquals((int) numSegmentsToBeMovedPerInstance.get(instanceName),
+          newNumSegmentsPerInstance - numSegmentsPerInstance);
+    }
+
+    // Add 6 instances (2 to each replica)
+    // {
+    //   0_0=[instance_0, instance_1, instance_2, instance_9, instance_10],
+    //   0_1=[instance_3, instance_4, instance_5, instance_11, instance_12],
+    //   0_2=[instance_6, instance_7, instance_8, instance_13, instance_14]
+    // }
+    newNumInstances = numInstances + 6;
+    newNumInstancesPerReplica = newNumInstances / NUM_REPLICAS;
+    newInstances = SegmentAssignmentTestUtils.getNameList(INSTANCE_NAME_PREFIX, newNumInstances);
+    instanceIdToAdd = numInstances;
+    for (int replicaId = 0; replicaId < NUM_REPLICAS; replicaId++) {
+      List<String> newInstancesForReplica = new ArrayList<>(instancePartitions.getInstances(0, replicaId));
+      for (int i = 0; i < newNumInstancesPerReplica - numInstancesPerReplica; i++) {
+        newInstancesForReplica.add(newInstances.get(instanceIdToAdd++));
+      }
+      newInstancePartitions.setInstances(0, replicaId, newInstancesForReplica);
+    }
+    newAssignment = SegmentAssignmentUtils
+        .rebalanceReplicaGroupBasedTable(currentAssignment, newInstancePartitions, partitionIdToSegmentsMap);
+    // There should be 90 segments assigned
+    assertEquals(newAssignment.size(), numSegments);
+    // Each segment should have 3 replicas
+    for (Map<String, String> instanceStateMap : newAssignment.values()) {
+      assertEquals(instanceStateMap.size(), NUM_REPLICAS);
+    }
+    // Each instance should have 18 segments assigned
+    numSegmentsAssignedPerInstance =
+        SegmentAssignmentUtils.getNumSegmentsAssignedPerInstance(newAssignment, newInstances);
+    expectedNumSegmentsAssignedPerInstance = new int[newNumInstances];
+    newNumSegmentsPerInstance = numSegments * NUM_REPLICAS / newNumInstances;
+    Arrays.fill(expectedNumSegmentsAssignedPerInstance, newNumSegmentsPerInstance);
+    assertEquals(numSegmentsAssignedPerInstance, expectedNumSegmentsAssignedPerInstance);
+    // Each new added instance should have 18 segments to be moved to it
+    numSegmentsToBeMovedPerInstance =
+        SegmentAssignmentUtils.getNumSegmentsToBeMovedPerInstance(currentAssignment, newAssignment);
+    assertEquals(numSegmentsToBeMovedPerInstance.size(), 6);
+    for (int instanceId = numInstances; instanceId < newNumInstances; instanceId++) {
+      assertEquals((int) numSegmentsToBeMovedPerInstance.get(newInstances.get(instanceId)), newNumSegmentsPerInstance);
+    }
+
+    // Change all instances
+    // {
+    //   0_0=[i_0, i_1, i_2],
+    //   0_1=[i_3, i_4, i_5],
+    //   0_2=[i_6, i_7, i_8]
+    // }
+    newInstances = SegmentAssignmentTestUtils.getNameList("i_", numInstances);
+    instanceIdToAdd = 0;
+    for (int replicaId = 0; replicaId < NUM_REPLICAS; replicaId++) {
+      List<String> instancesForReplica = new ArrayList<>(numInstancesPerReplica);
+      for (int i = 0; i < numInstancesPerReplica; i++) {
+        instancesForReplica.add(newInstances.get(instanceIdToAdd++));
+      }
+      newInstancePartitions.setInstances(0, replicaId, instancesForReplica);
+    }
+    newAssignment = SegmentAssignmentUtils
+        .rebalanceReplicaGroupBasedTable(currentAssignment, newInstancePartitions, partitionIdToSegmentsMap);
+    // There should be 90 segments assigned
+    assertEquals(newAssignment.size(), numSegments);
+    // Each segment should have 3 replicas
+    for (Map<String, String> instanceStateMap : newAssignment.values()) {
+      assertEquals(instanceStateMap.size(), NUM_REPLICAS);
+    }
+    // Each instance should have 30 segments assigned
+    numSegmentsAssignedPerInstance =
+        SegmentAssignmentUtils.getNumSegmentsAssignedPerInstance(newAssignment, newInstances);
+    expectedNumSegmentsAssignedPerInstance = new int[numInstances];
+    Arrays.fill(expectedNumSegmentsAssignedPerInstance, numSegmentsPerInstance);
+    assertEquals(numSegmentsAssignedPerInstance, expectedNumSegmentsAssignedPerInstance);
+    // Each instance should have 30 segments to be moved to it
+    numSegmentsToBeMovedPerInstance =
+        SegmentAssignmentUtils.getNumSegmentsToBeMovedPerInstance(currentAssignment, newAssignment);
+    assertEquals(numSegmentsToBeMovedPerInstance.size(), numInstances);
+    for (String instanceName : newInstances) {
+      assertEquals((int) numSegmentsToBeMovedPerInstance.get(instanceName), numSegmentsPerInstance);
+    }
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@pinot.apache.org
For additional commands, e-mail: commits-help@pinot.apache.org