You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hbase.apache.org by ap...@apache.org on 2018/07/05 23:38:19 UTC

[2/2] hbase git commit: HBASE-20840 Backport HBASE-20791 'RSGroupBasedLoadBalancer#setClusterMetrics should pass ClusterMetrics to its internalBalancer' to branch-1 (chenxu)

HBASE-20840 Backport HBASE-20791 'RSGroupBasedLoadBalancer#setClusterMetrics should pass ClusterMetrics to its internalBalancer' to branch-1 (chenxu)


Project: http://git-wip-us.apache.org/repos/asf/hbase/repo
Commit: http://git-wip-us.apache.org/repos/asf/hbase/commit/71a0a3cf
Tree: http://git-wip-us.apache.org/repos/asf/hbase/tree/71a0a3cf
Diff: http://git-wip-us.apache.org/repos/asf/hbase/diff/71a0a3cf

Branch: refs/heads/branch-1.4
Commit: 71a0a3cf0d8925ce81fdde688f8daa4f25e2d0a4
Parents: c5be81f
Author: Andrew Purtell <ap...@apache.org>
Authored: Thu Jul 5 16:11:43 2018 -0700
Committer: Andrew Purtell <ap...@apache.org>
Committed: Thu Jul 5 16:21:43 2018 -0700

----------------------------------------------------------------------
 .../hbase/rsgroup/RSGroupBasedLoadBalancer.java |  12 +-
 .../balancer/RSGroupableBalancerTestBase.java   | 433 ++++++++++++++++++
 .../balancer/TestRSGroupBasedLoadBalancer.java  | 446 +------------------
 ...cerWithStochasticLoadBalancerAsInternal.java | 161 +++++++
 4 files changed, 625 insertions(+), 427 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/hbase/blob/71a0a3cf/hbase-rsgroup/src/main/java/org/apache/hadoop/hbase/rsgroup/RSGroupBasedLoadBalancer.java
----------------------------------------------------------------------
diff --git a/hbase-rsgroup/src/main/java/org/apache/hadoop/hbase/rsgroup/RSGroupBasedLoadBalancer.java b/hbase-rsgroup/src/main/java/org/apache/hadoop/hbase/rsgroup/RSGroupBasedLoadBalancer.java
index 77adf0d..091b02c 100644
--- a/hbase-rsgroup/src/main/java/org/apache/hadoop/hbase/rsgroup/RSGroupBasedLoadBalancer.java
+++ b/hbase-rsgroup/src/main/java/org/apache/hadoop/hbase/rsgroup/RSGroupBasedLoadBalancer.java
@@ -101,11 +101,17 @@ public class RSGroupBasedLoadBalancer implements RSGroupableBalancer, LoadBalanc
   @Override
   public void setConf(Configuration conf) {
     this.config = conf;
+    if (internalBalancer != null) {
+      internalBalancer.setConf(conf);
+    }
   }
 
   @Override
   public void setClusterStatus(ClusterStatus st) {
     this.clusterStatus = st;
+    if (internalBalancer != null) {
+      internalBalancer.setClusterStatus(st);
+    }
   }
 
   @Override
@@ -365,7 +371,7 @@ public class RSGroupBasedLoadBalancer implements RSGroupableBalancer, LoadBalanc
   }
 
   private Map<ServerName, List<HRegionInfo>> correctAssignments(
-       Map<ServerName, List<HRegionInfo>> existingAssignments){
+      Map<ServerName, List<HRegionInfo>> existingAssignments) {
     Map<ServerName, List<HRegionInfo>> correctAssignments =
         new TreeMap<ServerName, List<HRegionInfo>>();
     correctAssignments.put(LoadBalancer.BOGUS_SERVER_NAME, new LinkedList<HRegionInfo>());
@@ -424,7 +430,9 @@ public class RSGroupBasedLoadBalancer implements RSGroupableBalancer, LoadBalanc
         HBASE_GROUP_LOADBALANCER_CLASS,
         StochasticLoadBalancer.class, LoadBalancer.class);
     internalBalancer = ReflectionUtils.newInstance(balancerKlass, config);
-    internalBalancer.setClusterStatus(clusterStatus);
+    if (clusterStatus != null) {
+      internalBalancer.setClusterStatus(clusterStatus);
+    }
     internalBalancer.setMasterServices(masterServices);
     internalBalancer.setConf(config);
     internalBalancer.initialize();

http://git-wip-us.apache.org/repos/asf/hbase/blob/71a0a3cf/hbase-rsgroup/src/test/java/org/apache/hadoop/hbase/master/balancer/RSGroupableBalancerTestBase.java
----------------------------------------------------------------------
diff --git a/hbase-rsgroup/src/test/java/org/apache/hadoop/hbase/master/balancer/RSGroupableBalancerTestBase.java b/hbase-rsgroup/src/test/java/org/apache/hadoop/hbase/master/balancer/RSGroupableBalancerTestBase.java
new file mode 100644
index 0000000..a32b77e
--- /dev/null
+++ b/hbase-rsgroup/src/test/java/org/apache/hadoop/hbase/master/balancer/RSGroupableBalancerTestBase.java
@@ -0,0 +1,433 @@
+/**
+ * Copyright The Apache Software Foundation
+ *
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hadoop.hbase.master.balancer;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+
+import com.google.common.collect.ArrayListMultimap;
+import com.google.common.collect.Lists;
+
+import java.io.FileNotFoundException;
+import java.io.IOException;
+import java.security.SecureRandom;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.TreeMap;
+import java.util.TreeSet;
+
+import org.apache.commons.lang.StringUtils;
+import org.apache.hadoop.hbase.HRegionInfo;
+import org.apache.hadoop.hbase.HTableDescriptor;
+import org.apache.hadoop.hbase.ServerName;
+import org.apache.hadoop.hbase.TableDescriptors;
+import org.apache.hadoop.hbase.TableName;
+import org.apache.hadoop.hbase.master.AssignmentManager;
+import org.apache.hadoop.hbase.master.HMaster;
+import org.apache.hadoop.hbase.master.MasterServices;
+import org.apache.hadoop.hbase.master.RegionPlan;
+import org.apache.hadoop.hbase.net.Address;
+import org.apache.hadoop.hbase.rsgroup.RSGroupInfo;
+import org.apache.hadoop.hbase.rsgroup.RSGroupInfoManager;
+import org.apache.hadoop.hbase.util.Bytes;
+import org.mockito.Mockito;
+import org.mockito.invocation.InvocationOnMock;
+import org.mockito.stubbing.Answer;
+
+public class RSGroupableBalancerTestBase {
+
+  static SecureRandom rand = new SecureRandom();
+  static String[] groups = new String[] { RSGroupInfo.DEFAULT_GROUP, "dg2", "dg3", "dg4" };
+  static TableName table0 = TableName.valueOf("dt0");
+  static TableName[] tables =
+      new TableName[] { TableName.valueOf("dt1"),
+          TableName.valueOf("dt2"),
+          TableName.valueOf("dt3"),
+          TableName.valueOf("dt4")};
+  static List<ServerName> servers;
+  static Map<String, RSGroupInfo> groupMap;
+  static Map<TableName, String> tableMap = new HashMap<>();
+  static List<HTableDescriptor> tableDescs;
+  int[] regionAssignment = new int[] { 2, 5, 7, 10, 4, 3, 1 };
+  static int regionId = 0;
+
+  /**
+   * Invariant is that all servers of a group have load between floor(avg) and
+   * ceiling(avg) number of regions.
+   */
+  protected void assertClusterAsBalanced(
+      ArrayListMultimap<String, ServerAndLoad> groupLoadMap) {
+    for (String gName : groupLoadMap.keySet()) {
+      List<ServerAndLoad> groupLoad = groupLoadMap.get(gName);
+      int numServers = groupLoad.size();
+      int numRegions = 0;
+      int maxRegions = 0;
+      int minRegions = Integer.MAX_VALUE;
+      for (ServerAndLoad server : groupLoad) {
+        int nr = server.getLoad();
+        if (nr > maxRegions) {
+          maxRegions = nr;
+        }
+        if (nr < minRegions) {
+          minRegions = nr;
+        }
+        numRegions += nr;
+      }
+      if (maxRegions - minRegions < 2) {
+        // less than 2 between max and min, can't balance
+        return;
+      }
+      int min = numRegions / numServers;
+      int max = numRegions % numServers == 0 ? min : min + 1;
+
+      for (ServerAndLoad server : groupLoad) {
+        assertTrue(server.getLoad() <= max);
+        assertTrue(server.getLoad() >= min);
+      }
+    }
+  }
+
+  /**
+   * Asserts a valid retained assignment plan.
+   * <p>
+   * Must meet the following conditions:
+   * <ul>
+   * <li>Every input region has an assignment, and to an online server
+   * <li>If a region had an existing assignment to a server with the same
+   * address a a currently online server, it will be assigned to it
+   * </ul>
+   */
+  protected void assertRetainedAssignment(
+      Map<HRegionInfo, ServerName> existing, List<ServerName> servers,
+      Map<ServerName, List<HRegionInfo>> assignment)
+      throws FileNotFoundException, IOException {
+    // Verify condition 1, every region assigned, and to online server
+    Set<ServerName> onlineServerSet = new TreeSet<ServerName>(servers);
+    Set<HRegionInfo> assignedRegions = new TreeSet<HRegionInfo>();
+    for (Map.Entry<ServerName, List<HRegionInfo>> a : assignment.entrySet()) {
+      assertTrue(
+          "Region assigned to server that was not listed as online",
+          onlineServerSet.contains(a.getKey()));
+      for (HRegionInfo r : a.getValue()) {
+        assignedRegions.add(r);
+      }
+    }
+    assertEquals(existing.size(), assignedRegions.size());
+
+    // Verify condition 2, every region must be assigned to correct server.
+    Set<String> onlineHostNames = new TreeSet<String>();
+    for (ServerName s : servers) {
+      onlineHostNames.add(s.getHostname());
+    }
+
+    for (Map.Entry<ServerName, List<HRegionInfo>> a : assignment.entrySet()) {
+      ServerName currentServer = a.getKey();
+      for (HRegionInfo r : a.getValue()) {
+        ServerName oldAssignedServer = existing.get(r);
+        TableName tableName = r.getTable();
+        String groupName =
+            getMockedGroupInfoManager().getRSGroupOfTable(tableName);
+        assertTrue(StringUtils.isNotEmpty(groupName));
+        RSGroupInfo gInfo = getMockedGroupInfoManager().getRSGroup(
+            groupName);
+        assertTrue(
+            "Region is not correctly assigned to group servers.",
+            gInfo.containsServer(currentServer.getAddress()));
+        if (oldAssignedServer != null
+            && onlineHostNames.contains(oldAssignedServer
+            .getHostname())) {
+          // this region was previously assigned somewhere, and that
+          // host is still around, then the host must have been is a
+          // different group.
+          if (!oldAssignedServer.getAddress().equals(currentServer.getAddress())) {
+            assertFalse(gInfo.containsServer(oldAssignedServer.getAddress()));
+          }
+        }
+      }
+    }
+  }
+
+  protected String printStats(
+      ArrayListMultimap<String, ServerAndLoad> groupBasedLoad) {
+    StringBuffer sb = new StringBuffer();
+    sb.append("\n");
+    for (String groupName : groupBasedLoad.keySet()) {
+      sb.append("Stats for group: " + groupName);
+      sb.append("\n");
+      sb.append(groupMap.get(groupName).getServers());
+      sb.append("\n");
+      List<ServerAndLoad> groupLoad = groupBasedLoad.get(groupName);
+      int numServers = groupLoad.size();
+      int totalRegions = 0;
+      sb.append("Per Server Load: \n");
+      for (ServerAndLoad sLoad : groupLoad) {
+        sb.append("Server :" + sLoad.getServerName() + " Load : "
+            + sLoad.getLoad() + "\n");
+        totalRegions += sLoad.getLoad();
+      }
+      sb.append(" Group Statistics : \n");
+      float average = (float) totalRegions / numServers;
+      int max = (int) Math.ceil(average);
+      int min = (int) Math.floor(average);
+      sb.append("[srvr=" + numServers + " rgns=" + totalRegions + " avg="
+          + average + " max=" + max + " min=" + min + "]");
+      sb.append("\n");
+      sb.append("===============================");
+      sb.append("\n");
+    }
+    return sb.toString();
+  }
+
+  protected ArrayListMultimap<String, ServerAndLoad> convertToGroupBasedMap(
+      final Map<ServerName, List<HRegionInfo>> serversMap) throws IOException {
+    ArrayListMultimap<String, ServerAndLoad> loadMap = ArrayListMultimap
+        .create();
+    for (RSGroupInfo gInfo : getMockedGroupInfoManager().listRSGroups()) {
+      Set<Address> groupServers = gInfo.getServers();
+      for (Address server : groupServers) {
+        ServerName actual = null;
+        for(ServerName entry: servers) {
+          if(entry.getAddress().equals(server)) {
+            actual = entry;
+            break;
+          }
+        }
+        List<HRegionInfo> regions = serversMap.get(actual);
+        assertTrue("No load for " + actual, regions != null);
+        loadMap.put(gInfo.getName(),
+            new ServerAndLoad(actual, regions.size()));
+      }
+    }
+    return loadMap;
+  }
+
+  protected ArrayListMultimap<String, ServerAndLoad> reconcile(
+      ArrayListMultimap<String, ServerAndLoad> previousLoad,
+      List<RegionPlan> plans) {
+    ArrayListMultimap<String, ServerAndLoad> result = ArrayListMultimap
+        .create();
+    result.putAll(previousLoad);
+    if (plans != null) {
+      for (RegionPlan plan : plans) {
+        ServerName source = plan.getSource();
+        updateLoad(result, source, -1);
+        ServerName destination = plan.getDestination();
+        updateLoad(result, destination, +1);
+      }
+    }
+    return result;
+  }
+
+  protected void updateLoad(
+      ArrayListMultimap<String, ServerAndLoad> previousLoad,
+      final ServerName sn, final int diff) {
+    for (String groupName : previousLoad.keySet()) {
+      ServerAndLoad newSAL = null;
+      ServerAndLoad oldSAL = null;
+      for (ServerAndLoad sal : previousLoad.get(groupName)) {
+        if (ServerName.isSameAddress(sn, sal.getServerName())) {
+          oldSAL = sal;
+          newSAL = new ServerAndLoad(sn, sal.getLoad() + diff);
+          break;
+        }
+      }
+      if (newSAL != null) {
+        previousLoad.remove(groupName, oldSAL);
+        previousLoad.put(groupName, newSAL);
+        break;
+      }
+    }
+  }
+
+  protected Map<ServerName, List<HRegionInfo>> mockClusterServers() throws IOException {
+    assertTrue(servers.size() == regionAssignment.length);
+    Map<ServerName, List<HRegionInfo>> assignment = new TreeMap<ServerName, List<HRegionInfo>>();
+    for (int i = 0; i < servers.size(); i++) {
+      int numRegions = regionAssignment[i];
+      List<HRegionInfo> regions = assignedRegions(numRegions, servers.get(i));
+      assignment.put(servers.get(i), regions);
+    }
+    return assignment;
+  }
+
+  /**
+   * Generate a list of regions evenly distributed between the tables.
+   *
+   * @param numRegions The number of regions to be generated.
+   * @return List of HRegionInfo.
+   */
+  protected List<HRegionInfo> randomRegions(int numRegions) {
+    List<HRegionInfo> regions = new ArrayList<HRegionInfo>(numRegions);
+    byte[] start = new byte[16];
+    byte[] end = new byte[16];
+    rand.nextBytes(start);
+    rand.nextBytes(end);
+    int regionIdx = rand.nextInt(tables.length);
+    for (int i = 0; i < numRegions; i++) {
+      Bytes.putInt(start, 0, numRegions << 1);
+      Bytes.putInt(end, 0, (numRegions << 1) + 1);
+      int tableIndex = (i + regionIdx) % tables.length;
+      HRegionInfo hri = new HRegionInfo(
+          tables[tableIndex], start, end, false, regionId++);
+      regions.add(hri);
+    }
+    return regions;
+  }
+
+  /**
+   * Generate assigned regions to a given server using group information.
+   *
+   * @param numRegions the num regions to generate
+   * @param sn the servername
+   * @return the list of regions
+   * @throws java.io.IOException Signals that an I/O exception has occurred.
+   */
+  protected List<HRegionInfo> assignedRegions(int numRegions, ServerName sn) throws IOException {
+    List<HRegionInfo> regions = new ArrayList<HRegionInfo>(numRegions);
+    byte[] start = new byte[16];
+    byte[] end = new byte[16];
+    Bytes.putInt(start, 0, numRegions << 1);
+    Bytes.putInt(end, 0, (numRegions << 1) + 1);
+    for (int i = 0; i < numRegions; i++) {
+      TableName tableName = getTableName(sn);
+      HRegionInfo hri = new HRegionInfo(
+          tableName, start, end, false,
+          regionId++);
+      regions.add(hri);
+    }
+    return regions;
+  }
+
+  protected static List<ServerName> generateServers(int numServers) {
+    List<ServerName> servers = new ArrayList<ServerName>(numServers);
+    for (int i = 0; i < numServers; i++) {
+      String host = "server" + rand.nextInt(100000);
+      int port = rand.nextInt(60000);
+      servers.add(ServerName.valueOf(host, port, -1));
+    }
+    return servers;
+  }
+
+  /**
+   * Construct group info, with each group having at least one server.
+   *
+   * @param servers the servers
+   * @param groups the groups
+   * @return the map
+   */
+  protected static Map<String, RSGroupInfo> constructGroupInfo(
+      List<ServerName> servers, String[] groups) {
+    assertTrue(servers != null);
+    assertTrue(servers.size() >= groups.length);
+    int index = 0;
+    Map<String, RSGroupInfo> groupMap = new HashMap<String, RSGroupInfo>();
+    for (String grpName : groups) {
+      RSGroupInfo RSGroupInfo = new RSGroupInfo(grpName);
+      RSGroupInfo.addServer(servers.get(index).getAddress());
+      groupMap.put(grpName, RSGroupInfo);
+      index++;
+    }
+    while (index < servers.size()) {
+      int grpIndex = rand.nextInt(groups.length);
+      groupMap.get(groups[grpIndex]).addServer(servers.get(index).getAddress());
+      index++;
+    }
+    return groupMap;
+  }
+
+  /**
+   * Construct table descriptors evenly distributed between the groups.
+   *
+   * @return the list
+   */
+  protected static List<HTableDescriptor> constructTableDesc(boolean hasBogusTable) {
+    List<HTableDescriptor> tds = Lists.newArrayList();
+    int index = rand.nextInt(groups.length);
+    for (int i = 0; i < tables.length; i++) {
+      HTableDescriptor htd = new HTableDescriptor(tables[i]);
+      int grpIndex = (i + index) % groups.length ;
+      String groupName = groups[grpIndex];
+      tableMap.put(tables[i], groupName);
+      tds.add(htd);
+    }
+    if (hasBogusTable) {
+      tableMap.put(table0, "");
+      tds.add(new HTableDescriptor(table0));
+    }
+    return tds;
+  }
+
+  protected static MasterServices getMockedMaster() throws IOException {
+    TableDescriptors tds = Mockito.mock(TableDescriptors.class);
+    Mockito.when(tds.get(tables[0])).thenReturn(tableDescs.get(0));
+    Mockito.when(tds.get(tables[1])).thenReturn(tableDescs.get(1));
+    Mockito.when(tds.get(tables[2])).thenReturn(tableDescs.get(2));
+    Mockito.when(tds.get(tables[3])).thenReturn(tableDescs.get(3));
+    MasterServices services = Mockito.mock(HMaster.class);
+    Mockito.when(services.getTableDescriptors()).thenReturn(tds);
+    AssignmentManager am = Mockito.mock(AssignmentManager.class);
+    Mockito.when(services.getAssignmentManager()).thenReturn(am);
+    return services;
+  }
+
+  protected static RSGroupInfoManager getMockedGroupInfoManager() throws IOException {
+    RSGroupInfoManager gm = Mockito.mock(RSGroupInfoManager.class);
+    Mockito.when(gm.getRSGroup(Mockito.anyString())).thenAnswer(new Answer<RSGroupInfo>() {
+      @Override
+      public RSGroupInfo answer(InvocationOnMock invocation) throws Throwable {
+        return groupMap.get(invocation.getArguments()[0]);
+      }
+    });
+    Mockito.when(gm.listRSGroups()).thenReturn(
+        Lists.newLinkedList(groupMap.values()));
+    Mockito.when(gm.isOnline()).thenReturn(true);
+    Mockito.when(gm.getRSGroupOfTable(Mockito.any(TableName.class)))
+        .thenAnswer(new Answer<String>() {
+          @Override
+          public String answer(InvocationOnMock invocation) throws Throwable {
+            return tableMap.get(invocation.getArguments()[0]);
+          }
+        });
+    return gm;
+  }
+
+  protected TableName getTableName(ServerName sn) throws IOException {
+    TableName tableName = null;
+    RSGroupInfoManager gm = getMockedGroupInfoManager();
+    RSGroupInfo groupOfServer = null;
+    for(RSGroupInfo gInfo : gm.listRSGroups()){
+      if(gInfo.containsServer(sn.getAddress())){
+        groupOfServer = gInfo;
+        break;
+      }
+    }
+
+    for(HTableDescriptor desc : tableDescs){
+      if(gm.getRSGroupOfTable(desc.getTableName()).endsWith(groupOfServer.getName())){
+        tableName = desc.getTableName();
+      }
+    }
+    return tableName;
+  }
+}

http://git-wip-us.apache.org/repos/asf/hbase/blob/71a0a3cf/hbase-rsgroup/src/test/java/org/apache/hadoop/hbase/master/balancer/TestRSGroupBasedLoadBalancer.java
----------------------------------------------------------------------
diff --git a/hbase-rsgroup/src/test/java/org/apache/hadoop/hbase/master/balancer/TestRSGroupBasedLoadBalancer.java b/hbase-rsgroup/src/test/java/org/apache/hadoop/hbase/master/balancer/TestRSGroupBasedLoadBalancer.java
index e511d14..6170cc1 100644
--- a/hbase-rsgroup/src/test/java/org/apache/hadoop/hbase/master/balancer/TestRSGroupBasedLoadBalancer.java
+++ b/hbase-rsgroup/src/test/java/org/apache/hadoop/hbase/master/balancer/TestRSGroupBasedLoadBalancer.java
@@ -19,8 +19,20 @@
  */
 package org.apache.hadoop.hbase.master.balancer;
 
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+
 import com.google.common.collect.ArrayListMultimap;
-import com.google.common.collect.Lists;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
 
 import org.apache.commons.lang.StringUtils;
 import org.apache.commons.logging.Log;
@@ -28,76 +40,31 @@ import org.apache.commons.logging.LogFactory;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.hbase.HBaseConfiguration;
 import org.apache.hadoop.hbase.HRegionInfo;
-import org.apache.hadoop.hbase.HTableDescriptor;
 import org.apache.hadoop.hbase.ServerName;
-import org.apache.hadoop.hbase.TableDescriptors;
 import org.apache.hadoop.hbase.TableName;
-import org.apache.hadoop.hbase.rsgroup.RSGroupBasedLoadBalancer;
-import org.apache.hadoop.hbase.rsgroup.RSGroupInfo;
-import org.apache.hadoop.hbase.rsgroup.RSGroupInfoManager;
-import org.apache.hadoop.hbase.master.AssignmentManager;
-import org.apache.hadoop.hbase.master.HMaster;
 import org.apache.hadoop.hbase.master.LoadBalancer;
-import org.apache.hadoop.hbase.master.MasterServices;
 import org.apache.hadoop.hbase.master.RegionPlan;
 import org.apache.hadoop.hbase.net.Address;
+import org.apache.hadoop.hbase.rsgroup.RSGroupBasedLoadBalancer;
+import org.apache.hadoop.hbase.rsgroup.RSGroupInfo;
 import org.apache.hadoop.hbase.testclassification.SmallTests;
-import org.apache.hadoop.hbase.util.Bytes;
 import org.junit.BeforeClass;
 import org.junit.Test;
 import org.junit.experimental.categories.Category;
-import org.mockito.Mockito;
-import org.mockito.invocation.InvocationOnMock;
-import org.mockito.stubbing.Answer;
-
-import java.io.FileNotFoundException;
-import java.io.IOException;
-import java.security.SecureRandom;
-import java.util.ArrayList;
-import java.util.Collections;
-import java.util.HashMap;
-import java.util.HashSet;
-import java.util.Iterator;
-import java.util.List;
-import java.util.Map;
-import java.util.Set;
-import java.util.TreeMap;
-import java.util.TreeSet;
-
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertFalse;
-import static org.junit.Assert.assertTrue;
 
-//TODO use stochastic based load balancer instead
+/**
+ * Test RSGroupBasedLoadBalancer with SimpleLoadBalancer as internal balancer
+ */
 @Category(SmallTests.class)
-public class TestRSGroupBasedLoadBalancer {
-
+public class TestRSGroupBasedLoadBalancer extends RSGroupableBalancerTestBase {
   private static final Log LOG = LogFactory.getLog(TestRSGroupBasedLoadBalancer.class);
   private static RSGroupBasedLoadBalancer loadBalancer;
-  private static SecureRandom rand;
-
-  static String[]  groups = new String[] { RSGroupInfo.DEFAULT_GROUP, "dg2", "dg3",
-      "dg4" };
-  static TableName table0 = TableName.valueOf("dt0");
-  static TableName[] tables =
-      new TableName[] { TableName.valueOf("dt1"),
-          TableName.valueOf("dt2"),
-          TableName.valueOf("dt3"),
-          TableName.valueOf("dt4")};
-  static List<ServerName> servers;
-  static Map<String, RSGroupInfo> groupMap;
-  static Map<TableName, String> tableMap;
-  static List<HTableDescriptor> tableDescs;
-  int[] regionAssignment = new int[] { 2, 5, 7, 10, 4, 3, 1 };
-  static int regionId = 0;
 
   @BeforeClass
   public static void beforeAllTests() throws Exception {
-    rand = new SecureRandom();
     servers = generateServers(7);
     groupMap = constructGroupInfo(servers, groups);
-    tableMap = new HashMap<TableName, String>();
-    tableDescs = constructTableDesc();
+    tableDescs = constructTableDesc(true);
     Configuration conf = HBaseConfiguration.create();
     conf.set("hbase.regions.slop", "0");
     conf.set("hbase.group.grouploadbalancer.class", SimpleLoadBalancer.class.getCanonicalName());
@@ -109,11 +76,8 @@ public class TestRSGroupBasedLoadBalancer {
 
   /**
    * Test the load balancing algorithm.
-   *
    * Invariant is that all servers of the group should be hosting either floor(average) or
    * ceiling(average)
-   *
-   * @throws Exception
    */
   @Test
   public void testBalanceCluster() throws Exception {
@@ -128,49 +92,10 @@ public class TestRSGroupBasedLoadBalancer {
   }
 
   /**
-   * Invariant is that all servers of a group have load between floor(avg) and
-   * ceiling(avg) number of regions.
-   */
-  private void assertClusterAsBalanced(
-      ArrayListMultimap<String, ServerAndLoad> groupLoadMap) {
-    for (String gName : groupLoadMap.keySet()) {
-      List<ServerAndLoad> groupLoad = groupLoadMap.get(gName);
-      int numServers = groupLoad.size();
-      int numRegions = 0;
-      int maxRegions = 0;
-      int minRegions = Integer.MAX_VALUE;
-      for (ServerAndLoad server : groupLoad) {
-        int nr = server.getLoad();
-        if (nr > maxRegions) {
-          maxRegions = nr;
-        }
-        if (nr < minRegions) {
-          minRegions = nr;
-        }
-        numRegions += nr;
-      }
-      if (maxRegions - minRegions < 2) {
-        // less than 2 between max and min, can't balance
-        return;
-      }
-      int min = numRegions / numServers;
-      int max = numRegions % numServers == 0 ? min : min + 1;
-
-      for (ServerAndLoad server : groupLoad) {
-        assertTrue(server.getLoad() <= max);
-        assertTrue(server.getLoad() >= min);
-      }
-    }
-  }
-
-  /**
    * Tests the bulk assignment used during cluster startup.
-   *
    * Round-robin. Should yield a balanced cluster so same invariant as the
    * load balancer holds, all servers holding either floor(avg) or
    * ceiling(avg).
-   *
-   * @throws Exception
    */
   @Test
   public void testBulkAssignment() throws Exception {
@@ -203,8 +128,6 @@ public class TestRSGroupBasedLoadBalancer {
   /**
    * Test the cluster startup bulk assignment which attempts to retain
    * assignment info.
-   *
-   * @throws Exception
    */
   @Test
   public void testRetainAssignment() throws Exception {
@@ -233,9 +156,9 @@ public class TestRSGroupBasedLoadBalancer {
     Set<HRegionInfo> misplacedRegions = loadBalancer.getMisplacedRegions(inputForTest);
     assertFalse(misplacedRegions.contains(ri));
   }
+
   /**
    * Test BOGUS_SERVER_NAME among groups do not overwrite each other
-   * @throws Exception
    */
   @Test
   public void testRoundRobinAssignment() throws Exception {
@@ -263,331 +186,4 @@ public class TestRSGroupBasedLoadBalancer {
         .roundRobinAssignment(regions, onlineServers);
     assertEquals(bogusRegion, assignments.get(LoadBalancer.BOGUS_SERVER_NAME).size());
   }
-
-  /**
-   * Asserts a valid retained assignment plan.
-   * <p>
-   * Must meet the following conditions:
-   * <ul>
-   * <li>Every input region has an assignment, and to an online server
-   * <li>If a region had an existing assignment to a server with the same
-   * address a a currently online server, it will be assigned to it
-   * </ul>
-   *
-   * @param existing
-   * @param assignment
-   * @throws java.io.IOException
-   * @throws java.io.FileNotFoundException
-   */
-  private void assertRetainedAssignment(
-      Map<HRegionInfo, ServerName> existing, List<ServerName> servers,
-      Map<ServerName, List<HRegionInfo>> assignment)
-      throws FileNotFoundException, IOException {
-    // Verify condition 1, every region assigned, and to online server
-    Set<ServerName> onlineServerSet = new TreeSet<ServerName>(servers);
-    Set<HRegionInfo> assignedRegions = new TreeSet<HRegionInfo>();
-    for (Map.Entry<ServerName, List<HRegionInfo>> a : assignment.entrySet()) {
-      assertTrue(
-          "Region assigned to server that was not listed as online",
-          onlineServerSet.contains(a.getKey()));
-      for (HRegionInfo r : a.getValue())
-        assignedRegions.add(r);
-    }
-    assertEquals(existing.size(), assignedRegions.size());
-
-    // Verify condition 2, every region must be assigned to correct server.
-    Set<String> onlineHostNames = new TreeSet<String>();
-    for (ServerName s : servers) {
-      onlineHostNames.add(s.getHostname());
-    }
-
-    for (Map.Entry<ServerName, List<HRegionInfo>> a : assignment.entrySet()) {
-      ServerName currentServer = a.getKey();
-      for (HRegionInfo r : a.getValue()) {
-        ServerName oldAssignedServer = existing.get(r);
-        TableName tableName = r.getTable();
-        String groupName =
-            getMockedGroupInfoManager().getRSGroupOfTable(tableName);
-        assertTrue(StringUtils.isNotEmpty(groupName));
-        RSGroupInfo gInfo = getMockedGroupInfoManager().getRSGroup(
-            groupName);
-        assertTrue(
-            "Region is not correctly assigned to group servers.",
-            gInfo.containsServer(currentServer.getAddress()));
-        if (oldAssignedServer != null
-            && onlineHostNames.contains(oldAssignedServer
-            .getHostname())) {
-          // this region was previously assigned somewhere, and that
-          // host is still around, then the host must have been is a
-          // different group.
-          if (!oldAssignedServer.getAddress().equals(currentServer.getAddress())) {
-            assertFalse(gInfo.containsServer(oldAssignedServer.getAddress()));
-          }
-        }
-      }
-    }
-  }
-
-  private String printStats(
-      ArrayListMultimap<String, ServerAndLoad> groupBasedLoad) {
-    StringBuffer sb = new StringBuffer();
-    sb.append("\n");
-    for (String groupName : groupBasedLoad.keySet()) {
-      sb.append("Stats for group: " + groupName);
-      sb.append("\n");
-      sb.append(groupMap.get(groupName).getServers());
-      sb.append("\n");
-      List<ServerAndLoad> groupLoad = groupBasedLoad.get(groupName);
-      int numServers = groupLoad.size();
-      int totalRegions = 0;
-      sb.append("Per Server Load: \n");
-      for (ServerAndLoad sLoad : groupLoad) {
-        sb.append("Server :" + sLoad.getServerName() + " Load : "
-            + sLoad.getLoad() + "\n");
-        totalRegions += sLoad.getLoad();
-      }
-      sb.append(" Group Statistics : \n");
-      float average = (float) totalRegions / numServers;
-      int max = (int) Math.ceil(average);
-      int min = (int) Math.floor(average);
-      sb.append("[srvr=" + numServers + " rgns=" + totalRegions + " avg="
-          + average + " max=" + max + " min=" + min + "]");
-      sb.append("\n");
-      sb.append("===============================");
-      sb.append("\n");
-    }
-    return sb.toString();
-  }
-
-  private ArrayListMultimap<String, ServerAndLoad> convertToGroupBasedMap(
-      final Map<ServerName, List<HRegionInfo>> serversMap) throws IOException {
-    ArrayListMultimap<String, ServerAndLoad> loadMap = ArrayListMultimap
-        .create();
-    for (RSGroupInfo gInfo : getMockedGroupInfoManager().listRSGroups()) {
-      Set<Address> groupServers = gInfo.getServers();
-      for (Address server : groupServers) {
-        ServerName actual = null;
-        for(ServerName entry: servers) {
-          if(entry.getAddress().equals(server)) {
-            actual = entry;
-            break;
-          }
-        }
-        List<HRegionInfo> regions = serversMap.get(actual);
-        assertTrue("No load for " + actual, regions != null);
-        loadMap.put(gInfo.getName(),
-            new ServerAndLoad(actual, regions.size()));
-      }
-    }
-    return loadMap;
-  }
-
-  private ArrayListMultimap<String, ServerAndLoad> reconcile(
-      ArrayListMultimap<String, ServerAndLoad> previousLoad,
-      List<RegionPlan> plans) {
-    ArrayListMultimap<String, ServerAndLoad> result = ArrayListMultimap
-        .create();
-    result.putAll(previousLoad);
-    if (plans != null) {
-      for (RegionPlan plan : plans) {
-        ServerName source = plan.getSource();
-        updateLoad(result, source, -1);
-        ServerName destination = plan.getDestination();
-        updateLoad(result, destination, +1);
-      }
-    }
-    return result;
-  }
-
-  private void updateLoad(
-      ArrayListMultimap<String, ServerAndLoad> previousLoad,
-      final ServerName sn, final int diff) {
-    for (String groupName : previousLoad.keySet()) {
-      ServerAndLoad newSAL = null;
-      ServerAndLoad oldSAL = null;
-      for (ServerAndLoad sal : previousLoad.get(groupName)) {
-        if (ServerName.isSameAddress(sn, sal.getServerName())) {
-          oldSAL = sal;
-          newSAL = new ServerAndLoad(sn, sal.getLoad() + diff);
-          break;
-        }
-      }
-      if (newSAL != null) {
-        previousLoad.remove(groupName, oldSAL);
-        previousLoad.put(groupName, newSAL);
-        break;
-      }
-    }
-  }
-
-  private Map<ServerName, List<HRegionInfo>> mockClusterServers() throws IOException {
-    assertTrue(servers.size() == regionAssignment.length);
-    Map<ServerName, List<HRegionInfo>> assignment = new TreeMap<ServerName, List<HRegionInfo>>();
-    for (int i = 0; i < servers.size(); i++) {
-      int numRegions = regionAssignment[i];
-      List<HRegionInfo> regions = assignedRegions(numRegions, servers.get(i));
-      assignment.put(servers.get(i), regions);
-    }
-    return assignment;
-  }
-
-  /**
-   * Generate a list of regions evenly distributed between the tables.
-   *
-   * @param numRegions The number of regions to be generated.
-   * @return List of HRegionInfo.
-   */
-  private List<HRegionInfo> randomRegions(int numRegions) {
-    List<HRegionInfo> regions = new ArrayList<HRegionInfo>(numRegions);
-    byte[] start = new byte[16];
-    byte[] end = new byte[16];
-    rand.nextBytes(start);
-    rand.nextBytes(end);
-    int regionIdx = rand.nextInt(tables.length);
-    for (int i = 0; i < numRegions; i++) {
-      Bytes.putInt(start, 0, numRegions << 1);
-      Bytes.putInt(end, 0, (numRegions << 1) + 1);
-      int tableIndex = (i + regionIdx) % tables.length;
-      HRegionInfo hri = new HRegionInfo(
-          tables[tableIndex], start, end, false, regionId++);
-      regions.add(hri);
-    }
-    return regions;
-  }
-
-  /**
-   * Generate assigned regions to a given server using group information.
-   *
-   * @param numRegions the num regions to generate
-   * @param sn the servername
-   * @return the list of regions
-   * @throws java.io.IOException Signals that an I/O exception has occurred.
-   */
-  private List<HRegionInfo> assignedRegions(int numRegions, ServerName sn) throws IOException {
-    List<HRegionInfo> regions = new ArrayList<HRegionInfo>(numRegions);
-    byte[] start = new byte[16];
-    byte[] end = new byte[16];
-    Bytes.putInt(start, 0, numRegions << 1);
-    Bytes.putInt(end, 0, (numRegions << 1) + 1);
-    for (int i = 0; i < numRegions; i++) {
-      TableName tableName = getTableName(sn);
-      HRegionInfo hri = new HRegionInfo(
-          tableName, start, end, false,
-          regionId++);
-      regions.add(hri);
-    }
-    return regions;
-  }
-
-  private static List<ServerName> generateServers(int numServers) {
-    List<ServerName> servers = new ArrayList<ServerName>(numServers);
-    for (int i = 0; i < numServers; i++) {
-      String host = "server" + rand.nextInt(100000);
-      int port = rand.nextInt(60000);
-      servers.add(ServerName.valueOf(host, port, -1));
-    }
-    return servers;
-  }
-
-  /**
-   * Construct group info, with each group having at least one server.
-   *
-   * @param servers the servers
-   * @param groups the groups
-   * @return the map
-   */
-  private static Map<String, RSGroupInfo> constructGroupInfo(
-      List<ServerName> servers, String[] groups) {
-    assertTrue(servers != null);
-    assertTrue(servers.size() >= groups.length);
-    int index = 0;
-    Map<String, RSGroupInfo> groupMap = new HashMap<String, RSGroupInfo>();
-    for (String grpName : groups) {
-      RSGroupInfo RSGroupInfo = new RSGroupInfo(grpName);
-      RSGroupInfo.addServer(servers.get(index).getAddress());
-      groupMap.put(grpName, RSGroupInfo);
-      index++;
-    }
-    while (index < servers.size()) {
-      int grpIndex = rand.nextInt(groups.length);
-      groupMap.get(groups[grpIndex]).addServer(servers.get(index).getAddress());
-      index++;
-    }
-    return groupMap;
-  }
-
-  /**
-   * Construct table descriptors evenly distributed between the groups.
-   *
-   * @return the list
-   */
-  private static List<HTableDescriptor> constructTableDesc() {
-    List<HTableDescriptor> tds = Lists.newArrayList();
-    int index = rand.nextInt(groups.length);
-    for (int i = 0; i < tables.length; i++) {
-      HTableDescriptor htd = new HTableDescriptor(tables[i]);
-      int grpIndex = (i + index) % groups.length ;
-      String groupName = groups[grpIndex];
-      tableMap.put(tables[i], groupName);
-      tds.add(htd);
-    }
-    tableMap.put(table0, "");
-    tds.add(new HTableDescriptor(table0));
-    return tds;
-  }
-
-  private static MasterServices getMockedMaster() throws IOException {
-    TableDescriptors tds = Mockito.mock(TableDescriptors.class);
-    Mockito.when(tds.get(tables[0])).thenReturn(tableDescs.get(0));
-    Mockito.when(tds.get(tables[1])).thenReturn(tableDescs.get(1));
-    Mockito.when(tds.get(tables[2])).thenReturn(tableDescs.get(2));
-    Mockito.when(tds.get(tables[3])).thenReturn(tableDescs.get(3));
-    MasterServices services = Mockito.mock(HMaster.class);
-    Mockito.when(services.getTableDescriptors()).thenReturn(tds);
-    AssignmentManager am = Mockito.mock(AssignmentManager.class);
-    Mockito.when(services.getAssignmentManager()).thenReturn(am);
-    return services;
-  }
-
-  private static RSGroupInfoManager getMockedGroupInfoManager() throws IOException {
-    RSGroupInfoManager gm = Mockito.mock(RSGroupInfoManager.class);
-    Mockito.when(gm.getRSGroup(groups[0])).thenReturn(
-        groupMap.get(groups[0]));
-    Mockito.when(gm.getRSGroup(groups[1])).thenReturn(
-        groupMap.get(groups[1]));
-    Mockito.when(gm.getRSGroup(groups[2])).thenReturn(
-        groupMap.get(groups[2]));
-    Mockito.when(gm.getRSGroup(groups[3])).thenReturn(
-        groupMap.get(groups[3]));
-    Mockito.when(gm.listRSGroups()).thenReturn(
-        Lists.newLinkedList(groupMap.values()));
-    Mockito.when(gm.isOnline()).thenReturn(true);
-    Mockito.when(gm.getRSGroupOfTable(Mockito.any(TableName.class)))
-        .thenAnswer(new Answer<String>() {
-          @Override
-          public String answer(InvocationOnMock invocation) throws Throwable {
-            return tableMap.get(invocation.getArguments()[0]);
-          }
-        });
-    return gm;
-  }
-
-  private TableName getTableName(ServerName sn) throws IOException {
-    TableName tableName = null;
-    RSGroupInfoManager gm = getMockedGroupInfoManager();
-    RSGroupInfo groupOfServer = null;
-    for(RSGroupInfo gInfo : gm.listRSGroups()){
-      if(gInfo.containsServer(sn.getAddress())){
-        groupOfServer = gInfo;
-        break;
-      }
-    }
-
-    for(HTableDescriptor desc : tableDescs){
-      if(gm.getRSGroupOfTable(desc.getTableName()).endsWith(groupOfServer.getName())){
-        tableName = desc.getTableName();
-      }
-    }
-    return tableName;
-  }
 }

http://git-wip-us.apache.org/repos/asf/hbase/blob/71a0a3cf/hbase-rsgroup/src/test/java/org/apache/hadoop/hbase/master/balancer/TestRSGroupBasedLoadBalancerWithStochasticLoadBalancerAsInternal.java
----------------------------------------------------------------------
diff --git a/hbase-rsgroup/src/test/java/org/apache/hadoop/hbase/master/balancer/TestRSGroupBasedLoadBalancerWithStochasticLoadBalancerAsInternal.java b/hbase-rsgroup/src/test/java/org/apache/hadoop/hbase/master/balancer/TestRSGroupBasedLoadBalancerWithStochasticLoadBalancerAsInternal.java
new file mode 100644
index 0000000..771b59f
--- /dev/null
+++ b/hbase-rsgroup/src/test/java/org/apache/hadoop/hbase/master/balancer/TestRSGroupBasedLoadBalancerWithStochasticLoadBalancerAsInternal.java
@@ -0,0 +1,161 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hadoop.hbase.master.balancer;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.TreeMap;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.hbase.ClusterStatus;
+import org.apache.hadoop.hbase.HBaseConfiguration;
+import org.apache.hadoop.hbase.HBaseIOException;
+import org.apache.hadoop.hbase.HRegionInfo;
+import org.apache.hadoop.hbase.RegionLoad;
+import org.apache.hadoop.hbase.ServerLoad;
+import org.apache.hadoop.hbase.ServerName;
+import org.apache.hadoop.hbase.master.RegionPlan;
+import org.apache.hadoop.hbase.rsgroup.RSGroupBasedLoadBalancer;
+import org.apache.hadoop.hbase.rsgroup.RSGroupInfo;
+import org.apache.hadoop.hbase.testclassification.SmallTests;
+import org.apache.hadoop.hbase.util.Bytes;
+import org.junit.BeforeClass;
+import org.junit.Test;
+import org.junit.experimental.categories.Category;
+import org.mockito.Mockito;
+import org.mockito.invocation.InvocationOnMock;
+import org.mockito.stubbing.Answer;
+
+/**
+ * Test RSGroupBasedLoadBalancer with StochasticLoadBalancer as internal balancer
+ */
+@Category(SmallTests.class)
+public class TestRSGroupBasedLoadBalancerWithStochasticLoadBalancerAsInternal
+    extends RSGroupableBalancerTestBase {
+  private static RSGroupBasedLoadBalancer loadBalancer;
+
+  @BeforeClass
+  public static void beforeAllTests() throws Exception {
+    groups = new String[] { RSGroupInfo.DEFAULT_GROUP };
+    servers = generateServers(3);
+    groupMap = constructGroupInfo(servers, groups);
+    tableDescs = constructTableDesc(false);
+    Configuration conf = HBaseConfiguration.create();
+    conf.set("hbase.regions.slop", "0");
+    conf.setFloat("hbase.master.balancer.stochastic.readRequestCost", 10000f);
+    conf.set("hbase.rsgroup.grouploadbalancer.class",
+        StochasticLoadBalancer.class.getCanonicalName());
+    loadBalancer = new RSGroupBasedLoadBalancer(getMockedGroupInfoManager());
+    loadBalancer.setMasterServices(getMockedMaster());
+    loadBalancer.setConf(conf);
+    loadBalancer.initialize();
+  }
+
+  private ServerLoad mockServerLoadWithReadRequests(ServerName server,
+      List<HRegionInfo> regionsOnServer, long readRequestCount) {
+    ServerLoad serverMetrics = mock(ServerLoad.class);
+    Map<byte[], RegionLoad> regionLoadMap = new TreeMap<>(Bytes.BYTES_COMPARATOR);
+    for(HRegionInfo info : regionsOnServer){
+      RegionLoad rl = mock(RegionLoad.class);
+      when(rl.getReadRequestsCount()).thenReturn(readRequestCount);
+      when(rl.getWriteRequestsCount()).thenReturn(0L);
+      when(rl.getMemStoreSizeMB()).thenReturn(0);
+      when(rl.getStorefileSizeMB()).thenReturn(0);
+      regionLoadMap.put(info.getEncodedNameAsBytes(), rl);
+    }
+    when(serverMetrics.getRegionsLoad()).thenReturn(regionLoadMap);
+    return serverMetrics;
+  }
+
+  /**
+   * Test HBASE-20791
+   */
+  @Test
+  public void testBalanceCluster() throws HBaseIOException {
+    // mock cluster State
+    Map<ServerName, List<HRegionInfo>> clusterState = new HashMap<ServerName, List<HRegionInfo>>();
+    ServerName serverA = servers.get(0);
+    ServerName serverB = servers.get(1);
+    ServerName serverC = servers.get(2);
+    List<HRegionInfo> regionsOnServerA = randomRegions(3);
+    List<HRegionInfo> regionsOnServerB = randomRegions(3);
+    List<HRegionInfo> regionsOnServerC = randomRegions(3);
+    clusterState.put(serverA, regionsOnServerA);
+    clusterState.put(serverB, regionsOnServerB);
+    clusterState.put(serverC, regionsOnServerC);
+    // mock ClusterMetrics
+    final Map<ServerName, ServerLoad> serverMetricsMap = new TreeMap<>();
+    serverMetricsMap.put(serverA, mockServerLoadWithReadRequests(serverA, regionsOnServerA, 0));
+    serverMetricsMap.put(serverB, mockServerLoadWithReadRequests(serverB, regionsOnServerB, 0));
+    serverMetricsMap.put(serverC, mockServerLoadWithReadRequests(serverC, regionsOnServerC, 0));
+    ClusterStatus clusterStatus = mock(ClusterStatus.class);
+    when(clusterStatus.getServers()).thenReturn(serverMetricsMap.keySet());
+    when(clusterStatus.getLoad(Mockito.any(ServerName.class)))
+        .thenAnswer(new Answer<ServerLoad>() {
+          @Override
+          public ServerLoad answer(InvocationOnMock invocation) throws Throwable {
+            return serverMetricsMap.get(invocation.getArguments()[0]);
+          }
+        });
+    loadBalancer.setClusterStatus(clusterStatus);
+
+    // ReadRequestCostFunction are Rate based, So doing setClusterMetrics again
+    // this time, regions on serverA with more readRequestCount load
+    // serverA : 1000,1000,1000
+    // serverB : 0,0,0
+    // serverC : 0,0,0
+    // so should move two regions from serverA to serverB & serverC
+    final Map<ServerName, ServerLoad> serverMetricsMap2 = new TreeMap<>();
+    serverMetricsMap2.put(serverA, mockServerLoadWithReadRequests(serverA,
+        regionsOnServerA, 1000));
+    serverMetricsMap2.put(serverB, mockServerLoadWithReadRequests(serverB, regionsOnServerB, 0));
+    serverMetricsMap2.put(serverC, mockServerLoadWithReadRequests(serverC, regionsOnServerC, 0));
+    clusterStatus = mock(ClusterStatus.class);
+    when(clusterStatus.getServers()).thenReturn(serverMetricsMap2.keySet());
+    when(clusterStatus.getLoad(Mockito.any(ServerName.class)))
+        .thenAnswer(new Answer<ServerLoad>() {
+          @Override
+          public ServerLoad answer(InvocationOnMock invocation) throws Throwable {
+            return serverMetricsMap2.get(invocation.getArguments()[0]);
+          }
+        });
+    loadBalancer.setClusterStatus(clusterStatus);
+
+    List<RegionPlan> plans = loadBalancer.balanceCluster(clusterState);
+    Set<HRegionInfo> regionsMoveFromServerA = new HashSet<>();
+    Set<ServerName> targetServers = new HashSet<>();
+    for(RegionPlan plan : plans) {
+      if(plan.getSource().equals(serverA)) {
+        regionsMoveFromServerA.add(plan.getRegionInfo());
+        targetServers.add(plan.getDestination());
+      }
+    }
+    // should move 2 regions from serverA, one moves to serverB, the other moves to serverC
+    assertEquals(2, regionsMoveFromServerA.size());
+    assertEquals(2, targetServers.size());
+    assertTrue(regionsOnServerA.containsAll(regionsMoveFromServerA));
+  }
+}