You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hudi.apache.org by na...@apache.org on 2020/11/02 16:34:27 UTC

[hudi] branch master updated: Use RateLimiter instead of sleep. Repartition WriteStatus to optimize Hbase index writes (#1484)

This is an automated email from the ASF dual-hosted git repository.

nagarwal pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/hudi.git


The following commit(s) were added to refs/heads/master by this push:
     new 59f995a  Use RateLimiter instead of sleep. Repartition WriteStatus to optimize Hbase index writes (#1484)
59f995a is described below

commit 59f995a3f5476e8171f24a32250da85e13d28daf
Author: Venkatesh Rudraraju <33...@users.noreply.github.com>
AuthorDate: Mon Nov 2 08:33:27 2020 -0800

    Use RateLimiter instead of sleep. Repartition WriteStatus to optimize Hbase index writes (#1484)
---
 .../hudi/index/hbase/SparkHoodieHBaseIndex.java    | 211 +++++++++++++++------
 .../apache/hudi/index/hbase/TestHBaseIndex.java    |  93 ++++++++-
 .../hbase/TestHBasePutBatchSizeCalculator.java     |  35 ++--
 .../org/apache/hudi/common/util/RateLimiter.java   |  91 +++++++++
 .../apache/hudi/common/util/TestRatelimiter.java   |  52 +++++
 5 files changed, 398 insertions(+), 84 deletions(-)

diff --git a/hudi-client/hudi-spark-client/src/main/java/org/apache/hudi/index/hbase/SparkHoodieHBaseIndex.java b/hudi-client/hudi-spark-client/src/main/java/org/apache/hudi/index/hbase/SparkHoodieHBaseIndex.java
index 072c71c..77659b7 100644
--- a/hudi-client/hudi-spark-client/src/main/java/org/apache/hudi/index/hbase/SparkHoodieHBaseIndex.java
+++ b/hudi-client/hudi-spark-client/src/main/java/org/apache/hudi/index/hbase/SparkHoodieHBaseIndex.java
@@ -30,6 +30,7 @@ import org.apache.hudi.common.model.HoodieRecordPayload;
 import org.apache.hudi.common.table.HoodieTableMetaClient;
 import org.apache.hudi.common.table.timeline.HoodieTimeline;
 import org.apache.hudi.common.util.Option;
+import org.apache.hudi.common.util.RateLimiter;
 import org.apache.hudi.common.util.ReflectionUtils;
 import org.apache.hudi.config.HoodieHBaseIndexConfig;
 import org.apache.hudi.config.HoodieWriteConfig;
@@ -55,18 +56,23 @@ import org.apache.hadoop.hbase.client.Result;
 import org.apache.hadoop.hbase.util.Bytes;
 import org.apache.log4j.LogManager;
 import org.apache.log4j.Logger;
+import org.apache.spark.Partitioner;
 import org.apache.spark.SparkConf;
 import org.apache.spark.api.java.JavaPairRDD;
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.api.java.function.Function2;
+import org.joda.time.DateTime;
 
 import java.io.IOException;
 import java.io.Serializable;
 import java.util.ArrayList;
+import java.util.HashMap;
 import java.util.Iterator;
 import java.util.LinkedList;
 import java.util.List;
+import java.util.Map;
+import java.util.concurrent.TimeUnit;
 
 import scala.Tuple2;
 
@@ -84,13 +90,14 @@ public class SparkHoodieHBaseIndex<T extends HoodieRecordPayload> extends SparkH
   private static final byte[] COMMIT_TS_COLUMN = Bytes.toBytes("commit_ts");
   private static final byte[] FILE_NAME_COLUMN = Bytes.toBytes("file_name");
   private static final byte[] PARTITION_PATH_COLUMN = Bytes.toBytes("partition_path");
-  private static final int SLEEP_TIME_MILLISECONDS = 100;
 
   private static final Logger LOG = LogManager.getLogger(SparkHoodieHBaseIndex.class);
   private static Connection hbaseConnection = null;
   private HBaseIndexQPSResourceAllocator hBaseIndexQPSResourceAllocator = null;
-  private float qpsFraction;
   private int maxQpsPerRegionServer;
+  private long totalNumInserts;
+  private int numWriteStatusWithInserts;
+
   /**
    * multiPutBatchSize will be computed and re-set in updateLocation if
    * {@link HoodieHBaseIndexConfig#HBASE_PUT_BATCH_SIZE_AUTO_COMPUTE_PROP} is set to true.
@@ -109,7 +116,6 @@ public class SparkHoodieHBaseIndex<T extends HoodieRecordPayload> extends SparkH
 
   private void init(HoodieWriteConfig config) {
     this.multiPutBatchSize = config.getHbaseIndexGetBatchSize();
-    this.qpsFraction = config.getHbaseIndexQPSFraction();
     this.maxQpsPerRegionServer = config.getHbaseIndexMaxQPSPerRegionServer();
     this.putBatchSizeCalculator = new HBasePutBatchSizeCalculator();
     this.hBaseIndexQPSResourceAllocator = createQPSResourceAllocator(this.config);
@@ -163,7 +169,7 @@ public class SparkHoodieHBaseIndex<T extends HoodieRecordPayload> extends SparkH
    */
   @Override
   public void close() {
-    this.hBaseIndexQPSResourceAllocator.releaseQPSResources();
+    LOG.info("No resources to release from Hbase index");
   }
 
   private Get generateStatement(String key) throws IOException {
@@ -185,12 +191,14 @@ public class SparkHoodieHBaseIndex<T extends HoodieRecordPayload> extends SparkH
   private Function2<Integer, Iterator<HoodieRecord<T>>, Iterator<HoodieRecord<T>>> locationTagFunction(
       HoodieTableMetaClient metaClient) {
 
+    // `multiGetBatchSize` is intended to be a batch per 100ms. To create a rate limiter that measures
+    // operations per second, we need to multiply `multiGetBatchSize` by 10.
+    Integer multiGetBatchSize = config.getHbaseIndexGetBatchSize();
     return (Function2<Integer, Iterator<HoodieRecord<T>>, Iterator<HoodieRecord<T>>>) (partitionNum,
         hoodieRecordIterator) -> {
 
-      int multiGetBatchSize = config.getHbaseIndexGetBatchSize();
       boolean updatePartitionPath = config.getHbaseIndexUpdatePartitionPath();
-
+      RateLimiter limiter = RateLimiter.create(multiGetBatchSize * 10, TimeUnit.SECONDS);
       // Grab the global HBase connection
       synchronized (SparkHoodieHBaseIndex.class) {
         if (hbaseConnection == null || hbaseConnection.isClosed()) {
@@ -211,7 +219,7 @@ public class SparkHoodieHBaseIndex<T extends HoodieRecordPayload> extends SparkH
             continue;
           }
           // get results for batch from Hbase
-          Result[] results = doGet(hTable, statements);
+          Result[] results = doGet(hTable, statements, limiter);
           // clear statements to be GC'd
           statements.clear();
           for (Result result : results) {
@@ -262,9 +270,12 @@ public class SparkHoodieHBaseIndex<T extends HoodieRecordPayload> extends SparkH
     };
   }
 
-  private Result[] doGet(HTable hTable, List<Get> keys) throws IOException {
-    sleepForTime(SLEEP_TIME_MILLISECONDS);
-    return hTable.get(keys);
+  private Result[] doGet(HTable hTable, List<Get> keys, RateLimiter limiter) throws IOException {
+    if (keys.size() > 0) {
+      limiter.tryAcquire(keys.size());
+      return hTable.get(keys);
+    }
+    return new Result[keys.size()];
   }
 
   @Override
@@ -285,11 +296,21 @@ public class SparkHoodieHBaseIndex<T extends HoodieRecordPayload> extends SparkH
           hbaseConnection = getHBaseConnection();
         }
       }
+      final long startTimeForPutsTask = DateTime.now().getMillis();
+      LOG.info("startTimeForPutsTask for this task: " + startTimeForPutsTask);
+
       try (BufferedMutator mutator = hbaseConnection.getBufferedMutator(TableName.valueOf(tableName))) {
+        final RateLimiter limiter = RateLimiter.create(multiPutBatchSize, TimeUnit.SECONDS);
         while (statusIterator.hasNext()) {
           WriteStatus writeStatus = statusIterator.next();
           List<Mutation> mutations = new ArrayList<>();
           try {
+            long numOfInserts = writeStatus.getStat().getNumInserts();
+            LOG.info("Num of inserts in this WriteStatus: " + numOfInserts);
+            LOG.info("Total inserts in this job: " + this.totalNumInserts);
+            LOG.info("multiPutBatchSize for this job: " + this.multiPutBatchSize);
+            // Create a rate limiter that allows `multiPutBatchSize` operations per second
+            // Any calls beyond `multiPutBatchSize` within a second will be rate limited
             for (HoodieRecord rec : writeStatus.getWrittenRecords()) {
               if (!writeStatus.isErrored(rec.getKey())) {
                 Option<HoodieRecordLocation> loc = rec.getNewLocation();
@@ -312,10 +333,10 @@ public class SparkHoodieHBaseIndex<T extends HoodieRecordPayload> extends SparkH
               if (mutations.size() < multiPutBatchSize) {
                 continue;
               }
-              doMutations(mutator, mutations);
+              doMutations(mutator, mutations, limiter);
             }
             // process remaining puts and deletes, if any
-            doMutations(mutator, mutations);
+            doMutations(mutator, mutations, limiter);
           } catch (Exception e) {
             Exception we = new Exception("Error updating index for " + writeStatus, e);
             LOG.error(we);
@@ -323,6 +344,8 @@ public class SparkHoodieHBaseIndex<T extends HoodieRecordPayload> extends SparkH
           }
           writeStatusList.add(writeStatus);
         }
+        final long endPutsTime = DateTime.now().getMillis();
+        LOG.info("hbase puts task time for this task: " + (endPutsTime - startTimeForPutsTask));
       } catch (IOException e) {
         throw new HoodieIndexException("Failed to Update Index locations because of exception with HBase Client", e);
       }
@@ -333,67 +356,95 @@ public class SparkHoodieHBaseIndex<T extends HoodieRecordPayload> extends SparkH
   /**
    * Helper method to facilitate performing mutations (including puts and deletes) in Hbase.
    */
-  private void doMutations(BufferedMutator mutator, List<Mutation> mutations) throws IOException {
+  private void doMutations(BufferedMutator mutator, List<Mutation> mutations, RateLimiter limiter) throws IOException {
     if (mutations.isEmpty()) {
       return;
     }
+    // report number of operations to account per second with rate limiter.
+    // If #limiter.getRate() operations are acquired within 1 second, ratelimiter will limit the rest of calls
+    // for within that second
+    limiter.tryAcquire(mutations.size());
     mutator.mutate(mutations);
     mutator.flush();
     mutations.clear();
-    sleepForTime(SLEEP_TIME_MILLISECONDS);
   }
 
-  private static void sleepForTime(int sleepTimeMs) {
-    try {
-      Thread.sleep(sleepTimeMs);
-    } catch (InterruptedException e) {
-      LOG.error("Sleep interrupted during throttling", e);
-      throw new RuntimeException(e);
+  public Map<String, Integer> mapFileWithInsertsToUniquePartition(JavaRDD<WriteStatus> writeStatusRDD) {
+    final Map<String, Integer> fileIdPartitionMap = new HashMap<>();
+    int partitionIndex = 0;
+    // Map each fileId that has inserts to a unique partition Id. This will be used while
+    // repartitioning RDD<WriteStatus>
+    final List<String> fileIds = writeStatusRDD.filter(w -> w.getStat().getNumInserts() > 0)
+                                   .map(w -> w.getFileId()).collect();
+    for (final String fileId : fileIds) {
+      fileIdPartitionMap.put(fileId, partitionIndex++);
     }
+    return fileIdPartitionMap;
   }
 
   @Override
   public JavaRDD<WriteStatus> updateLocation(JavaRDD<WriteStatus> writeStatusRDD, HoodieEngineContext context,
-                                             HoodieTable<T, JavaRDD<HoodieRecord<T>>, JavaRDD<HoodieKey>, JavaRDD<WriteStatus>> hoodieTable) {
-    final HBaseIndexQPSResourceAllocator hBaseIndexQPSResourceAllocator = createQPSResourceAllocator(this.config);
-    setPutBatchSize(writeStatusRDD, hBaseIndexQPSResourceAllocator, context);
-    LOG.info("multiPutBatchSize: before hbase puts" + multiPutBatchSize);
-    JavaRDD<WriteStatus> writeStatusJavaRDD = writeStatusRDD.mapPartitionsWithIndex(updateLocationFunction(), true);
+                                             HoodieTable<T, JavaRDD<HoodieRecord<T>>, JavaRDD<HoodieKey>,
+                                                            JavaRDD<WriteStatus>> hoodieTable) {
+    final Option<Float> desiredQPSFraction =  calculateQPSFraction(writeStatusRDD);
+    final Map<String, Integer> fileIdPartitionMap = mapFileWithInsertsToUniquePartition(writeStatusRDD);
+    JavaRDD<WriteStatus> partitionedRDD = this.numWriteStatusWithInserts == 0 ? writeStatusRDD :
+                                          writeStatusRDD.mapToPair(w -> new Tuple2<>(w.getFileId(), w))
+                                              .partitionBy(new WriteStatusPartitioner(fileIdPartitionMap,
+                                                  this.numWriteStatusWithInserts))
+                                              .map(w -> w._2());
+    JavaSparkContext jsc = HoodieSparkEngineContext.getSparkContext(context);
+    acquireQPSResourcesAndSetBatchSize(desiredQPSFraction, jsc);
+    JavaRDD<WriteStatus> writeStatusJavaRDD = partitionedRDD.mapPartitionsWithIndex(updateLocationFunction(),
+        true);
     // caching the index updated status RDD
     writeStatusJavaRDD = writeStatusJavaRDD.persist(SparkMemoryUtils.getWriteStatusStorageLevel(config.getProps()));
+    // force trigger update location(hbase puts)
+    writeStatusJavaRDD.count();
+    this.hBaseIndexQPSResourceAllocator.releaseQPSResources();
     return writeStatusJavaRDD;
   }
 
-  private void setPutBatchSize(JavaRDD<WriteStatus> writeStatusRDD,
-      HBaseIndexQPSResourceAllocator hBaseIndexQPSResourceAllocator, final HoodieEngineContext context) {
+  private Option<Float> calculateQPSFraction(JavaRDD<WriteStatus> writeStatusRDD) {
     if (config.getHbaseIndexPutBatchSizeAutoCompute()) {
-      JavaSparkContext jsc = HoodieSparkEngineContext.getSparkContext(context);
-      SparkConf conf = jsc.getConf();
-      int maxExecutors = conf.getInt(DEFAULT_SPARK_EXECUTOR_INSTANCES_CONFIG_NAME, 1);
-      if (conf.getBoolean(DEFAULT_SPARK_DYNAMIC_ALLOCATION_ENABLED_CONFIG_NAME, false)) {
-        maxExecutors =
-            Math.max(maxExecutors, conf.getInt(DEFAULT_SPARK_DYNAMIC_ALLOCATION_MAX_EXECUTORS_CONFIG_NAME, 1));
-      }
-
       /*
-       * Each writeStatus represents status information from a write done in one of the IOHandles. If a writeStatus has
-       * any insert, it implies that the corresponding task contacts HBase for doing puts, since we only do puts for
-       * inserts from HBaseIndex.
+        Each writeStatus represents status information from a write done in one of the IOHandles.
+        If a writeStatus has any insert, it implies that the corresponding task contacts HBase for
+        doing puts, since we only do puts for inserts from HBaseIndex.
        */
-      final Tuple2<Long, Integer> numPutsParallelismTuple = getHBasePutAccessParallelism(writeStatusRDD);
-      final long numPuts = numPutsParallelismTuple._1;
-      final int hbasePutsParallelism = numPutsParallelismTuple._2;
+      final Tuple2<Long, Integer> numPutsParallelismTuple  = getHBasePutAccessParallelism(writeStatusRDD);
+      this.totalNumInserts = numPutsParallelismTuple._1;
+      this.numWriteStatusWithInserts = numPutsParallelismTuple._2;
       this.numRegionServersForTable = getNumRegionServersAliveForTable();
-      final float desiredQPSFraction =
-          hBaseIndexQPSResourceAllocator.calculateQPSFractionForPutsTime(numPuts, this.numRegionServersForTable);
+      final float desiredQPSFraction = this.hBaseIndexQPSResourceAllocator.calculateQPSFractionForPutsTime(
+          this.totalNumInserts, this.numRegionServersForTable);
       LOG.info("Desired QPSFraction :" + desiredQPSFraction);
-      LOG.info("Number HBase puts :" + numPuts);
-      LOG.info("Hbase Puts Parallelism :" + hbasePutsParallelism);
-      final float availableQpsFraction =
-          hBaseIndexQPSResourceAllocator.acquireQPSResources(desiredQPSFraction, numPuts);
+      LOG.info("Number HBase puts :" + this.totalNumInserts);
+      LOG.info("Number of WriteStatus with inserts :" + numWriteStatusWithInserts);
+      return Option.of(desiredQPSFraction);
+    }
+    return Option.empty();
+  }
+
+  private void acquireQPSResourcesAndSetBatchSize(final Option<Float> desiredQPSFraction,
+                                                  final JavaSparkContext jsc) {
+    if (config.getHbaseIndexPutBatchSizeAutoCompute()) {
+      SparkConf conf = jsc.getConf();
+      int maxExecutors = conf.getInt(DEFAULT_SPARK_EXECUTOR_INSTANCES_CONFIG_NAME, 1);
+      if (conf.getBoolean(DEFAULT_SPARK_DYNAMIC_ALLOCATION_ENABLED_CONFIG_NAME, false)) {
+        maxExecutors = Math.max(maxExecutors, conf.getInt(
+          DEFAULT_SPARK_DYNAMIC_ALLOCATION_MAX_EXECUTORS_CONFIG_NAME, 1));
+      }
+      final float availableQpsFraction = this.hBaseIndexQPSResourceAllocator
+                                           .acquireQPSResources(desiredQPSFraction.get(), this.totalNumInserts);
       LOG.info("Allocated QPS Fraction :" + availableQpsFraction);
-      multiPutBatchSize = putBatchSizeCalculator.getBatchSize(numRegionServersForTable, maxQpsPerRegionServer,
-          hbasePutsParallelism, maxExecutors, SLEEP_TIME_MILLISECONDS, availableQpsFraction);
+      multiPutBatchSize = putBatchSizeCalculator
+                            .getBatchSize(
+                              numRegionServersForTable,
+                              maxQpsPerRegionServer,
+                              numWriteStatusWithInserts,
+                              maxExecutors,
+                              availableQpsFraction);
       LOG.info("multiPutBatchSize :" + multiPutBatchSize);
     }
   }
@@ -406,7 +457,6 @@ public class SparkHoodieHBaseIndex<T extends HoodieRecordPayload> extends SparkH
 
   public static class HBasePutBatchSizeCalculator implements Serializable {
 
-    private static final int MILLI_SECONDS_IN_A_SECOND = 1000;
     private static final Logger LOG = LogManager.getLogger(HBasePutBatchSizeCalculator.class);
 
     /**
@@ -441,22 +491,26 @@ public class SparkHoodieHBaseIndex<T extends HoodieRecordPayload> extends SparkH
      * </li>
      * </p>
      */
-    public int getBatchSize(int numRegionServersForTable, int maxQpsPerRegionServer, int numTasksDuringPut,
-        int maxExecutors, int sleepTimeMs, float qpsFraction) {
-      int maxReqPerSec = (int) (qpsFraction * numRegionServersForTable * maxQpsPerRegionServer);
-      int maxParallelPuts = Math.max(1, Math.min(numTasksDuringPut, maxExecutors));
-      int maxReqsSentPerTaskPerSec = MILLI_SECONDS_IN_A_SECOND / sleepTimeMs;
-      int multiPutBatchSize = Math.max(1, maxReqPerSec / (maxParallelPuts * maxReqsSentPerTaskPerSec));
+    public int getBatchSize(int numRegionServersForTable, int maxQpsPerRegionServer,
+                            int numTasksDuringPut, int maxExecutors, float qpsFraction) {
+      int numRSAlive = numRegionServersForTable;
+      int maxReqPerSec = getMaxReqPerSec(numRSAlive, maxQpsPerRegionServer, qpsFraction);
+      int numTasks = numTasksDuringPut;
+      int maxParallelPutsTask = Math.max(1, Math.min(numTasks, maxExecutors));
+      int multiPutBatchSizePerSecPerTask = Math.max(1, (int) Math.ceil(maxReqPerSec / maxParallelPutsTask));
       LOG.info("HbaseIndexThrottling: qpsFraction :" + qpsFraction);
-      LOG.info("HbaseIndexThrottling: numRSAlive :" + numRegionServersForTable);
+      LOG.info("HbaseIndexThrottling: numRSAlive :" + numRSAlive);
       LOG.info("HbaseIndexThrottling: maxReqPerSec :" + maxReqPerSec);
-      LOG.info("HbaseIndexThrottling: numTasks :" + numTasksDuringPut);
+      LOG.info("HbaseIndexThrottling: numTasks :" + numTasks);
       LOG.info("HbaseIndexThrottling: maxExecutors :" + maxExecutors);
-      LOG.info("HbaseIndexThrottling: maxParallelPuts :" + maxParallelPuts);
-      LOG.info("HbaseIndexThrottling: maxReqsSentPerTaskPerSec :" + maxReqsSentPerTaskPerSec);
+      LOG.info("HbaseIndexThrottling: maxParallelPuts :" + maxParallelPutsTask);
       LOG.info("HbaseIndexThrottling: numRegionServersForTable :" + numRegionServersForTable);
-      LOG.info("HbaseIndexThrottling: multiPutBatchSize :" + multiPutBatchSize);
-      return multiPutBatchSize;
+      LOG.info("HbaseIndexThrottling: multiPutBatchSizePerSecPerTask :" + multiPutBatchSizePerSecPerTask);
+      return multiPutBatchSizePerSecPerTask;
+    }
+
+    public int getMaxReqPerSec(int numRegionServersForTable, int maxQpsPerRegionServer, float qpsFraction) {
+      return (int) (qpsFraction * numRegionServersForTable * maxQpsPerRegionServer);
     }
   }
 
@@ -510,4 +564,37 @@ public class SparkHoodieHBaseIndex<T extends HoodieRecordPayload> extends SparkH
   public void setHbaseConnection(Connection hbaseConnection) {
     SparkHoodieHBaseIndex.hbaseConnection = hbaseConnection;
   }
+
+  /**
+   * Partitions each WriteStatus with inserts into a unique single partition. WriteStatus without inserts will be
+   * assigned to random partitions. This partitioner will be useful to utilize max parallelism with spark operations
+   * that are based on inserts in each WriteStatus.
+   */
+  public static class WriteStatusPartitioner extends Partitioner {
+    private int totalPartitions;
+    final Map<String, Integer> fileIdPartitionMap;
+
+    public WriteStatusPartitioner(final Map<String, Integer> fileIdPartitionMap, final int totalPartitions) {
+      this.totalPartitions = totalPartitions;
+      this.fileIdPartitionMap = fileIdPartitionMap;
+    }
+
+    @Override
+    public int numPartitions() {
+      return this.totalPartitions;
+    }
+
+    @Override
+    public int getPartition(Object key) {
+      final String fileId = (String) key;
+      if (!fileIdPartitionMap.containsKey(fileId)) {
+        LOG.info("This writestatus(fileId: " + fileId + ") is not mapped because it doesn't have any inserts. "
+                 + "In this case, we can assign a random partition to this WriteStatus.");
+        // Assign random spark partition for the `WriteStatus` that has no inserts. For a spark operation that depends
+        // on number of inserts, there won't be any performance penalty in packing these WriteStatus'es together.
+        return Math.abs(fileId.hashCode()) % totalPartitions;
+      }
+      return fileIdPartitionMap.get(fileId);
+    }
+  }
 }
diff --git a/hudi-client/hudi-spark-client/src/test/java/org/apache/hudi/index/hbase/TestHBaseIndex.java b/hudi-client/hudi-spark-client/src/test/java/org/apache/hudi/index/hbase/TestHBaseIndex.java
index b74daad..2eb672a 100644
--- a/hudi-client/hudi-spark-client/src/test/java/org/apache/hudi/index/hbase/TestHBaseIndex.java
+++ b/hudi-client/hudi-spark-client/src/test/java/org/apache/hudi/index/hbase/TestHBaseIndex.java
@@ -62,6 +62,8 @@ import org.junit.jupiter.api.TestMethodOrder;
 import java.util.Arrays;
 import java.util.LinkedList;
 import java.util.List;
+import java.util.Map;
+import java.util.UUID;
 import java.util.stream.Collectors;
 
 import scala.Tuple2;
@@ -382,13 +384,98 @@ public class TestHBaseIndex extends FunctionalTestHarness {
     HoodieWriteConfig config = getConfig();
     SparkHoodieHBaseIndex index = new SparkHoodieHBaseIndex(config);
     final JavaRDD<WriteStatus> writeStatusRDD = jsc().parallelize(
-        Arrays.asList(getSampleWriteStatus(1, 2), getSampleWriteStatus(0, 3), getSampleWriteStatus(10, 0)), 10);
+        Arrays.asList(
+            getSampleWriteStatus(0, 2),
+            getSampleWriteStatus(2, 3),
+            getSampleWriteStatus(4, 3),
+            getSampleWriteStatus(6, 3),
+            getSampleWriteStatus(8, 0)),
+        10);
     final Tuple2<Long, Integer> tuple = index.getHBasePutAccessParallelism(writeStatusRDD);
     final int hbasePutAccessParallelism = Integer.parseInt(tuple._2.toString());
     final int hbaseNumPuts = Integer.parseInt(tuple._1.toString());
     assertEquals(10, writeStatusRDD.getNumPartitions());
-    assertEquals(2, hbasePutAccessParallelism);
-    assertEquals(11, hbaseNumPuts);
+    assertEquals(4, hbasePutAccessParallelism);
+    assertEquals(20, hbaseNumPuts);
+  }
+
+  @Test
+  public void testsWriteStatusPartitioner() {
+    HoodieWriteConfig config = getConfig();
+    SparkHoodieHBaseIndex index = new SparkHoodieHBaseIndex(config);
+    int parallelism = 4;
+    final JavaRDD<WriteStatus> writeStatusRDD = jsc().parallelize(
+        Arrays.asList(
+            getSampleWriteStatusWithFileId(0, 2),
+            getSampleWriteStatusWithFileId(2, 3),
+            getSampleWriteStatusWithFileId(4, 3),
+            getSampleWriteStatusWithFileId(0, 3),
+            getSampleWriteStatusWithFileId(11, 0)), parallelism);
+
+    final Map<String, Integer> fileIdPartitionMap = index.mapFileWithInsertsToUniquePartition(writeStatusRDD);
+    int numWriteStatusWithInserts = (int) index.getHBasePutAccessParallelism(writeStatusRDD)._2;
+    JavaRDD<WriteStatus> partitionedRDD = writeStatusRDD.mapToPair(w -> new Tuple2<>(w.getFileId(), w))
+                                              .partitionBy(new SparkHoodieHBaseIndex
+                                                                   .WriteStatusPartitioner(fileIdPartitionMap,
+                                                  numWriteStatusWithInserts)).map(w -> w._2());
+    assertEquals(numWriteStatusWithInserts, partitionedRDD.getNumPartitions());
+    int[] partitionIndexesBeforeRepartition = writeStatusRDD.partitions().stream().mapToInt(p -> p.index()).toArray();
+    assertEquals(parallelism, partitionIndexesBeforeRepartition.length);
+
+    int[] partitionIndexesAfterRepartition = partitionedRDD.partitions().stream().mapToInt(p -> p.index()).toArray();
+    // there should be 3 partitions after repartition, because only 3 writestatus has
+    // inserts (numWriteStatusWithInserts)
+    assertEquals(numWriteStatusWithInserts, partitionIndexesAfterRepartition.length);
+
+    List<WriteStatus>[] writeStatuses = partitionedRDD.collectPartitions(partitionIndexesAfterRepartition);
+    for (List<WriteStatus> list : writeStatuses) {
+      int count = 0;
+      for (WriteStatus w: list) {
+        if (w.getStat().getNumInserts() > 0)   {
+          count++;
+        }
+      }
+      assertEquals(1, count);
+    }
+  }
+
+  @Test
+  public void testsWriteStatusPartitionerWithNoInserts() {
+    HoodieWriteConfig config = getConfig();
+    SparkHoodieHBaseIndex index = new SparkHoodieHBaseIndex(config);
+    int parallelism = 3;
+    final JavaRDD<WriteStatus> writeStatusRDD = jsc().parallelize(
+        Arrays.asList(
+            getSampleWriteStatusWithFileId(0, 2),
+            getSampleWriteStatusWithFileId(0, 3),
+            getSampleWriteStatusWithFileId(0, 0)), parallelism);
+
+    final Map<String, Integer> fileIdPartitionMap = index.mapFileWithInsertsToUniquePartition(writeStatusRDD);
+    int numWriteStatusWithInserts = (int) index.getHBasePutAccessParallelism(writeStatusRDD)._2;
+    JavaRDD<WriteStatus> partitionedRDD = writeStatusRDD.mapToPair(w -> new Tuple2<>(w.getFileId(), w))
+                                              .partitionBy(new SparkHoodieHBaseIndex
+                                                                   .WriteStatusPartitioner(fileIdPartitionMap,
+                                                  numWriteStatusWithInserts)).map(w -> w._2());
+    assertEquals(numWriteStatusWithInserts, partitionedRDD.getNumPartitions());
+    int[] partitionIndexesBeforeRepartition = writeStatusRDD.partitions().stream().mapToInt(p -> p.index()).toArray();
+    assertEquals(parallelism, partitionIndexesBeforeRepartition.length);
+
+    int[] partitionIndexesAfterRepartition = partitionedRDD.partitions().stream().mapToInt(p -> p.index()).toArray();
+    // there should be 3 partitions after repartition, because only 3 writestatus has inserts
+    // (numWriteStatusWithInserts)
+    assertEquals(numWriteStatusWithInserts, partitionIndexesAfterRepartition.length);
+    assertEquals(partitionIndexesBeforeRepartition.length, parallelism);
+
+  }
+
+  private WriteStatus getSampleWriteStatusWithFileId(final int numInserts, final int numUpdateWrites) {
+    final WriteStatus writeStatus = new WriteStatus(false, 0.0);
+    HoodieWriteStat hoodieWriteStat = new HoodieWriteStat();
+    hoodieWriteStat.setNumInserts(numInserts);
+    hoodieWriteStat.setNumUpdateWrites(numUpdateWrites);
+    writeStatus.setStat(hoodieWriteStat);
+    writeStatus.setFileId(UUID.randomUUID().toString());
+    return writeStatus;
   }
 
   @Test
diff --git a/hudi-client/hudi-spark-client/src/test/java/org/apache/hudi/index/hbase/TestHBasePutBatchSizeCalculator.java b/hudi-client/hudi-spark-client/src/test/java/org/apache/hudi/index/hbase/TestHBasePutBatchSizeCalculator.java
index e698eaf..a6068e6 100644
--- a/hudi-client/hudi-spark-client/src/test/java/org/apache/hudi/index/hbase/TestHBasePutBatchSizeCalculator.java
+++ b/hudi-client/hudi-spark-client/src/test/java/org/apache/hudi/index/hbase/TestHBasePutBatchSizeCalculator.java
@@ -27,38 +27,35 @@ public class TestHBasePutBatchSizeCalculator {
 
   @Test
   public void testPutBatchSizeCalculation() {
-    SparkHoodieHBaseIndex.HBasePutBatchSizeCalculator batchSizeCalculator = new SparkHoodieHBaseIndex.HBasePutBatchSizeCalculator();
-
+    SparkHoodieHBaseIndex.HBasePutBatchSizeCalculator batchSizeCalculator = new SparkHoodieHBaseIndex
+                                                                                    .HBasePutBatchSizeCalculator();
     // All asserts cases below are derived out of the first
     // example below, with change in one parameter at a time.
-
-    int putBatchSize = batchSizeCalculator.getBatchSize(10, 16667, 1200, 200, 100, 0.1f);
-    // Expected batchSize is 8 because in that case, total request sent in one second is below
-    // 8 (batchSize) * 200 (parallelism) * 10 (maxReqsInOneSecond) * 10 (numRegionServers) * 0.1 (qpsFraction)) => 16000
-    // We assume requests get distributed to Region Servers uniformly, so each RS gets 1600 request
-    // 1600 happens to be 10% of 16667 (maxQPSPerRegionServer) as expected.
-    assertEquals(8, putBatchSize);
+    int putBatchSize = batchSizeCalculator.getBatchSize(10, 16667, 1200, 200, 0.1f);
+    // Total puts that can be sent  in 1 second = (10 * 16667 * 0.1) = 16,667
+    // Total puts per batch will be (16,667 / parallelism) = 83.335, where 200 is the maxExecutors
+    assertEquals(putBatchSize, 83);
 
     // Number of Region Servers are halved, total requests sent in a second are also halved, so batchSize is also halved
-    int putBatchSize2 = batchSizeCalculator.getBatchSize(5, 16667, 1200, 200, 100, 0.1f);
-    assertEquals(4, putBatchSize2);
+    int putBatchSize2 = batchSizeCalculator.getBatchSize(5, 16667, 1200, 200, 0.1f);
+    assertEquals(putBatchSize2, 41);
 
     // If the parallelism is halved, batchSize has to double
-    int putBatchSize3 = batchSizeCalculator.getBatchSize(10, 16667, 1200, 100, 100, 0.1f);
-    assertEquals(16, putBatchSize3);
+    int putBatchSize3 = batchSizeCalculator.getBatchSize(10, 16667, 1200, 100, 0.1f);
+    assertEquals(putBatchSize3, 166);
 
     // If the parallelism is halved, batchSize has to double.
     // This time parallelism is driven by numTasks rather than numExecutors
-    int putBatchSize4 = batchSizeCalculator.getBatchSize(10, 16667, 100, 200, 100, 0.1f);
-    assertEquals(16, putBatchSize4);
+    int putBatchSize4 = batchSizeCalculator.getBatchSize(10, 16667, 100, 200, 0.1f);
+    assertEquals(putBatchSize4, 166);
 
     // If sleepTimeMs is halved, batchSize has to halve
-    int putBatchSize5 = batchSizeCalculator.getBatchSize(10, 16667, 1200, 200, 100, 0.05f);
-    assertEquals(4, putBatchSize5);
+    int putBatchSize5 = batchSizeCalculator.getBatchSize(10, 16667, 1200, 200, 0.05f);
+    assertEquals(putBatchSize5, 41);
 
     // If maxQPSPerRegionServer is doubled, batchSize also doubles
-    int putBatchSize6 = batchSizeCalculator.getBatchSize(10, 33334, 1200, 200, 100, 0.1f);
-    assertEquals(16, putBatchSize6);
+    int putBatchSize6 = batchSizeCalculator.getBatchSize(10, 33334, 1200, 200, 0.1f);
+    assertEquals(putBatchSize6, 166);
   }
 
 }
diff --git a/hudi-common/src/main/java/org/apache/hudi/common/util/RateLimiter.java b/hudi-common/src/main/java/org/apache/hudi/common/util/RateLimiter.java
new file mode 100644
index 0000000..e156ccf
--- /dev/null
+++ b/hudi-common/src/main/java/org/apache/hudi/common/util/RateLimiter.java
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hudi.common.util;
+
+import org.apache.log4j.LogManager;
+import org.apache.log4j.Logger;
+
+import java.util.concurrent.Executors;
+import java.util.concurrent.ScheduledExecutorService;
+import java.util.concurrent.Semaphore;
+import java.util.concurrent.TimeUnit;
+import javax.annotation.concurrent.ThreadSafe;
+
+@ThreadSafe
+public class RateLimiter {
+
+  private final Semaphore semaphore;
+  private final int maxPermits;
+  private final TimeUnit timePeriod;
+  private ScheduledExecutorService scheduler;
+  private static final long RELEASE_PERMITS_PERIOD_IN_SECONDS = 1L;
+  private static final long WAIT_BEFORE_NEXT_ACQUIRE_PERMIT_IN_MS = 5;
+  private static final int SCHEDULER_CORE_THREAD_POOL_SIZE = 1;
+
+  private static final Logger LOG = LogManager.getLogger(RateLimiter.class);
+
+  public static RateLimiter create(int permits, TimeUnit timePeriod) {
+    final RateLimiter limiter = new RateLimiter(permits, timePeriod);
+    limiter.releasePermitsPeriodically();
+    return limiter;
+  }
+
+  private RateLimiter(int permits, TimeUnit timePeriod) {
+    this.semaphore = new Semaphore(permits);
+    this.maxPermits = permits;
+    this.timePeriod = timePeriod;
+  }
+
+  public boolean tryAcquire(int numPermits) {
+    if (numPermits > maxPermits) {
+      acquire(maxPermits);
+      return tryAcquire(numPermits - maxPermits);
+    } else {
+      return acquire(numPermits);
+    }
+  }
+
+  public boolean acquire(int numOps) {
+    try {
+      if (!semaphore.tryAcquire(numOps)) {
+        Thread.sleep(WAIT_BEFORE_NEXT_ACQUIRE_PERMIT_IN_MS);
+        return acquire(numOps);
+      }
+      LOG.debug(String.format("acquire permits: %s, maxPremits: %s", numOps, maxPermits));
+    } catch (InterruptedException e) {
+      throw new RuntimeException("Unable to acquire permits", e);
+    }
+    return true;
+  }
+
+  public void stop() {
+    scheduler.shutdownNow();
+  }
+
+  public void releasePermitsPeriodically() {
+    scheduler = Executors.newScheduledThreadPool(SCHEDULER_CORE_THREAD_POOL_SIZE);
+    scheduler.scheduleAtFixedRate(() -> {
+      LOG.debug(String.format("Release permits: maxPremits: %s, available: %s", maxPermits,
+          semaphore.availablePermits()));
+      semaphore.release(maxPermits - semaphore.availablePermits());
+    }, RELEASE_PERMITS_PERIOD_IN_SECONDS, RELEASE_PERMITS_PERIOD_IN_SECONDS, timePeriod);
+
+  }
+
+}
diff --git a/hudi-common/src/test/java/org/apache/hudi/common/util/TestRatelimiter.java b/hudi-common/src/test/java/org/apache/hudi/common/util/TestRatelimiter.java
new file mode 100644
index 0000000..c2e939c
--- /dev/null
+++ b/hudi-common/src/test/java/org/apache/hudi/common/util/TestRatelimiter.java
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hudi.common.util;
+
+import java.util.concurrent.TimeUnit;
+import org.junit.jupiter.api.Test;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+public class TestRatelimiter {
+
+  @Test
+  public void testRateLimiterWithNoThrottling() throws InterruptedException {
+    RateLimiter limiter =  RateLimiter.create(1000, TimeUnit.SECONDS);
+    long start = System.currentTimeMillis();
+    assertEquals(true, limiter.tryAcquire(1000));
+    // Sleep to represent some operation
+    Thread.sleep(500);
+    long end = System.currentTimeMillis();
+    // With a large permit limit, there shouldn't be any throttling of operations
+    assertTrue((end - start) < TimeUnit.SECONDS.toMillis(2));
+  }
+
+  @Test
+  public void testRateLimiterWithThrottling() throws InterruptedException {
+    RateLimiter limiter =  RateLimiter.create(100, TimeUnit.SECONDS);
+    long start = System.currentTimeMillis();
+    assertEquals(true, limiter.tryAcquire(400));
+    // Sleep to represent some operation
+    Thread.sleep(500);
+    long end = System.currentTimeMillis();
+    // As size of operations is more than the maximum permits per second,
+    // whole execution should be greater than 1 second
+    assertTrue((end - start) >= TimeUnit.SECONDS.toMillis(2));
+  }
+}