You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@celeborn.apache.org by zh...@apache.org on 2023/03/02 10:33:01 UTC

[incubator-celeborn] branch branch-0.2 updated: [CELEBORN-348] Support fetchTime in load-aware slots assignment strategy (#1296)

This is an automated email from the ASF dual-hosted git repository.

zhouky pushed a commit to branch branch-0.2
in repository https://gitbox.apache.org/repos/asf/incubator-celeborn.git


The following commit(s) were added to refs/heads/branch-0.2 by this push:
     new 197131550 [CELEBORN-348] Support fetchTime in load-aware slots assignment strategy (#1296)
197131550 is described below

commit 197131550c30e9582cd0ddc6a45cc3dc23118f23
Author: Keyong Zhou <zh...@apache.org>
AuthorDate: Thu Mar 2 18:32:57 2023 +0800

    [CELEBORN-348] Support fetchTime in load-aware slots assignment strategy (#1296)
---
 .../spark/shuffle/celeborn/SortBasedPusher.java    |  3 +-
 .../common/network/server/ChunkStreamManager.java  | 19 +++++--
 common/src/main/proto/TransportMessages.proto      |  1 +
 .../org/apache/celeborn/common/CelebornConf.scala  | 56 +++++++++++++++----
 .../apache/celeborn/common/meta/DeviceInfo.scala   | 48 +++++++++++++----
 .../apache/celeborn/common/meta/TimeWindow.scala   | 59 ++++++++++++++++++++
 .../apache/celeborn/common/meta/WorkerInfo.scala   |  1 +
 .../apache/celeborn/common/util/PbSerDeUtils.scala |  2 +
 .../network/server/ChunkStreamManagerSuiteJ.java   |  4 +-
 .../celeborn/common/meta/TimeWindowSuite.scala     | 42 +++++++++++++++
 .../celeborn/common/meta/WorkerInfoSuite.scala     | 18 +++----
 .../celeborn/common/util/PbSerDeUtilsTest.scala    |  5 +-
 docs/configuration/master.md                       |  2 +
 docs/configuration/worker.md                       |  2 +-
 .../service/deploy/master/SlotsAllocator.java      | 21 ++++++--
 .../deploy/master/clustermeta/MetaUtil.java        |  7 ++-
 master/src/main/proto/Resource.proto               |  1 +
 .../celeborn/service/deploy/master/Master.scala    | 11 ++--
 .../deploy/master/SlotsAllocatorSuiteJ.java        | 58 ++++++++++++++++----
 .../clustermeta/DefaultMetaSystemSuiteJ.java       | 24 ++++-----
 .../clustermeta/ha/MasterStateMachineSuiteJ.java   | 18 +++----
 .../ha/RatisMasterStatusSystemSuiteJ.java          | 24 ++++-----
 .../service/deploy/worker/FetchHandler.scala       | 12 +++--
 .../service/deploy/worker/storage/Flusher.scala    | 62 +++++-----------------
 .../deploy/worker/storage/StorageManager.scala     | 26 +++++----
 .../deploy/worker/storage/FileWriterSuiteJ.java    | 11 ++--
 .../deploy/worker/storage/DeviceMonitorSuite.scala |  6 ++-
 27 files changed, 383 insertions(+), 160 deletions(-)

diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedPusher.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedPusher.java
index a23ac2646..1109ac909 100644
--- a/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedPusher.java
+++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedPusher.java
@@ -183,8 +183,7 @@ public class SortBasedPusher extends MemoryConsumer {
       throws IOException {
 
     if (getUsed() > pushSortMemoryThreshold
-        && pageCursor + size8k
-            > currentPage.getBaseOffset() + currentPage.size()) {
+        && pageCursor + size8k > currentPage.getBaseOffset() + currentPage.size()) {
       logger.info(
           "Memory Used across threshold, trigger push. Memory: "
               + getUsed()
diff --git a/common/src/main/java/org/apache/celeborn/common/network/server/ChunkStreamManager.java b/common/src/main/java/org/apache/celeborn/common/network/server/ChunkStreamManager.java
index 83541c855..6387aaca2 100644
--- a/common/src/main/java/org/apache/celeborn/common/network/server/ChunkStreamManager.java
+++ b/common/src/main/java/org/apache/celeborn/common/network/server/ChunkStreamManager.java
@@ -31,6 +31,7 @@ import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import org.apache.celeborn.common.meta.FileManagedBuffers;
+import org.apache.celeborn.common.meta.TimeWindow;
 import org.apache.celeborn.common.network.buffer.ManagedBuffer;
 
 /**
@@ -53,13 +54,15 @@ public class ChunkStreamManager {
     // Used to keep track of the index of the buffer that the user has retrieved, just to ensure
     // that the caller only requests each chunk one at a time, in order.
     int curChunk = 0;
+    final TimeWindow fetchTimeMetric;
 
     // Used to keep track of the number of chunks being transferred and not finished yet.
     volatile long chunksBeingTransferred = 0L;
 
-    StreamState(FileManagedBuffers buffers, Channel channel) {
+    StreamState(FileManagedBuffers buffers, Channel channel, TimeWindow fetchTimeMetric) {
       this.buffers = Preconditions.checkNotNull(buffers);
       this.associatedChannel = channel;
+      this.fetchTimeMetric = fetchTimeMetric;
     }
   }
 
@@ -99,6 +102,15 @@ public class ChunkStreamManager {
     return nextChunk;
   }
 
+  public TimeWindow getFetchTimeMetric(long streamId) {
+    StreamState state = streams.get(streamId);
+    if (state != null) {
+      return state.fetchTimeMetric;
+    } else {
+      return null;
+    }
+  }
+
   public static String genStreamChunkId(long streamId, int chunkId) {
     return String.format("%d_%d", streamId, chunkId);
   }
@@ -158,9 +170,10 @@ public class ChunkStreamManager {
    * to be the only reader of the stream. Once the connection is closed, the stream will never be
    * used again, enabling cleanup by `connectionTerminated`.
    */
-  public long registerStream(FileManagedBuffers buffers, Channel channel) {
+  public long registerStream(
+      FileManagedBuffers buffers, Channel channel, TimeWindow fetchTimeMetric) {
     long myStreamId = nextStreamId.getAndIncrement();
-    streams.put(myStreamId, new StreamState(buffers, channel));
+    streams.put(myStreamId, new StreamState(buffers, channel, fetchTimeMetric));
     return myStreamId;
   }
 
diff --git a/common/src/main/proto/TransportMessages.proto b/common/src/main/proto/TransportMessages.proto
index 07cd0db84..0c1377dec 100644
--- a/common/src/main/proto/TransportMessages.proto
+++ b/common/src/main/proto/TransportMessages.proto
@@ -104,6 +104,7 @@ message PbDiskInfo {
   int64 avgFlushTime = 3;
   int64 usedSlots = 4;
   int32 status = 5;
+  int64 avgFetchTime = 6;
 }
 
 message PbWorkerInfo {
diff --git a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
index f44651ec6..004428aa2 100644
--- a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
+++ b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
@@ -472,6 +472,10 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable with Logging with Se
   def slotsAssignLoadAwareDiskGroupNum: Int = get(SLOTS_ASSIGN_LOADAWARE_DISKGROUP_NUM)
   def slotsAssignLoadAwareDiskGroupGradient: Double =
     get(SLOTS_ASSIGN_LOADAWARE_DISKGROUP_GRADIENT)
+  def slotsAssignLoadAwareFlushTimeWeight: Double =
+    get(SLOTS_ASSIGN_LOADAWARE_FLUSHTIME_WEIGHT)
+  def slotsAssignLoadAwareFetchTimeWeight: Double =
+    get(SLOTS_ASSIGN_LOADAWARE_FETCHTIME_WEIGHT)
   def slotsAssignExtraSlots: Int = get(SLOTS_ASSIGN_EXTRA_SLOTS)
   def slotsAssignPolicy: SlotsAssignPolicy = SlotsAssignPolicy.valueOf(get(SLOTS_ASSIGN_POLICY))
   def initialEstimatedPartitionSize: Long = get(SHUFFLE_INITIAL_ESRIMATED_PARTITION_SIZE)
@@ -718,9 +722,11 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable with Logging with Se
   def hddFlusherThreads: Int = get(WORKER_FLUSHER_HDD_THREADS)
   def ssdFlusherThreads: Int = get(WORKER_FLUSHER_SSD_THREADS)
   def hdfsFlusherThreads: Int = get(WORKER_FLUSHER_HDFS_THREADS)
-  def avgFlushTimeSlidingWindowSize: Int = get(WORKER_FLUSHER_AVGFLUSHTIME_SLIDINGWINDOW_SIZE)
-  def avgFlushTimeSlidingWindowMinCount: Int =
-    get(WORKER_FLUSHER_AVGFLUSHTIME_SLIDINGWINDOW_MINCOUNT)
+  def diskTimeSlidingWindowSize: Int = get(WORKER_DISKTIME_SLIDINGWINDOW_SIZE)
+  def diskTimeSlidingWindowMinFlushCount: Int =
+    get(WORKER_DISKTIME_SLIDINGWINDOW_MINFLUSHCOUNT)
+  def diskTimeSlidingWindowMinFetchCount: Int =
+    get(WORKER_DISKTIME_SLIDINGWINDOW_MINFETCHCOUNT)
   def diskReserveSize: Long = get(WORKER_DISK_RESERVE_SIZE)
   def diskMonitorEnabled: Boolean = get(WORKER_DISK_MONITOR_ENABLED)
   def diskMonitorCheckList: Seq[String] = get(WORKER_DISK_MONITOR_CHECKLIST)
@@ -1948,25 +1954,37 @@ object CelebornConf extends Logging {
       .bytesConf(ByteUnit.BYTE)
       .createWithDefaultString("5G")
 
-  val WORKER_FLUSHER_AVGFLUSHTIME_SLIDINGWINDOW_SIZE: ConfigEntry[Int] =
-    buildConf("celeborn.worker.flusher.avgFlushTime.slidingWindow.size")
+  val WORKER_DISKTIME_SLIDINGWINDOW_SIZE: ConfigEntry[Int] =
+    buildConf("celeborn.worker.diskTime.slidingWindow.size")
+      .withAlternative("celeborn.worker.flusher.avgFlushTime.slidingWindow.size")
       .withAlternative("rss.flusher.avg.time.window")
       .categories("worker")
       .doc("The size of sliding windows used to calculate statistics about flushed time and count.")
-      .version("0.2.0")
+      .version("0.2.1")
       .intConf
       .createWithDefault(20)
 
-  val WORKER_FLUSHER_AVGFLUSHTIME_SLIDINGWINDOW_MINCOUNT: ConfigEntry[Int] =
-    buildConf("celeborn.worker.flusher.avgFlushTime.slidingWindow.minCount")
+  val WORKER_DISKTIME_SLIDINGWINDOW_MINFLUSHCOUNT: ConfigEntry[Int] =
+    buildConf("celeborn.worker.diskTime.slidingWindow.minFlushCount")
+      .withAlternative("celeborn.worker.flusher.avgFlushTime.slidingWindow.minCount")
       .withAlternative("rss.flusher.avg.time.minimum.count")
       .categories("worker")
       .doc("The minimum flush count to enter a sliding window" +
         " to calculate statistics about flushed time and count.")
-      .version("0.2.0")
+      .version("0.2.1")
       .internal
       .intConf
-      .createWithDefault(1000)
+      .createWithDefault(500)
+
+  val WORKER_DISKTIME_SLIDINGWINDOW_MINFETCHCOUNT: ConfigEntry[Int] =
+    buildConf("celeborn.worker.diskTime.slidingWindow.minFetchCount")
+      .categories("worker")
+      .doc("The minimum fetch count to enter a sliding window" +
+        " to calculate statistics about flushed time and count.")
+      .version("0.2.1")
+      .internal
+      .intConf
+      .createWithDefault(100)
 
   val SLOTS_ASSIGN_LOADAWARE_DISKGROUP_NUM: ConfigEntry[Int] =
     buildConf("celeborn.slots.assign.loadAware.numDiskGroups")
@@ -1988,6 +2006,24 @@ object CelebornConf extends Logging {
       .doubleConf
       .createWithDefault(0.1)
 
+  val SLOTS_ASSIGN_LOADAWARE_FLUSHTIME_WEIGHT: ConfigEntry[Double] =
+    buildConf("celeborn.slots.assign.loadAware.flushTimeWeight")
+      .categories("master")
+      .doc(
+        "Weight of average flush time when calculating ordering in load-aware assignment strategy")
+      .version("0.2.1")
+      .doubleConf
+      .createWithDefault(0)
+
+  val SLOTS_ASSIGN_LOADAWARE_FETCHTIME_WEIGHT: ConfigEntry[Double] =
+    buildConf("celeborn.slots.assign.loadAware.fetchTimeWeight")
+      .categories("master")
+      .doc(
+        "Weight of average fetch time when calculating ordering in load-aware assignment strategy")
+      .version("0.2.1")
+      .doubleConf
+      .createWithDefault(1)
+
   val SLOTS_ASSIGN_EXTRA_SLOTS: ConfigEntry[Int] =
     buildConf("celeborn.slots.assign.extraSlots")
       .withAlternative("rss.offer.slots.extra.size")
diff --git a/common/src/main/scala/org/apache/celeborn/common/meta/DeviceInfo.scala b/common/src/main/scala/org/apache/celeborn/common/meta/DeviceInfo.scala
index 13826f246..0561a6df4 100644
--- a/common/src/main/scala/org/apache/celeborn/common/meta/DeviceInfo.scala
+++ b/common/src/main/scala/org/apache/celeborn/common/meta/DeviceInfo.scala
@@ -26,6 +26,7 @@ import scala.collection.mutable.ListBuffer
 
 import org.slf4j.LoggerFactory
 
+import org.apache.celeborn.common.CelebornConf
 import org.apache.celeborn.common.internal.Logging
 import org.apache.celeborn.common.protocol.StorageInfo
 import org.apache.celeborn.common.util.Utils.runCommand
@@ -35,18 +36,34 @@ class DiskInfo(
     var actualUsableSpace: Long,
     // avgFlushTime is nano seconds
     var avgFlushTime: Long,
+    var avgFetchTime: Long,
     var activeSlots: Long,
     val dirs: List[File],
     val deviceInfo: DeviceInfo) extends Serializable with Logging {
 
-  def this(mountPoint: String, usableSpace: Long, avgFlushTime: Long, activeSlots: Long) = {
-    this(mountPoint, usableSpace, avgFlushTime, activeSlots, List.empty, null)
+  def this(
+      mountPoint: String,
+      usableSpace: Long,
+      avgFlushTime: Long,
+      avgFetchTime: Long,
+      activeSlots: Long) = {
+    this(mountPoint, usableSpace, avgFlushTime, avgFetchTime, activeSlots, List.empty, null)
   }
 
-  def this(mountPoint: String, dirs: List[File], deviceInfo: DeviceInfo) = {
-    this(mountPoint, 0, 0, 0, dirs, deviceInfo)
+  def this(
+      mountPoint: String,
+      dirs: List[File],
+      deviceInfo: DeviceInfo,
+      conf: CelebornConf) = {
+    this(mountPoint, 0, 0, 0, 0, dirs, deviceInfo)
+    flushTimeMetrics =
+      new TimeWindow(conf.diskTimeSlidingWindowSize, conf.diskTimeSlidingWindowMinFlushCount)
+    fetchTimeMetrics =
+      new TimeWindow(conf.diskTimeSlidingWindowSize, conf.diskTimeSlidingWindowMinFetchCount)
   }
 
+  var flushTimeMetrics: TimeWindow = _
+  var fetchTimeMetrics: TimeWindow = _
   var status: DiskStatus = DiskStatus.HEALTHY
   var threadCount = 1
   var configuredUsableSpace = 0L
@@ -64,9 +81,12 @@ class DiskInfo(
     this
   }
 
-  def setFlushTime(avgFlushTime: Long): this.type = this.synchronized {
-    this.avgFlushTime = avgFlushTime
-    this
+  def updateFlushTime(): Unit = {
+    avgFlushTime = flushTimeMetrics.getAverage()
+  }
+
+  def updateFetchTime(): Unit = {
+    avgFetchTime = fetchTimeMetrics.getAverage()
   }
 
   def availableSlots(): Long = this.synchronized {
@@ -112,6 +132,7 @@ class DiskInfo(
       s" mountPoint: $mountPoint," +
       s" usableSpace: $actualUsableSpace," +
       s" avgFlushTime: $avgFlushTime," +
+      s" avgFetchTime: $avgFetchTime," +
       s" activeSlots: $activeSlots)" +
       s" status: $status" +
       s" dirs ${dirs.mkString("\t")}"
@@ -145,12 +166,13 @@ object DeviceInfo {
 
   /**
    * @param workingDirs array of (workingDir, max usable space, flush thread count, storage type)
-   * @return it will return three maps
+   * @return it will return two maps
    *         (deviceName -> deviceInfo)
    *         (mount point -> diskInfo)
    */
-  def getDeviceAndDiskInfos(workingDirs: Seq[(File, Long, Int, StorageInfo.Type)])
-      : (util.Map[String, DeviceInfo], util.Map[String, DiskInfo]) = {
+  def getDeviceAndDiskInfos(
+      workingDirs: Seq[(File, Long, Int, StorageInfo.Type)],
+      conf: CelebornConf): (util.Map[String, DeviceInfo], util.Map[String, DiskInfo]) = {
     val deviceNameToDeviceInfo = new util.HashMap[String, DeviceInfo]()
     val mountPointToDeviceInfo = new util.HashMap[String, DeviceInfo]()
 
@@ -209,7 +231,11 @@ object DeviceInfo {
     }.foreach {
       case (mountPoint, dirs) =>
         val deviceInfo = mountPointToDeviceInfo.get(mountPoint)
-        val diskInfo = new DiskInfo(mountPoint, dirs.map(_._1).toList, deviceInfo)
+        val diskInfo = new DiskInfo(
+          mountPoint,
+          dirs.map(_._1).toList,
+          deviceInfo,
+          conf)
         val (_, maxUsableSpace, threadCount, storageType) = dirs(0)
         diskInfo.configuredUsableSpace = maxUsableSpace
         diskInfo.threadCount = threadCount
diff --git a/common/src/main/scala/org/apache/celeborn/common/meta/TimeWindow.scala b/common/src/main/scala/org/apache/celeborn/common/meta/TimeWindow.scala
new file mode 100644
index 000000000..04fa887b3
--- /dev/null
+++ b/common/src/main/scala/org/apache/celeborn/common/meta/TimeWindow.scala
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.common.meta
+
+import java.util.concurrent.atomic.LongAdder
+
+class TimeWindow(windowSize: Int, minWindowCount: Int) {
+  val totalCount = new LongAdder
+  val totalTime = new LongAdder
+  val timeWindow = new Array[(Long, Long)](windowSize)
+  var index = 0
+
+  for (i <- 0 until windowSize) {
+    timeWindow(i) = (0L, 0L)
+  }
+
+  def update(delta: Long): Unit = {
+    totalTime.add(delta)
+    totalCount.increment()
+  }
+
+  def getAverage(): Long = {
+    val currentTime = totalTime.sumThenReset()
+    val currentCount = totalCount.sumThenReset()
+
+    if (currentCount >= minWindowCount) {
+      timeWindow(index) = (currentTime, currentCount)
+      index = (index + 1) % windowSize
+    }
+
+    var time = 0L
+    var count = 0L
+    timeWindow.foreach { case (flushTime, flushCount) =>
+      time = time + flushTime
+      count = count + flushCount
+    }
+
+    if (count != 0) {
+      time / count
+    } else {
+      0L
+    }
+  }
+}
diff --git a/common/src/main/scala/org/apache/celeborn/common/meta/WorkerInfo.scala b/common/src/main/scala/org/apache/celeborn/common/meta/WorkerInfo.scala
index eabe9ae24..f1b605ba7 100644
--- a/common/src/main/scala/org/apache/celeborn/common/meta/WorkerInfo.scala
+++ b/common/src/main/scala/org/apache/celeborn/common/meta/WorkerInfo.scala
@@ -200,6 +200,7 @@ class WorkerInfo(
         curDisk.actualUsableSpace_$eq(newDisk.actualUsableSpace)
         curDisk.activeSlots_$eq(Math.max(curDisk.activeSlots, newDisk.activeSlots))
         curDisk.avgFlushTime_$eq(newDisk.avgFlushTime)
+        curDisk.avgFetchTime_$eq(newDisk.avgFetchTime)
         curDisk.maxSlots_$eq(curDisk.actualUsableSpace / estimatedPartitionSize)
         curDisk.setStatus(newDisk.status)
       } else {
diff --git a/common/src/main/scala/org/apache/celeborn/common/util/PbSerDeUtils.scala b/common/src/main/scala/org/apache/celeborn/common/util/PbSerDeUtils.scala
index e35725ab1..56f74251d 100644
--- a/common/src/main/scala/org/apache/celeborn/common/util/PbSerDeUtils.scala
+++ b/common/src/main/scala/org/apache/celeborn/common/util/PbSerDeUtils.scala
@@ -68,6 +68,7 @@ object PbSerDeUtils {
       pbDiskInfo.getMountPoint,
       pbDiskInfo.getUsableSpace,
       pbDiskInfo.getAvgFlushTime,
+      pbDiskInfo.getAvgFetchTime,
       pbDiskInfo.getUsedSlots)
       .setStatus(Utils.toDiskStatus(pbDiskInfo.getStatus))
 
@@ -76,6 +77,7 @@ object PbSerDeUtils {
       .setMountPoint(diskInfo.mountPoint)
       .setUsableSpace(diskInfo.actualUsableSpace)
       .setAvgFlushTime(diskInfo.avgFlushTime)
+      .setAvgFetchTime(diskInfo.avgFetchTime)
       .setUsedSlots(diskInfo.activeSlots)
       .setStatus(diskInfo.status.getValue)
       .build
diff --git a/common/src/test/java/org/apache/celeborn/common/network/server/ChunkStreamManagerSuiteJ.java b/common/src/test/java/org/apache/celeborn/common/network/server/ChunkStreamManagerSuiteJ.java
index 2fce016ee..cff1cc9f5 100644
--- a/common/src/test/java/org/apache/celeborn/common/network/server/ChunkStreamManagerSuiteJ.java
+++ b/common/src/test/java/org/apache/celeborn/common/network/server/ChunkStreamManagerSuiteJ.java
@@ -36,8 +36,8 @@ public class ChunkStreamManagerSuiteJ {
     FileManagedBuffers buffers2 = Mockito.mock(FileManagedBuffers.class);
 
     Channel dummyChannel = Mockito.mock(Channel.class, Mockito.RETURNS_SMART_NULLS);
-    manager.registerStream(buffers, dummyChannel);
-    manager.registerStream(buffers2, dummyChannel);
+    manager.registerStream(buffers, dummyChannel, null);
+    manager.registerStream(buffers2, dummyChannel, null);
 
     Assert.assertEquals(2, manager.numStreamStates());
 
diff --git a/common/src/test/scala/org/apache/celeborn/common/meta/TimeWindowSuite.scala b/common/src/test/scala/org/apache/celeborn/common/meta/TimeWindowSuite.scala
new file mode 100644
index 000000000..fdc41fa08
--- /dev/null
+++ b/common/src/test/scala/org/apache/celeborn/common/meta/TimeWindowSuite.scala
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.common.meta
+
+import org.apache.celeborn.RssFunSuite
+
+class TimeWindowSuite extends RssFunSuite {
+
+  test("test TimeWindow") {
+    val tw = new TimeWindow(2, 2)
+    tw.update(10)
+    tw.getAverage()
+    assert(tw.index == 0)
+    tw.update(10)
+    tw.update(10)
+    tw.getAverage()
+    assert(tw.index == 1)
+    assert(tw.timeWindow(0) == (20, 2))
+    tw.update(5)
+    tw.update(5)
+    tw.update(5)
+    val avg = tw.getAverage()
+    assert(tw.timeWindow(1) == (15, 3))
+    assert(tw.index == 0)
+    assert(avg == 7)
+  }
+}
diff --git a/common/src/test/scala/org/apache/celeborn/common/meta/WorkerInfoSuite.scala b/common/src/test/scala/org/apache/celeborn/common/meta/WorkerInfoSuite.scala
index 87c2af903..74d814bef 100644
--- a/common/src/test/scala/org/apache/celeborn/common/meta/WorkerInfoSuite.scala
+++ b/common/src/test/scala/org/apache/celeborn/common/meta/WorkerInfoSuite.scala
@@ -64,9 +64,9 @@ class WorkerInfoSuite extends RssFunSuite {
   test("multi-thread modify same WorkerInfo.") {
     val numSlots = 10000
     val disks = new util.HashMap[String, DiskInfo]()
-    disks.put("disk1", new DiskInfo("disk1", Int.MaxValue, 1, 0))
-    disks.put("disk2", new DiskInfo("disk2", Int.MaxValue, 1, 0))
-    disks.put("disk3", new DiskInfo("disk3", Int.MaxValue, 1, 0))
+    disks.put("disk1", new DiskInfo("disk1", Int.MaxValue, 1, 1, 0))
+    disks.put("disk2", new DiskInfo("disk2", Int.MaxValue, 1, 1, 0))
+    disks.put("disk3", new DiskInfo("disk3", Int.MaxValue, 1, 1, 0))
     val userResourceConsumption = new ConcurrentHashMap[UserIdentifier, ResourceConsumption]()
     userResourceConsumption.put(UserIdentifier("tenant1", "name1"), ResourceConsumption(1, 1, 1, 1))
     val worker =
@@ -224,9 +224,9 @@ class WorkerInfoSuite extends RssFunSuite {
       null)
 
     val disks = new util.HashMap[String, DiskInfo]()
-    disks.put("disk1", new DiskInfo("disk1", Int.MaxValue, 1, 10))
-    disks.put("disk2", new DiskInfo("disk2", Int.MaxValue, 2, 20))
-    disks.put("disk3", new DiskInfo("disk3", Int.MaxValue, 3, 30))
+    disks.put("disk1", new DiskInfo("disk1", Int.MaxValue, 1, 1, 10))
+    disks.put("disk2", new DiskInfo("disk2", Int.MaxValue, 2, 2, 20))
+    disks.put("disk3", new DiskInfo("disk3", Int.MaxValue, 3, 3, 30))
     val userResourceConsumption = new ConcurrentHashMap[UserIdentifier, ResourceConsumption]()
     userResourceConsumption.put(
       UserIdentifier("tenant1", "name1"),
@@ -297,9 +297,9 @@ class WorkerInfoSuite extends RssFunSuite {
          |SlotsUsed: 60
          |LastHeartbeat: 0
          |Disks: $placeholder
-         |  DiskInfo0: DiskInfo(maxSlots: 0, committed shuffles 0 shuffleAllocations: Map(), mountPoint: disk3, usableSpace: 2147483647, avgFlushTime: 3, activeSlots: 30) status: HEALTHY dirs $placeholder
-         |  DiskInfo1: DiskInfo(maxSlots: 0, committed shuffles 0 shuffleAllocations: Map(), mountPoint: disk1, usableSpace: 2147483647, avgFlushTime: 1, activeSlots: 10) status: HEALTHY dirs $placeholder
-         |  DiskInfo2: DiskInfo(maxSlots: 0, committed shuffles 0 shuffleAllocations: Map(), mountPoint: disk2, usableSpace: 2147483647, avgFlushTime: 2, activeSlots: 20) status: HEALTHY dirs $placeholder
+         |  DiskInfo0: DiskInfo(maxSlots: 0, committed shuffles 0 shuffleAllocations: Map(), mountPoint: disk3, usableSpace: 2147483647, avgFlushTime: 3, avgFetchTime: 3, activeSlots: 30) status: HEALTHY dirs $placeholder
+         |  DiskInfo1: DiskInfo(maxSlots: 0, committed shuffles 0 shuffleAllocations: Map(), mountPoint: disk1, usableSpace: 2147483647, avgFlushTime: 1, avgFetchTime: 1, activeSlots: 10) status: HEALTHY dirs $placeholder
+         |  DiskInfo2: DiskInfo(maxSlots: 0, committed shuffles 0 shuffleAllocations: Map(), mountPoint: disk2, usableSpace: 2147483647, avgFlushTime: 2, avgFetchTime: 2, activeSlots: 20) status: HEALTHY dirs $placeholder
          |UserResourceConsumption: $placeholder
          |  UserIdentifier: `tenant1`.`name1`, ResourceConsumption: ResourceConsumption(diskBytesWritten: 20.0 MB, diskFileCount: 1, hdfsBytesWritten: 50.0 MB, hdfsFileCount: 1)
          |WorkerRef: NettyRpcEndpointRef(rss://mockRpc@localhost:12345)
diff --git a/common/src/test/scala/org/apache/celeborn/common/util/PbSerDeUtilsTest.scala b/common/src/test/scala/org/apache/celeborn/common/util/PbSerDeUtilsTest.scala
index 7eb1eb2a5..9acc7ccb8 100644
--- a/common/src/test/scala/org/apache/celeborn/common/util/PbSerDeUtilsTest.scala
+++ b/common/src/test/scala/org/apache/celeborn/common/util/PbSerDeUtilsTest.scala
@@ -43,8 +43,8 @@ class PbSerDeUtilsTest extends RssFunSuite {
   val files = List(file1, file2)
 
   val device = new DeviceInfo("device-a")
-  val diskInfo1 = new DiskInfo("/mnt/disk/0", 1000, 1000, 1000, files, device)
-  val diskInfo2 = new DiskInfo("/mnt/disk/1", 2000, 2000, 2000, files, device)
+  val diskInfo1 = new DiskInfo("/mnt/disk/0", 1000, 1000, 1000, 1000, files, device)
+  val diskInfo2 = new DiskInfo("/mnt/disk/1", 2000, 2000, 2000, 2000, files, device)
   val diskInfos = new util.HashMap[String, DiskInfo]()
   diskInfos.put("disk1", diskInfo1)
   diskInfos.put("disk2", diskInfo2)
@@ -108,6 +108,7 @@ class PbSerDeUtilsTest extends RssFunSuite {
     assert(restoredDiskInfo.mountPoint.equals(diskInfo1.mountPoint))
     assert(restoredDiskInfo.actualUsableSpace.equals(diskInfo1.actualUsableSpace))
     assert(restoredDiskInfo.avgFlushTime.equals(diskInfo1.avgFlushTime))
+    assert(restoredDiskInfo.avgFetchTime.equals(diskInfo1.avgFetchTime))
     assert(restoredDiskInfo.activeSlots.equals(diskInfo1.activeSlots))
     assert(restoredDiskInfo.dirs.equals(List.empty))
     assert(restoredDiskInfo.deviceInfo == null)
diff --git a/docs/configuration/master.md b/docs/configuration/master.md
index 59adaa0c7..4ed10dfe4 100644
--- a/docs/configuration/master.md
+++ b/docs/configuration/master.md
@@ -43,6 +43,8 @@ license: |
 | celeborn.shuffle.initialEstimatedPartitionSize | 64mb | Initial partition size for estimation, it will change according to runtime stats. | 0.2.0 | 
 | celeborn.slots.assign.extraSlots | 2 | Extra slots number when master assign slots. | 0.2.0 | 
 | celeborn.slots.assign.loadAware.diskGroupGradient | 0.1 | This value means how many more workload will be placed into a faster disk group than a slower group. | 0.2.0 | 
+| celeborn.slots.assign.loadAware.fetchTimeWeight | 1.0 | Weight of average fetch time when calculating ordering in load-aware assignment strategy | 0.2.1 | 
+| celeborn.slots.assign.loadAware.flushTimeWeight | 0.0 | Weight of average flush time when calculating ordering in load-aware assignment strategy | 0.2.1 | 
 | celeborn.slots.assign.loadAware.numDiskGroups | 5 | This configuration is a guidance for load-aware slot allocation algorithm. This value is control how many disk groups will be created. | 0.2.0 | 
 | celeborn.slots.assign.policy | ROUNDROBIN | Policy for master to assign slots, Celeborn supports two types of policy: roundrobin and loadaware. | 0.2.0 | 
 | celeborn.worker.heartbeat.timeout | 120s | Worker heartbeat timeout. | 0.2.0 | 
diff --git a/docs/configuration/worker.md b/docs/configuration/worker.md
index 8fe879780..0a93f5ab1 100644
--- a/docs/configuration/worker.md
+++ b/docs/configuration/worker.md
@@ -42,9 +42,9 @@ license: |
 | celeborn.worker.disk.checkFileClean.maxRetries | 3 | The number of retries for a worker to check if the working directory is cleaned up before registering with the master. | 0.2.0 | 
 | celeborn.worker.disk.checkFileClean.timeout | 1000ms | The wait time per retry for a worker to check if the working directory is cleaned up before registering with the master. | 0.2.0 | 
 | celeborn.worker.disk.reserve.size | 5G | Celeborn worker reserved space for each disk. | 0.2.0 | 
+| celeborn.worker.diskTime.slidingWindow.size | 20 | The size of sliding windows used to calculate statistics about flushed time and count. | 0.2.1 | 
 | celeborn.worker.fetch.io.threads | 16 | Netty IO thread number of worker to handle client fetch data. The default threads number is 16. | 0.2.0 | 
 | celeborn.worker.fetch.port | 0 | Server port for Worker to receive fetch data request from ShuffleClient. | 0.2.0 | 
-| celeborn.worker.flusher.avgFlushTime.slidingWindow.size | 20 | The size of sliding windows used to calculate statistics about flushed time and count. | 0.2.0 | 
 | celeborn.worker.flusher.buffer.size | 256k | Size of buffer used by a single flusher. | 0.2.0 | 
 | celeborn.worker.flusher.hdd.threads | 1 | Flusher's thread count per disk used for write data to HDD disks. | 0.2.0 | 
 | celeborn.worker.flusher.hdfs.threads | 4 | Flusher's thread count used for write data to HDFS. | 0.2.0 | 
diff --git a/master/src/main/java/org/apache/celeborn/service/deploy/master/SlotsAllocator.java b/master/src/main/java/org/apache/celeborn/service/deploy/master/SlotsAllocator.java
index c41f34a9d..a22ef7656 100644
--- a/master/src/main/java/org/apache/celeborn/service/deploy/master/SlotsAllocator.java
+++ b/master/src/main/java/org/apache/celeborn/service/deploy/master/SlotsAllocator.java
@@ -90,7 +90,9 @@ public class SlotsAllocator {
           boolean shouldReplicate,
           long minimumUsableSize,
           int diskGroupCount,
-          double diskGroupGradient) {
+          double diskGroupGradient,
+          double flushTimeWeight,
+          double fetchTimeWeight) {
     if (partitionIds.isEmpty()) {
       return new HashMap<>();
     }
@@ -130,7 +132,7 @@ public class SlotsAllocator {
 
     Map<WorkerInfo, List<UsableDiskInfo>> restriction =
         getRestriction(
-            placeDisksToGroups(usableDisks, diskGroupCount),
+            placeDisksToGroups(usableDisks, diskGroupCount, flushTimeWeight, fetchTimeWeight),
             diskToWorkerMap,
             shouldReplicate ? partitionIds.size() * 2 : partitionIds.size());
 
@@ -254,9 +256,18 @@ public class SlotsAllocator {
   }
 
   private static List<List<DiskInfo>> placeDisksToGroups(
-      List<DiskInfo> usableDisks, int diskGroupCount) {
+      List<DiskInfo> usableDisks,
+      int diskGroupCount,
+      double flushTimeWeight,
+      double fetchTimeWeight) {
     List<List<DiskInfo>> diskGroups = new ArrayList<>();
-    usableDisks.sort((o1, o2) -> Math.toIntExact(o1.avgFlushTime() - o2.avgFlushTime()));
+    usableDisks.sort(
+        (o1, o2) ->
+            Math.toIntExact(
+                (long)
+                    ((o1.avgFlushTime() * flushTimeWeight + o1.avgFetchTime() * fetchTimeWeight)
+                        - (o2.avgFlushTime() * flushTimeWeight
+                            + o2.avgFetchTime() * fetchTimeWeight))));
     int diskCount = usableDisks.size();
     int startIndex = 0;
     int groupSizeSize = (int) Math.ceil(usableDisks.size() / (double) diskGroupCount);
@@ -363,6 +374,8 @@ public class SlotsAllocator {
               .append(diskInfo.mountPoint())
               .append(" flushtime:")
               .append(diskInfo.avgFlushTime())
+              .append(" fetchtime:")
+              .append(diskInfo.avgFetchTime())
               .append(" allocation: ")
               .append(allocation)
               .append(" ");
diff --git a/master/src/main/java/org/apache/celeborn/service/deploy/master/clustermeta/MetaUtil.java b/master/src/main/java/org/apache/celeborn/service/deploy/master/clustermeta/MetaUtil.java
index d8a5d811c..ec3094fc2 100644
--- a/master/src/main/java/org/apache/celeborn/service/deploy/master/clustermeta/MetaUtil.java
+++ b/master/src/main/java/org/apache/celeborn/service/deploy/master/clustermeta/MetaUtil.java
@@ -57,7 +57,11 @@ public class MetaUtil {
         (k, v) -> {
           DiskInfo diskInfo =
               new DiskInfo(
-                  v.getMountPoint(), v.getUsableSpace(), v.getAvgFlushTime(), v.getUsedSlots());
+                  v.getMountPoint(),
+                  v.getUsableSpace(),
+                  v.getAvgFlushTime(),
+                  v.getAvgFetchTime(),
+                  v.getUsedSlots());
           diskInfo.setStatus(Utils.toDiskStatus(v.getStatus()));
           map.put(k, diskInfo);
         });
@@ -75,6 +79,7 @@ public class MetaUtil {
                     .setMountPoint(v.mountPoint())
                     .setUsableSpace(v.actualUsableSpace())
                     .setAvgFlushTime(v.avgFlushTime())
+                    .setAvgFetchTime(v.avgFetchTime())
                     .setUsedSlots(v.activeSlots())
                     .setStatus(v.status().getValue())
                     .build()));
diff --git a/master/src/main/proto/Resource.proto b/master/src/main/proto/Resource.proto
index d07332188..b639fc98d 100644
--- a/master/src/main/proto/Resource.proto
+++ b/master/src/main/proto/Resource.proto
@@ -57,6 +57,7 @@ message DiskInfo {
   required int64 avgFlushTime = 3;
   required int64 usedSlots = 4;
   required int32 status = 5;
+  required int64 avgFetchTime = 6;
 }
 
 message RequestSlotsRequest {
diff --git a/master/src/main/scala/org/apache/celeborn/service/deploy/master/Master.scala b/master/src/main/scala/org/apache/celeborn/service/deploy/master/Master.scala
index 4be15bcf3..c027aca83 100644
--- a/master/src/main/scala/org/apache/celeborn/service/deploy/master/Master.scala
+++ b/master/src/main/scala/org/apache/celeborn/service/deploy/master/Master.scala
@@ -104,9 +104,10 @@ private[celeborn] class Master(
 
   private def diskReserveSize = conf.diskReserveSize
 
-  private def slotsAssignLoadAwareDiskGroupNum = conf.slotsAssignLoadAwareDiskGroupNum
-
-  private def slotsAssignLoadAwareDiskGroupGradient = conf.slotsAssignLoadAwareDiskGroupGradient
+  private val slotsAssignLoadAwareDiskGroupNum = conf.slotsAssignLoadAwareDiskGroupNum
+  private val slotsAssignLoadAwareDiskGroupGradient = conf.slotsAssignLoadAwareDiskGroupGradient
+  private val loadAwareFlushTimeWeight = conf.slotsAssignLoadAwareFlushTimeWeight
+  private val loadAwareFetchTimeWeight = conf.slotsAssignLoadAwareFetchTimeWeight
 
   private val estimatedPartitionSizeUpdaterInitialDelay =
     conf.estimatedPartitionSizeUpdaterInitialDelay
@@ -510,7 +511,9 @@ private[celeborn] class Master(
               requestSlots.shouldReplicate,
               diskReserveSize,
               slotsAssignLoadAwareDiskGroupNum,
-              slotsAssignLoadAwareDiskGroupGradient)
+              slotsAssignLoadAwareDiskGroupGradient,
+              loadAwareFlushTimeWeight,
+              loadAwareFetchTimeWeight)
           }
         }
       }
diff --git a/master/src/test/java/org/apache/celeborn/service/deploy/master/SlotsAllocatorSuiteJ.java b/master/src/test/java/org/apache/celeborn/service/deploy/master/SlotsAllocatorSuiteJ.java
index f60e60abb..9617dcf43 100644
--- a/master/src/test/java/org/apache/celeborn/service/deploy/master/SlotsAllocatorSuiteJ.java
+++ b/master/src/test/java/org/apache/celeborn/service/deploy/master/SlotsAllocatorSuiteJ.java
@@ -45,13 +45,25 @@ public class SlotsAllocatorSuiteJ {
     Map<String, DiskInfo> disks1 = new HashMap<>();
     DiskInfo diskInfo1 =
         new DiskInfo(
-            "/mnt/disk1", random.nextInt() + 100 * 1024 * 1024 * 1024L, random.nextInt(1000), 0);
+            "/mnt/disk1",
+            random.nextInt() + 100 * 1024 * 1024 * 1024L,
+            random.nextInt(1000),
+            random.nextInt(1000),
+            0);
     DiskInfo diskInfo2 =
         new DiskInfo(
-            "/mnt/disk2", random.nextInt() + 95 * 1024 * 1024 * 1024L, random.nextInt(1000), 0);
+            "/mnt/disk2",
+            random.nextInt() + 95 * 1024 * 1024 * 1024L,
+            random.nextInt(1000),
+            random.nextInt(1000),
+            0);
     DiskInfo diskInfo3 =
         new DiskInfo(
-            "/mnt/disk3", random.nextInt() + 90 * 1024 * 1024 * 1024L, random.nextInt(1000), 0);
+            "/mnt/disk3",
+            random.nextInt() + 90 * 1024 * 1024 * 1024L,
+            random.nextInt(1000),
+            random.nextInt(1000),
+            0);
     diskInfo1.maxSlots_$eq(diskInfo1.actualUsableSpace() / assumedPartitionSize);
     diskInfo2.maxSlots_$eq(diskInfo2.actualUsableSpace() / assumedPartitionSize);
     diskInfo3.maxSlots_$eq(diskInfo3.actualUsableSpace() / assumedPartitionSize);
@@ -62,13 +74,25 @@ public class SlotsAllocatorSuiteJ {
     Map<String, DiskInfo> disks2 = new HashMap<>();
     DiskInfo diskInfo4 =
         new DiskInfo(
-            "/mnt/disk1", random.nextInt() + 100 * 1024 * 1024 * 1024L, random.nextInt(1000), 0);
+            "/mnt/disk1",
+            random.nextInt() + 100 * 1024 * 1024 * 1024L,
+            random.nextInt(1000),
+            random.nextInt(1000),
+            0);
     DiskInfo diskInfo5 =
         new DiskInfo(
-            "/mnt/disk2", random.nextInt() + 95 * 1024 * 1024 * 1024L, random.nextInt(1000), 0);
+            "/mnt/disk2",
+            random.nextInt() + 95 * 1024 * 1024 * 1024L,
+            random.nextInt(1000),
+            random.nextInt(1000),
+            0);
     DiskInfo diskInfo6 =
         new DiskInfo(
-            "/mnt/disk3", random.nextInt() + 90 * 1024 * 1024 * 1024L, random.nextInt(1000), 0);
+            "/mnt/disk3",
+            random.nextInt() + 90 * 1024 * 1024 * 1024L,
+            random.nextInt(1000),
+            random.nextInt(1000),
+            0);
     diskInfo4.maxSlots_$eq(diskInfo4.actualUsableSpace() / assumedPartitionSize);
     diskInfo5.maxSlots_$eq(diskInfo5.actualUsableSpace() / assumedPartitionSize);
     diskInfo6.maxSlots_$eq(diskInfo6.actualUsableSpace() / assumedPartitionSize);
@@ -79,13 +103,25 @@ public class SlotsAllocatorSuiteJ {
     Map<String, DiskInfo> disks3 = new HashMap<>();
     DiskInfo diskInfo7 =
         new DiskInfo(
-            "/mnt/disk1", random.nextInt() + 100 * 1024 * 1024 * 1024L, random.nextInt(1000), 0);
+            "/mnt/disk1",
+            random.nextInt() + 100 * 1024 * 1024 * 1024L,
+            random.nextInt(1000),
+            random.nextInt(1000),
+            0);
     DiskInfo diskInfo8 =
         new DiskInfo(
-            "/mnt/disk2", random.nextInt() + 95 * 1024 * 1024 * 1024L, random.nextInt(1000), 0);
+            "/mnt/disk2",
+            random.nextInt() + 95 * 1024 * 1024 * 1024L,
+            random.nextInt(1000),
+            random.nextInt(1000),
+            0);
     DiskInfo diskInfo9 =
         new DiskInfo(
-            "/mnt/disk3", random.nextInt() + 90 * 1024 * 1024 * 1024L, random.nextInt(1000), 0);
+            "/mnt/disk3",
+            random.nextInt() + 90 * 1024 * 1024 * 1024L,
+            random.nextInt(1000),
+            random.nextInt(1000),
+            0);
     diskInfo7.maxSlots_$eq(diskInfo7.actualUsableSpace() / assumedPartitionSize);
     diskInfo8.maxSlots_$eq(diskInfo8.actualUsableSpace() / assumedPartitionSize);
     diskInfo9.maxSlots_$eq(diskInfo9.actualUsableSpace() / assumedPartitionSize);
@@ -194,7 +230,9 @@ public class SlotsAllocatorSuiteJ {
             shouldReplicate,
             10 * 1024 * 1024 * 1024L,
             conf.slotsAssignLoadAwareDiskGroupNum(),
-            conf.slotsAssignLoadAwareDiskGroupGradient());
+            conf.slotsAssignLoadAwareDiskGroupGradient(),
+            conf.slotsAssignLoadAwareFlushTimeWeight(),
+            conf.slotsAssignLoadAwareFetchTimeWeight());
     if (expectSuccess) {
       if (shouldReplicate) {
         slots.forEach(
diff --git a/master/src/test/java/org/apache/celeborn/service/deploy/master/clustermeta/DefaultMetaSystemSuiteJ.java b/master/src/test/java/org/apache/celeborn/service/deploy/master/clustermeta/DefaultMetaSystemSuiteJ.java
index 6a4b81936..2913d50ad 100644
--- a/master/src/test/java/org/apache/celeborn/service/deploy/master/clustermeta/DefaultMetaSystemSuiteJ.java
+++ b/master/src/test/java/org/apache/celeborn/service/deploy/master/clustermeta/DefaultMetaSystemSuiteJ.java
@@ -83,22 +83,22 @@ public class DefaultMetaSystemSuiteJ {
     statusSystem = new SingleMasterMetaManager(mockRpcEnv, conf);
 
     disks1.clear();
-    disks1.put("disk1", new DiskInfo("disk1", 64 * 1024 * 1024 * 1024L, 100, 0));
-    disks1.put("disk2", new DiskInfo("disk2", 64 * 1024 * 1024 * 1024L, 100, 0));
-    disks1.put("disk3", new DiskInfo("disk3", 64 * 1024 * 1024 * 1024L, 100, 0));
-    disks1.put("disk4", new DiskInfo("disk4", 64 * 1024 * 1024 * 1024L, 100, 0));
+    disks1.put("disk1", new DiskInfo("disk1", 64 * 1024 * 1024 * 1024L, 100, 100, 0));
+    disks1.put("disk2", new DiskInfo("disk2", 64 * 1024 * 1024 * 1024L, 100, 100, 0));
+    disks1.put("disk3", new DiskInfo("disk3", 64 * 1024 * 1024 * 1024L, 100, 100, 0));
+    disks1.put("disk4", new DiskInfo("disk4", 64 * 1024 * 1024 * 1024L, 100, 100, 0));
 
     disks2.clear();
-    disks2.put("disk1", new DiskInfo("disk1", 64 * 1024 * 1024 * 1024L, 100, 0));
-    disks2.put("disk2", new DiskInfo("disk2", 64 * 1024 * 1024 * 1024L, 100, 0));
-    disks2.put("disk3", new DiskInfo("disk3", 64 * 1024 * 1024 * 1024L, 100, 0));
-    disks2.put("disk4", new DiskInfo("disk4", 64 * 1024 * 1024 * 1024L, 100, 0));
+    disks2.put("disk1", new DiskInfo("disk1", 64 * 1024 * 1024 * 1024L, 100, 100, 0));
+    disks2.put("disk2", new DiskInfo("disk2", 64 * 1024 * 1024 * 1024L, 100, 100, 0));
+    disks2.put("disk3", new DiskInfo("disk3", 64 * 1024 * 1024 * 1024L, 100, 100, 0));
+    disks2.put("disk4", new DiskInfo("disk4", 64 * 1024 * 1024 * 1024L, 100, 100, 0));
 
     disks3.clear();
-    disks3.put("disk1", new DiskInfo("disk1", 64 * 1024 * 1024 * 1024L, 100, 0));
-    disks3.put("disk2", new DiskInfo("disk2", 64 * 1024 * 1024 * 1024L, 100, 0));
-    disks3.put("disk3", new DiskInfo("disk3", 64 * 1024 * 1024 * 1024L, 100, 0));
-    disks3.put("disk4", new DiskInfo("disk4", 64 * 1024 * 1024 * 1024L, 100, 0));
+    disks3.put("disk1", new DiskInfo("disk1", 64 * 1024 * 1024 * 1024L, 100, 100, 0));
+    disks3.put("disk2", new DiskInfo("disk2", 64 * 1024 * 1024 * 1024L, 100, 100, 0));
+    disks3.put("disk3", new DiskInfo("disk3", 64 * 1024 * 1024 * 1024L, 100, 100, 0));
+    disks3.put("disk4", new DiskInfo("disk4", 64 * 1024 * 1024 * 1024L, 100, 100, 0));
   }
 
   @After
diff --git a/master/src/test/java/org/apache/celeborn/service/deploy/master/clustermeta/ha/MasterStateMachineSuiteJ.java b/master/src/test/java/org/apache/celeborn/service/deploy/master/clustermeta/ha/MasterStateMachineSuiteJ.java
index afc4f009c..7c090c55e 100644
--- a/master/src/test/java/org/apache/celeborn/service/deploy/master/clustermeta/ha/MasterStateMachineSuiteJ.java
+++ b/master/src/test/java/org/apache/celeborn/service/deploy/master/clustermeta/ha/MasterStateMachineSuiteJ.java
@@ -207,9 +207,9 @@ public class MasterStateMachineSuiteJ extends RatisBaseSuiteJ {
     File tmpFile = File.createTempFile("tef", "test" + System.currentTimeMillis());
 
     Map<String, DiskInfo> disks1 = new HashMap<>();
-    disks1.put("disk1", new DiskInfo("disk1", 64 * 1024 * 1024 * 1024, 100, 0));
-    disks1.put("disk2", new DiskInfo("disk2", 64 * 1024 * 1024 * 1024, 100, 0));
-    disks1.put("disk3", new DiskInfo("disk3", 64 * 1024 * 1024 * 1024, 100, 0));
+    disks1.put("disk1", new DiskInfo("disk1", 64 * 1024 * 1024 * 1024, 100, 100, 0));
+    disks1.put("disk2", new DiskInfo("disk2", 64 * 1024 * 1024 * 1024, 100, 100, 0));
+    disks1.put("disk3", new DiskInfo("disk3", 64 * 1024 * 1024 * 1024, 100, 100, 0));
     Map<UserIdentifier, ResourceConsumption> userResourceConsumption1 = new ConcurrentHashMap<>();
     userResourceConsumption1.put(
         new UserIdentifier("tenant1", "name1"), new ResourceConsumption(1000, 1, 1000, 1));
@@ -219,9 +219,9 @@ public class MasterStateMachineSuiteJ extends RatisBaseSuiteJ {
         new UserIdentifier("tenant1", "name3"), new ResourceConsumption(3000, 3, 3000, 3));
 
     Map<String, DiskInfo> disks2 = new HashMap<>();
-    disks2.put("disk1", new DiskInfo("disk1", 64 * 1024 * 1024 * 1024, 100, 0));
-    disks2.put("disk2", new DiskInfo("disk2", 64 * 1024 * 1024 * 1024, 100, 0));
-    disks2.put("disk3", new DiskInfo("disk3", 64 * 1024 * 1024 * 1024, 100, 0));
+    disks2.put("disk1", new DiskInfo("disk1", 64 * 1024 * 1024 * 1024, 100, 100, 0));
+    disks2.put("disk2", new DiskInfo("disk2", 64 * 1024 * 1024 * 1024, 100, 100, 0));
+    disks2.put("disk3", new DiskInfo("disk3", 64 * 1024 * 1024 * 1024, 100, 100, 0));
     Map<UserIdentifier, ResourceConsumption> userResourceConsumption2 = new ConcurrentHashMap<>();
     userResourceConsumption2.put(
         new UserIdentifier("tenant2", "name1"), new ResourceConsumption(1000, 1, 1000, 1));
@@ -231,9 +231,9 @@ public class MasterStateMachineSuiteJ extends RatisBaseSuiteJ {
         new UserIdentifier("tenant2", "name3"), new ResourceConsumption(3000, 3, 3000, 3));
 
     Map<String, DiskInfo> disks3 = new HashMap<>();
-    disks3.put("disk1", new DiskInfo("disk1", 64 * 1024 * 1024 * 1024, 100, 0));
-    disks3.put("disk2", new DiskInfo("disk2", 64 * 1024 * 1024 * 1024, 100, 0));
-    disks3.put("disk3", new DiskInfo("disk3", 64 * 1024 * 1024 * 1024, 100, 0));
+    disks3.put("disk1", new DiskInfo("disk1", 64 * 1024 * 1024 * 1024, 100, 100, 0));
+    disks3.put("disk2", new DiskInfo("disk2", 64 * 1024 * 1024 * 1024, 100, 100, 0));
+    disks3.put("disk3", new DiskInfo("disk3", 64 * 1024 * 1024 * 1024, 100, 100, 0));
     Map<UserIdentifier, ResourceConsumption> userResourceConsumption3 = new ConcurrentHashMap<>();
     userResourceConsumption3.put(
         new UserIdentifier("tenant3", "name1"), new ResourceConsumption(1000, 1, 1000, 1));
diff --git a/master/src/test/java/org/apache/celeborn/service/deploy/master/clustermeta/ha/RatisMasterStatusSystemSuiteJ.java b/master/src/test/java/org/apache/celeborn/service/deploy/master/clustermeta/ha/RatisMasterStatusSystemSuiteJ.java
index 7414cdd71..34bdd8b58 100644
--- a/master/src/test/java/org/apache/celeborn/service/deploy/master/clustermeta/ha/RatisMasterStatusSystemSuiteJ.java
+++ b/master/src/test/java/org/apache/celeborn/service/deploy/master/clustermeta/ha/RatisMasterStatusSystemSuiteJ.java
@@ -765,22 +765,22 @@ public class RatisMasterStatusSystemSuiteJ {
     STATUSSYSTEM3.workerLostEvents.clear();
 
     disks1.clear();
-    disks1.put("disk1", new DiskInfo("disk1", 64 * 1024 * 1024 * 1024L, 100, 0));
-    disks1.put("disk2", new DiskInfo("disk2", 64 * 1024 * 1024 * 1024L, 100, 0));
-    disks1.put("disk3", new DiskInfo("disk3", 64 * 1024 * 1024 * 1024L, 100, 0));
-    disks1.put("disk4", new DiskInfo("disk4", 64 * 1024 * 1024 * 1024L, 100, 0));
+    disks1.put("disk1", new DiskInfo("disk1", 64 * 1024 * 1024 * 1024L, 100, 100, 0));
+    disks1.put("disk2", new DiskInfo("disk2", 64 * 1024 * 1024 * 1024L, 100, 100, 0));
+    disks1.put("disk3", new DiskInfo("disk3", 64 * 1024 * 1024 * 1024L, 100, 100, 0));
+    disks1.put("disk4", new DiskInfo("disk4", 64 * 1024 * 1024 * 1024L, 100, 100, 0));
 
     disks2.clear();
-    disks2.put("disk1", new DiskInfo("disk1", 64 * 1024 * 1024 * 1024L, 100, 0));
-    disks2.put("disk2", new DiskInfo("disk2", 64 * 1024 * 1024 * 1024L, 100, 0));
-    disks2.put("disk3", new DiskInfo("disk3", 64 * 1024 * 1024 * 1024L, 100, 0));
-    disks2.put("disk4", new DiskInfo("disk4", 64 * 1024 * 1024 * 1024L, 100, 0));
+    disks2.put("disk1", new DiskInfo("disk1", 64 * 1024 * 1024 * 1024L, 100, 100, 0));
+    disks2.put("disk2", new DiskInfo("disk2", 64 * 1024 * 1024 * 1024L, 100, 100, 0));
+    disks2.put("disk3", new DiskInfo("disk3", 64 * 1024 * 1024 * 1024L, 100, 100, 0));
+    disks2.put("disk4", new DiskInfo("disk4", 64 * 1024 * 1024 * 1024L, 100, 100, 0));
 
     disks3.clear();
-    disks3.put("disk1", new DiskInfo("disk1", 64 * 1024 * 1024 * 1024L, 100, 0));
-    disks3.put("disk2", new DiskInfo("disk2", 64 * 1024 * 1024 * 1024L, 100, 0));
-    disks3.put("disk3", new DiskInfo("disk3", 64 * 1024 * 1024 * 1024L, 100, 0));
-    disks3.put("disk4", new DiskInfo("disk4", 64 * 1024 * 1024 * 1024L, 100, 0));
+    disks3.put("disk1", new DiskInfo("disk1", 64 * 1024 * 1024 * 1024L, 100, 100, 0));
+    disks3.put("disk2", new DiskInfo("disk2", 64 * 1024 * 1024 * 1024L, 100, 100, 0));
+    disks3.put("disk3", new DiskInfo("disk3", 64 * 1024 * 1024 * 1024L, 100, 100, 0));
+    disks3.put("disk4", new DiskInfo("disk4", 64 * 1024 * 1024 * 1024L, 100, 100, 0));
   }
 
   @Test
diff --git a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala
index 62b846deb..7a3e7ed55 100644
--- a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala
+++ b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala
@@ -107,7 +107,9 @@ class FetchHandler(val conf: TransportConf) extends BaseMessageHandler with Logg
               new NioManagedBuffer(streamHandle.toByteBuffer)))
           } else {
             val buffers = new FileManagedBuffers(fileInfo, conf)
-            val streamId = chunkStreamManager.registerStream(buffers, client.getChannel)
+            val fetchTimeMetrics = storageManager.getFetchTimeMetric(fileInfo.getFile)
+            val streamId =
+              chunkStreamManager.registerStream(buffers, client.getChannel, fetchTimeMetrics)
             val streamHandle = new StreamHandle(streamId, fileInfo.numChunks())
             if (fileInfo.numChunks() == 0)
               logDebug(s"StreamId $streamId fileName $fileName startMapIndex" +
@@ -142,7 +144,6 @@ class FetchHandler(val conf: TransportConf) extends BaseMessageHandler with Logg
   }
 
   def handleChunkFetchRequest(client: TransportClient, req: ChunkFetchRequest): Unit = {
-    workerSource.startTimer(WorkerSource.FetchChunkTime, req.toString)
     logTrace(s"Received req from ${NettyUtils.getRemoteAddress(client.getChannel)}" +
       s" to fetch block ${req.streamChunkSlice}")
 
@@ -154,8 +155,10 @@ class FetchHandler(val conf: TransportConf) extends BaseMessageHandler with Logg
       logError(message)
       client.getChannel.writeAndFlush(
         new ChunkFetchFailure(req.streamChunkSlice, message))
-      workerSource.stopTimer(WorkerSource.FetchChunkTime, req.toString)
     } else {
+      workerSource.startTimer(WorkerSource.FetchChunkTime, req.toString)
+      val fetchTimeMetric = chunkStreamManager.getFetchTimeMetric(req.streamChunkSlice.streamId)
+      val fetchBeginTime = System.nanoTime()
       try {
         val buf = chunkStreamManager.getChunk(
           req.streamChunkSlice.streamId,
@@ -167,6 +170,9 @@ class FetchHandler(val conf: TransportConf) extends BaseMessageHandler with Logg
           .addListener(new GenericFutureListener[Future[_ >: Void]] {
             override def operationComplete(future: Future[_ >: Void]): Unit = {
               chunkStreamManager.chunkSent(req.streamChunkSlice.streamId)
+              if (fetchTimeMetric != null) {
+                fetchTimeMetric.update(System.nanoTime() - fetchBeginTime)
+              }
               workerSource.stopTimer(WorkerSource.FetchChunkTime, req.toString)
             }
           })
diff --git a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/storage/Flusher.scala b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/storage/Flusher.scala
index 1823733c0..2e7fbce87 100644
--- a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/storage/Flusher.scala
+++ b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/storage/Flusher.scala
@@ -20,7 +20,7 @@ package org.apache.celeborn.service.deploy.worker.storage
 import java.io.IOException
 import java.nio.channels.ClosedByInterruptException
 import java.util.concurrent.{LinkedBlockingQueue, TimeUnit}
-import java.util.concurrent.atomic.{AtomicBoolean, AtomicLongArray, LongAdder}
+import java.util.concurrent.atomic.{AtomicBoolean, AtomicLongArray}
 
 import scala.collection.JavaConverters._
 import scala.util.Random
@@ -28,7 +28,7 @@ import scala.util.Random
 import io.netty.buffer.{CompositeByteBuf, Unpooled}
 
 import org.apache.celeborn.common.internal.Logging
-import org.apache.celeborn.common.meta.DiskStatus
+import org.apache.celeborn.common.meta.{DiskStatus, TimeWindow}
 import org.apache.celeborn.common.metrics.source.AbstractSource
 import org.apache.celeborn.common.network.server.memory.MemoryManager
 import org.apache.celeborn.common.protocol.StorageInfo
@@ -37,17 +37,12 @@ import org.apache.celeborn.service.deploy.worker.WorkerSource
 abstract private[worker] class Flusher(
     val workerSource: AbstractSource,
     val threadCount: Int,
-    val avgFlushTimeSlidingWindowSize: Int,
-    val avgFlushTimeSlidingWindowMinCount: Int) extends Logging {
+    flushTimeMetric: TimeWindow) extends Logging {
   protected lazy val flusherId = System.identityHashCode(this)
   protected val workingQueues = new Array[LinkedBlockingQueue[FlushTask]](threadCount)
   protected val bufferQueue = new LinkedBlockingQueue[CompositeByteBuf]()
   protected val workers = new Array[Thread](threadCount)
   protected var nextWorkerIndex: Int = 0
-  protected val flushCount = new LongAdder
-  protected val flushTotalTime = new LongAdder
-  protected val avgTimeWindow = new Array[(Long, Long)](avgFlushTimeSlidingWindowSize)
-  protected var avgTimeWindowCurrentIndex = 0
 
   val lastBeginFlushTime: AtomicLongArray = new AtomicLongArray(threadCount)
   val stopFlag = new AtomicBoolean(false)
@@ -56,9 +51,6 @@ abstract private[worker] class Flusher(
   init()
 
   private def init(): Unit = {
-    for (i <- 0 until avgFlushTimeSlidingWindowSize) {
-      avgTimeWindow(i) = (0L, 0L)
-    }
     for (i <- 0 until lastBeginFlushTime.length()) {
       lastBeginFlushTime.set(i, -1)
     }
@@ -75,8 +67,10 @@ abstract private[worker] class Flusher(
                   val flushBeginTime = System.nanoTime()
                   lastBeginFlushTime.set(index, flushBeginTime)
                   task.flush()
-                  flushTotalTime.add(System.nanoTime() - flushBeginTime)
-                  flushCount.increment()
+                  if (flushTimeMetric != null) {
+                    val delta = System.nanoTime() - flushBeginTime
+                    flushTimeMetric.update(delta)
+                  }
                 } catch {
                   case _: ClosedByInterruptException =>
                   case e: IOException =>
@@ -106,32 +100,6 @@ abstract private[worker] class Flusher(
     nextWorkerIndex
   }
 
-  def averageFlushTime(): Long = {
-    if (this.isInstanceOf[LocalFlusher]) {
-      logInfo(s"Flush count in ${this.asInstanceOf[LocalFlusher].mountPoint}" +
-        s" last heartbeat interval: $flushCount")
-    }
-    val currentFlushTime = flushTotalTime.sumThenReset()
-    val currentFlushCount = flushCount.sumThenReset()
-    if (currentFlushCount >= avgFlushTimeSlidingWindowMinCount) {
-      avgTimeWindow(avgTimeWindowCurrentIndex) = (currentFlushTime, currentFlushCount)
-      avgTimeWindowCurrentIndex = (avgTimeWindowCurrentIndex + 1) % avgFlushTimeSlidingWindowSize
-    }
-
-    var totalFlushTime = 0L
-    var totalFlushCount = 0L
-    avgTimeWindow.foreach { case (flushTime, flushCount) =>
-      totalFlushTime = totalFlushTime + flushTime
-      totalFlushCount = totalFlushCount + flushCount
-    }
-
-    if (totalFlushCount != 0) {
-      totalFlushTime / totalFlushCount
-    } else {
-      0L
-    }
-  }
-
   def takeBuffer(): CompositeByteBuf = {
     var buffer = bufferQueue.poll()
     if (buffer == null) {
@@ -177,13 +145,11 @@ private[worker] class LocalFlusher(
     val deviceMonitor: DeviceMonitor,
     threadCount: Int,
     val mountPoint: String,
-    avgFlushTimeSlidingWindowSize: Int,
-    flushAvgTimeMinimumCount: Int,
-    val diskType: StorageInfo.Type) extends Flusher(
+    val diskType: StorageInfo.Type,
+    timeWindow: TimeWindow) extends Flusher(
     workerSource,
     threadCount,
-    avgFlushTimeSlidingWindowSize,
-    flushAvgTimeMinimumCount)
+    timeWindow)
   with DeviceObserver with Logging {
 
   deviceMonitor.registerFlusher(this)
@@ -215,14 +181,10 @@ private[worker] class LocalFlusher(
 
 final private[worker] class HdfsFlusher(
     workerSource: AbstractSource,
-    hdfsFlusherThreads: Int,
-    flushAvgTimeWindowSize: Int,
-    avgFlushTimeSlidingWindowMinCount: Int) extends Flusher(
+    hdfsFlusherThreads: Int) extends Flusher(
     workerSource,
     hdfsFlusherThreads,
-    flushAvgTimeWindowSize,
-    avgFlushTimeSlidingWindowMinCount) with Logging {
-
+    null) with Logging {
   override def toString: String = s"HdfsFlusher@$flusherId"
 
   override def processIOException(e: IOException, deviceErrorType: DiskStatus): Unit = {
diff --git a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/storage/StorageManager.scala b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/storage/StorageManager.scala
index 0962c2d53..cebff1547 100644
--- a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/storage/StorageManager.scala
+++ b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/storage/StorageManager.scala
@@ -39,7 +39,7 @@ import org.apache.celeborn.common.CelebornConf
 import org.apache.celeborn.common.exception.CelebornException
 import org.apache.celeborn.common.identity.UserIdentifier
 import org.apache.celeborn.common.internal.Logging
-import org.apache.celeborn.common.meta.{DeviceInfo, DiskInfo, DiskStatus, FileInfo}
+import org.apache.celeborn.common.meta.{DeviceInfo, DiskInfo, DiskStatus, FileInfo, TimeWindow}
 import org.apache.celeborn.common.metrics.source.AbstractSource
 import org.apache.celeborn.common.network.server.memory.MemoryManager.MemoryPressureListener
 import org.apache.celeborn.common.protocol.{PartitionLocation, PartitionSplitMode, PartitionType}
@@ -63,7 +63,7 @@ final private[worker] class StorageManager(conf: CelebornConf, workerSource: Abs
       throw new IOException("Empty working directory configuration!")
     }
 
-    DeviceInfo.getDeviceAndDiskInfos(workingDirInfos)
+    DeviceInfo.getDeviceAndDiskInfos(workingDirInfos, conf)
   }
   val mountPoints = new util.HashSet[String](diskInfos.keySet())
 
@@ -105,9 +105,8 @@ final private[worker] class StorageManager(conf: CelebornConf, workerSource: Abs
           deviceMonitor,
           diskInfo.threadCount,
           diskInfo.mountPoint,
-          conf.avgFlushTimeSlidingWindowSize,
-          conf.avgFlushTimeSlidingWindowMinCount,
-          diskInfo.storageType)
+          diskInfo.storageType,
+          diskInfo.flushTimeMetrics)
         flushers.put(diskInfo.mountPoint, flusher)
       }
     }
@@ -133,9 +132,7 @@ final private[worker] class StorageManager(conf: CelebornConf, workerSource: Abs
       StorageManager.hdfsFs = FileSystem.get(hdfsConfiguration)
       Some(new HdfsFlusher(
         workerSource,
-        conf.hdfsFlusherThreads,
-        conf.avgFlushTimeSlidingWindowSize,
-        conf.avgFlushTimeSlidingWindowMinCount))
+        conf.hdfsFlusherThreads))
     } else {
       None
     }
@@ -365,6 +362,15 @@ final private[worker] class StorageManager(conf: CelebornConf, workerSource: Abs
     }
   }
 
+  def getFetchTimeMetric(file: File): TimeWindow = {
+    if (diskInfos != null) {
+      val diskInfo = diskInfos.get(DeviceInfo.getMountPoint(file.getAbsolutePath, diskInfos))
+      if (diskInfo != null) {
+        diskInfo.fetchTimeMetrics
+      } else null
+    } else null
+  }
+
   def shuffleKeySet(): util.HashSet[String] = {
     val hashSet = new util.HashSet[String]()
     hashSet.addAll(fileInfos.keySet())
@@ -636,9 +642,9 @@ final private[worker] class StorageManager(conf: CelebornConf, workerSource: Abs
       val workingDirUsableSpace =
         Math.min(diskInfo.configuredUsableSpace - totalUsage, fileSystemReportedUsableSpace)
       logDebug(s"updateDiskInfos  workingDirUsableSpace:$workingDirUsableSpace filemeta:$fileSystemReportedUsableSpace conf:${diskInfo.configuredUsableSpace} totalUsage:$totalUsage")
-      val flushTimeAverage = localFlushers.get(diskInfo.mountPoint).averageFlushTime()
       diskInfo.setUsableSpace(workingDirUsableSpace)
-      diskInfo.setFlushTime(flushTimeAverage)
+      diskInfo.updateFlushTime()
+      diskInfo.updateFetchTime()
     }
     logInfo(s"Updated diskInfos: ${disksSnapshot()}")
   }
diff --git a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/FileWriterSuiteJ.java b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/FileWriterSuiteJ.java
index f20ea47fb..643bd723f 100644
--- a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/FileWriterSuiteJ.java
+++ b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/FileWriterSuiteJ.java
@@ -118,13 +118,18 @@ public class FileWriterSuiteJ {
     dirs.$plus$eq(tempDir);
     localFlusher =
         new LocalFlusher(
-            source, DeviceMonitor$.MODULE$.EmptyMonitor(), 1, "disk1", 20, 1, StorageInfo.Type.HDD);
+            source, DeviceMonitor$.MODULE$.EmptyMonitor(), 1, "disk1", StorageInfo.Type.HDD, null);
     MemoryManager.initialize(0.8, 0.9, 0.5, 0.6, 0.1, 0.1, 10, 10);
   }
 
   public static void setupChunkServer(FileInfo info) throws Exception {
     FetchHandler handler =
         new FetchHandler(transConf) {
+          @Override
+          public StorageManager storageManager() {
+            return new StorageManager(CONF, source);
+          }
+
           @Override
           public FileInfo getRawFileInfo(String shuffleKey, String fileName) {
             return info;
@@ -328,12 +333,12 @@ public class FileWriterSuiteJ {
     dirs.$plus$eq(file);
     localFlusher =
         new LocalFlusher(
-            source, DeviceMonitor$.MODULE$.EmptyMonitor(), 1, "disk2", 20, 1, StorageInfo.Type.HDD);
+            source, DeviceMonitor$.MODULE$.EmptyMonitor(), 1, "disk2", StorageInfo.Type.HDD, null);
   }
 
   @Test
   public void testWriteAndChunkRead() throws Exception {
-    final int threadsNum = 8;
+    final int threadsNum = 16;
     File file = getTemporaryFile();
     FileInfo fileInfo = new FileInfo(file, userIdentifier);
     FileWriter fileWriter =
diff --git a/worker/src/test/scala/org/apache/celeborn/service/deploy/worker/storage/DeviceMonitorSuite.scala b/worker/src/test/scala/org/apache/celeborn/service/deploy/worker/storage/DeviceMonitorSuite.scala
index de05c64fc..8082424c4 100644
--- a/worker/src/test/scala/org/apache/celeborn/service/deploy/worker/storage/DeviceMonitorSuite.scala
+++ b/worker/src/test/scala/org/apache/celeborn/service/deploy/worker/storage/DeviceMonitorSuite.scala
@@ -76,8 +76,10 @@ class DeviceMonitorSuite extends AnyFunSuite {
   withObjectMocked[org.apache.celeborn.common.util.Utils.type] {
     when(Utils.runCommand(dfCmd)) thenReturn dfOut
     when(Utils.runCommand(lsCmd)) thenReturn lsOut
-    val (tdeviceInfos, tdiskInfos) = DeviceInfo.getDeviceAndDiskInfos(dirs.asScala.toArray.map(f =>
-      (f, Long.MaxValue, 1, StorageInfo.Type.HDD)))
+    val (tdeviceInfos, tdiskInfos) = DeviceInfo.getDeviceAndDiskInfos(
+      dirs.asScala.toArray.map(f =>
+        (f, Long.MaxValue, 1, StorageInfo.Type.HDD)),
+      conf)
     deviceInfos = tdeviceInfos
     diskInfos = tdiskInfos
   }