You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@celeborn.apache.org by zh...@apache.org on 2023/01/06 02:11:51 UTC

[incubator-celeborn] branch branch-0.2 updated: [CELEBORN-173] refactor minicluster and fix ut (#1147) (#1148)

This is an automated email from the ASF dual-hosted git repository.

zhouky pushed a commit to branch branch-0.2
in repository https://gitbox.apache.org/repos/asf/incubator-celeborn.git


The following commit(s) were added to refs/heads/branch-0.2 by this push:
     new ae07fecc [CELEBORN-173] refactor minicluster and fix ut (#1147) (#1148)
ae07fecc is described below

commit ae07fecc53efd592c2672d15572f75bd5e2138ec
Author: Shuang <lv...@gmail.com>
AuthorDate: Fri Jan 6 10:11:47 2023 +0800

    [CELEBORN-173] refactor minicluster and fix ut (#1147) (#1148)
---
 .../celeborn/tests/client/ShuffleClientSuite.scala |   7 +-
 .../apache/celeborn/tests/spark/HugeDataTest.scala |  13 +--
 .../celeborn/tests/spark/PushdataTimeoutTest.scala |  10 +-
 .../tests/spark/RetryCommitFilesTest.scala         |  10 +-
 .../apache/celeborn/tests/spark/RssHashSuite.scala |  13 +--
 .../apache/celeborn/tests/spark/RssSortSuite.scala |  13 +--
 .../celeborn/tests/spark/SkewJoinSuite.scala       |  16 +--
 .../celeborn/tests/spark/SparkTestBase.scala       | 107 +++------------------
 .../service/deploy/MiniClusterFeature.scala        |  90 +++++++++--------
 .../cluster/ClusterReadWriteTestWithLZ4.scala      |  29 +-----
 .../cluster/ClusterReadWriteTestWithZSTD.scala     |  29 +-----
 .../service/deploy/cluster/ReadWriteTestBase.scala |  21 +++-
 12 files changed, 89 insertions(+), 269 deletions(-)

diff --git a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/ShuffleClientSuite.scala b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/ShuffleClientSuite.scala
index ea7e7fe2..46ebb18c 100644
--- a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/ShuffleClientSuite.scala
+++ b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/ShuffleClientSuite.scala
@@ -24,7 +24,7 @@ import org.junit.Assert
 import org.scalatest.BeforeAndAfterAll
 import org.scalatest.funsuite.AnyFunSuite
 
-import org.apache.celeborn.client.{LifecycleManager, ShuffleClient, ShuffleClientImpl}
+import org.apache.celeborn.client.{LifecycleManager, ShuffleClientImpl}
 import org.apache.celeborn.common.CelebornConf
 import org.apache.celeborn.common.identity.UserIdentifier
 import org.apache.celeborn.common.util.PackedPartitionId
@@ -94,8 +94,7 @@ class ShuffleClientSuite extends AnyFunSuite with MiniClusterFeature
   }
 
   override def afterAll(): Unit = {
-    // TODO refactor MiniCluster later
-    println("test done")
-    sys.exit(0)
+    logInfo("all test complete , stop rss mini cluster")
+    shutdownMiniCluster()
   }
 }
diff --git a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/HugeDataTest.scala b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/HugeDataTest.scala
index 932998e7..272558bd 100644
--- a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/HugeDataTest.scala
+++ b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/HugeDataTest.scala
@@ -19,26 +19,15 @@ package org.apache.celeborn.tests.spark
 
 import org.apache.spark.SparkConf
 import org.apache.spark.sql.SparkSession
-import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach}
+import org.scalatest.BeforeAndAfterEach
 import org.scalatest.funsuite.AnyFunSuite
 
 import org.apache.celeborn.client.ShuffleClient
 
 class HugeDataTest extends AnyFunSuite
   with SparkTestBase
-  with BeforeAndAfterAll
   with BeforeAndAfterEach {
 
-  override def beforeAll(): Unit = {
-    logInfo("test initialized , setup rss mini cluster")
-    tuple = setupRssMiniClusterSpark()
-  }
-
-  override def afterAll(): Unit = {
-    logInfo("all test complete , stop rss mini cluster")
-    clearMiniCluster(tuple)
-  }
-
   override def beforeEach(): Unit = {
     ShuffleClient.reset()
   }
diff --git a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/PushdataTimeoutTest.scala b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/PushdataTimeoutTest.scala
index ed876524..74eb2d60 100644
--- a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/PushdataTimeoutTest.scala
+++ b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/PushdataTimeoutTest.scala
@@ -19,26 +19,20 @@ package org.apache.celeborn.tests.spark
 
 import org.apache.spark.SparkConf
 import org.apache.spark.sql.SparkSession
-import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach}
+import org.scalatest.BeforeAndAfterEach
 import org.scalatest.funsuite.AnyFunSuite
 
 import org.apache.celeborn.client.ShuffleClient
 
 class PushdataTimeoutTest extends AnyFunSuite
   with SparkTestBase
-  with BeforeAndAfterAll
   with BeforeAndAfterEach {
 
   override def beforeAll(): Unit = {
     logInfo("test initialized , setup rss mini cluster")
     val workerConf = Map(
       "celeborn.test.pushdataTimeout" -> s"true")
-    tuple = setupRssMiniClusterSpark(masterConfs = null, workerConfs = workerConf)
-  }
-
-  override def afterAll(): Unit = {
-    logInfo("all test complete , stop rss mini cluster")
-    clearMiniCluster(tuple)
+    setUpMiniCluster(masterConfs = null, workerConfs = workerConf)
   }
 
   override def beforeEach(): Unit = {
diff --git a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/RetryCommitFilesTest.scala b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/RetryCommitFilesTest.scala
index afacdf5d..379e6726 100644
--- a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/RetryCommitFilesTest.scala
+++ b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/RetryCommitFilesTest.scala
@@ -19,26 +19,20 @@ package org.apache.celeborn.tests.spark
 
 import org.apache.spark.SparkConf
 import org.apache.spark.sql.SparkSession
-import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach}
+import org.scalatest.BeforeAndAfterEach
 import org.scalatest.funsuite.AnyFunSuite
 
 import org.apache.celeborn.client.ShuffleClient
 
 class RetryCommitFilesTest extends AnyFunSuite
   with SparkTestBase
-  with BeforeAndAfterAll
   with BeforeAndAfterEach {
 
   override def beforeAll(): Unit = {
     logInfo("test initialized , setup rss mini cluster")
     val workerConf = Map(
       "celeborn.test.retryCommitFiles" -> s"true")
-    tuple = setupRssMiniClusterSpark(masterConfs = null, workerConfs = workerConf)
-  }
-
-  override def afterAll(): Unit = {
-    logInfo("all test complete , stop rss mini cluster")
-    clearMiniCluster(tuple)
+    setUpMiniCluster(masterConfs = null, workerConfs = workerConf)
   }
 
   override def beforeEach(): Unit = {
diff --git a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/RssHashSuite.scala b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/RssHashSuite.scala
index 6e73c217..f7fa360a 100644
--- a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/RssHashSuite.scala
+++ b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/RssHashSuite.scala
@@ -19,26 +19,15 @@ package org.apache.celeborn.tests.spark
 
 import org.apache.spark.SparkConf
 import org.apache.spark.sql.SparkSession
-import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach}
+import org.scalatest.BeforeAndAfterEach
 import org.scalatest.funsuite.AnyFunSuite
 
 import org.apache.celeborn.client.ShuffleClient
 
 class RssHashSuite extends AnyFunSuite
   with SparkTestBase
-  with BeforeAndAfterAll
   with BeforeAndAfterEach {
 
-  override def beforeAll(): Unit = {
-    logInfo("test initialized , setup rss mini cluster")
-    tuple = setupRssMiniClusterSpark()
-  }
-
-  override def afterAll(): Unit = {
-    logInfo("all test complete , stop rss mini cluster")
-    clearMiniCluster(tuple)
-  }
-
   override def beforeEach(): Unit = {
     ShuffleClient.reset()
   }
diff --git a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/RssSortSuite.scala b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/RssSortSuite.scala
index ea9537f0..159f3713 100644
--- a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/RssSortSuite.scala
+++ b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/RssSortSuite.scala
@@ -19,26 +19,15 @@ package org.apache.celeborn.tests.spark
 
 import org.apache.spark.SparkConf
 import org.apache.spark.sql.SparkSession
-import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach}
+import org.scalatest.BeforeAndAfterEach
 import org.scalatest.funsuite.AnyFunSuite
 
 import org.apache.celeborn.client.ShuffleClient
 
 class RssSortSuite extends AnyFunSuite
   with SparkTestBase
-  with BeforeAndAfterAll
   with BeforeAndAfterEach {
 
-  override def beforeAll(): Unit = {
-    logInfo("test initialized , setup rss mini cluster")
-    tuple = setupRssMiniClusterSpark()
-  }
-
-  override def afterAll(): Unit = {
-    logInfo("all test complete , stop rss mini cluster")
-    clearMiniCluster(tuple)
-  }
-
   override def beforeEach(): Unit = {
     ShuffleClient.reset()
   }
diff --git a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SkewJoinSuite.scala b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SkewJoinSuite.scala
index 19e74005..69f3768f 100644
--- a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SkewJoinSuite.scala
+++ b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SkewJoinSuite.scala
@@ -21,28 +21,16 @@ import scala.util.Random
 
 import org.apache.spark.SparkConf
 import org.apache.spark.sql.SparkSession
-import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach}
+import org.scalatest.BeforeAndAfterEach
 import org.scalatest.funsuite.AnyFunSuite
 
 import org.apache.celeborn.client.ShuffleClient
 import org.apache.celeborn.common.protocol.CompressionCodec
-import org.apache.celeborn.common.util.Utils
 
 class SkewJoinSuite extends AnyFunSuite
   with SparkTestBase
-  with BeforeAndAfterAll
   with BeforeAndAfterEach {
 
-  override def beforeAll(): Unit = {
-    logInfo("test initialized , setup rss mini cluster")
-    tuple = setupRssMiniClusterSpark()
-  }
-
-  override def afterAll(): Unit = {
-    logInfo("all test complete , stop rss mini cluster")
-    clearMiniCluster(tuple)
-  }
-
   override def beforeEach(): Unit = {
     ShuffleClient.reset()
   }
@@ -53,7 +41,7 @@ class SkewJoinSuite extends AnyFunSuite
 
   private def enableRss(conf: SparkConf) = {
     conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.celeborn.RssShuffleManager")
-      .set("spark.rss.master.address", tuple._1.rpcEnv.address.toString)
+      .set("spark.rss.master.address", masterInfo._1.rpcEnv.address.toString)
       .set("spark.rss.shuffle.split.threshold", "10MB")
   }
 
diff --git a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SparkTestBase.scala b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SparkTestBase.scala
index 92d79d5f..481a2769 100644
--- a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SparkTestBase.scala
+++ b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SparkTestBase.scala
@@ -21,112 +21,27 @@ import scala.util.Random
 
 import org.apache.spark.SparkConf
 import org.apache.spark.sql.SparkSession
+import org.scalatest.BeforeAndAfterAll
+import org.scalatest.funsuite.AnyFunSuite
 
 import org.apache.celeborn.common.internal.Logging
-import org.apache.celeborn.common.rpc.RpcEnv
 import org.apache.celeborn.service.deploy.MiniClusterFeature
-import org.apache.celeborn.service.deploy.master.Master
-import org.apache.celeborn.service.deploy.worker.Worker
 
-trait SparkTestBase extends Logging with MiniClusterFeature {
+trait SparkTestBase extends AnyFunSuite
+  with Logging with MiniClusterFeature with BeforeAndAfterAll {
   private val sampleSeq = (1 to 78)
     .map(Random.alphanumeric)
     .toList
     .map(v => (v.toUpper, Random.nextInt(12) + 1))
 
-  @volatile var tuple: (
-      Master,
-      RpcEnv,
-      Worker,
-      RpcEnv,
-      Worker,
-      RpcEnv,
-      Worker,
-      RpcEnv,
-      Thread,
-      Thread,
-      Thread,
-      Thread) = _
-
-  def clearMiniCluster(
-      tuple: (
-          Master,
-          RpcEnv,
-          Worker,
-          RpcEnv,
-          Worker,
-          RpcEnv,
-          Worker,
-          RpcEnv,
-          Thread,
-          Thread,
-          Thread,
-          Thread)): Unit = {
-    tuple._3.close()
-    tuple._4.shutdown()
-    tuple._5.close()
-    tuple._6.shutdown()
-    tuple._7.close()
-    tuple._8.shutdown()
-    tuple._1.close()
-    tuple._2.shutdown()
-    Thread.sleep(5000L)
-    tuple._10.interrupt()
-    tuple._11.interrupt()
-    tuple._12.interrupt()
+  override def beforeAll(): Unit = {
+    logInfo("test initialized , setup rss mini cluster")
+    setUpMiniCluster()
   }
 
-  def setupRssMiniClusterSpark(
-      masterConfs: Map[String, String] = null,
-      workerConfs: Map[String, String] = null): (
-      Master,
-      RpcEnv,
-      Worker,
-      RpcEnv,
-      Worker,
-      RpcEnv,
-      Worker,
-      RpcEnv,
-      Thread,
-      Thread,
-      Thread,
-      Thread) = {
-    Thread.sleep(3000L)
-
-    val (master, masterRpcEnv) = createMaster(masterConfs)
-    val (worker1, workerRpcEnv1) = createWorker(workerConfs)
-    val (worker2, workerRpcEnv2) = createWorker(workerConfs)
-    val (worker3, workerRpcEnv3) = createWorker(workerConfs)
-    val masterThread = runnerWrap(masterRpcEnv.awaitTermination())
-    val workerThread1 = runnerWrap(worker1.initialize())
-    val workerThread2 = runnerWrap(worker2.initialize())
-    val workerThread3 = runnerWrap(worker3.initialize())
-
-    masterThread.start()
-    Thread.sleep(5000L)
-
-    workerThread1.start()
-    workerThread2.start()
-    workerThread3.start()
-    Thread.sleep(5000L)
-
-    assert(worker1.isRegistered())
-    assert(worker2.isRegistered())
-    assert(worker3.isRegistered())
-
-    (
-      master,
-      masterRpcEnv,
-      worker1,
-      workerRpcEnv1,
-      worker2,
-      workerRpcEnv2,
-      worker3,
-      workerRpcEnv3,
-      masterThread,
-      workerThread1,
-      workerThread2,
-      workerThread3)
+  override def afterAll(): Unit = {
+    logInfo("all test complete , stop rss mini cluster")
+    shutdownMiniCluster()
   }
 
   def updateSparkConf(sparkConf: SparkConf, sort: Boolean): SparkConf = {
@@ -137,7 +52,7 @@ trait SparkTestBase extends Logging with MiniClusterFeature {
     sparkConf.set("spark.shuffle.service.enabled", "false")
     sparkConf.set("spark.sql.adaptive.skewJoin.enabled", "false")
     sparkConf.set("spark.sql.adaptive.localShuffleReader.enabled", "false")
-    sparkConf.set("spark.celeborn.master.endpoints", tuple._1.rpcEnv.address.toString)
+    sparkConf.set("spark.celeborn.master.endpoints", masterInfo._1.rpcEnv.address.toString)
     if (sort) {
       sparkConf.set("spark.celeborn.shuffle.writer.mode", "sort")
     }
diff --git a/worker/src/test/scala/org/apache/celeborn/service/deploy/MiniClusterFeature.scala b/worker/src/test/scala/org/apache/celeborn/service/deploy/MiniClusterFeature.scala
index 277c538d..142f750a 100644
--- a/worker/src/test/scala/org/apache/celeborn/service/deploy/MiniClusterFeature.scala
+++ b/worker/src/test/scala/org/apache/celeborn/service/deploy/MiniClusterFeature.scala
@@ -20,9 +20,10 @@ package org.apache.celeborn.service.deploy
 import java.nio.file.Files
 import java.util.concurrent.atomic.AtomicInteger
 
+import scala.collection.mutable
+
 import org.apache.celeborn.common.CelebornConf
 import org.apache.celeborn.common.internal.Logging
-import org.apache.celeborn.common.rpc.RpcEnv
 import org.apache.celeborn.common.util.Utils
 import org.apache.celeborn.service.deploy.master.{Master, MasterArguments}
 import org.apache.celeborn.service.deploy.worker.{Worker, WorkerArguments}
@@ -30,21 +31,23 @@ import org.apache.celeborn.service.deploy.worker.{Worker, WorkerArguments}
 trait MiniClusterFeature extends Logging {
   val workerPrometheusPort = new AtomicInteger(12378)
   val masterPrometheusPort = new AtomicInteger(22378)
+  var masterInfo: (Master, Thread) = _
+  val workerInfos = new mutable.HashMap[Worker, Thread]()
 
-  protected def runnerWrap[T](code: => T): Thread = new Thread(new Runnable {
+  private def runnerWrap[T](code: => T): Thread = new Thread(new Runnable {
     override def run(): Unit = {
       Utils.tryLogNonFatalError(code)
     }
   })
 
-  protected def createTmpDir(): String = {
+  private def createTmpDir(): String = {
     val tmpDir = Files.createTempDirectory("celeborn-")
     logInfo(s"created temp dir: $tmpDir")
     tmpDir.toFile.deleteOnExit()
     tmpDir.toAbsolutePath.toString
   }
 
-  protected def createMaster(map: Map[String, String] = null): (Master, RpcEnv) = {
+  private def createMaster(map: Map[String, String] = null): Master = {
     val conf = new CelebornConf()
     conf.set("celeborn.metrics.enabled", "false")
     val prometheusPort = masterPrometheusPort.getAndIncrement()
@@ -59,10 +62,10 @@ trait MiniClusterFeature extends Logging {
     master.startHttpServer()
 
     Thread.sleep(5000L)
-    (master, master.rpcEnv)
+    master
   }
 
-  protected def createWorker(map: Map[String, String] = null): (Worker, RpcEnv) = {
+  private def createWorker(map: Map[String, String] = null): Worker = {
     logInfo("start create worker for mini cluster")
     val conf = new CelebornConf()
     conf.set("celeborn.worker.storage.dirs", createTmpDir())
@@ -83,63 +86,58 @@ trait MiniClusterFeature extends Logging {
     try {
       val worker = new Worker(conf, workerArguments)
       logInfo("worker created for mini cluster")
-      (worker, worker.rpcEnv)
+      worker
     } catch {
       case e: Exception =>
         logError("create worker failed, detail:", e)
         System.exit(-1)
-        (null, null)
+        null
     }
   }
 
   def setUpMiniCluster(
       masterConfs: Map[String, String] = null,
-      workerConfs: Map[String, String] = null)
-      : (Worker, RpcEnv, Worker, RpcEnv, Worker, RpcEnv, Worker, RpcEnv, Worker, RpcEnv) = {
-    val (master, masterRpcEnv) = createMaster(masterConfs)
-    val masterThread = runnerWrap(masterRpcEnv.awaitTermination())
+      workerConfs: Map[String, String] = null,
+      workerNum: Int = 3): Unit = {
+    val master = createMaster(masterConfs)
+    val masterThread = runnerWrap(master.rpcEnv.awaitTermination())
     masterThread.start()
-
+    masterInfo = (master, masterThread)
     Thread.sleep(5000L)
 
-    val (worker1, workerRpcEnv1) = createWorker(workerConfs)
-    val workerThread1 = runnerWrap(worker1.initialize())
-    workerThread1.start()
+    for (_ <- 1 to workerNum) {
+      val worker = createWorker(workerConfs)
+      val workerThread = runnerWrap(worker.initialize())
+      workerThread.start()
+      workerInfos.put(worker, workerThread)
+    }
 
-    val (worker2, workerRpcEnv2) = createWorker(workerConfs)
-    val workerThread2 = runnerWrap(worker2.initialize())
-    workerThread2.start()
+    Thread.sleep(5000L)
 
-    val (worker3, workerRpcEnv3) = createWorker(workerConfs)
-    val workerThread3 = runnerWrap(worker3.initialize())
-    workerThread3.start()
+    workerInfos.foreach {
+      case (worker, _) => assert(worker.isRegistered())
+    }
+  }
 
-    val (worker4, workerRpcEnv4) = createWorker(workerConfs)
-    val workerThread4 = runnerWrap(worker4.initialize())
-    workerThread4.start()
+  def shutdownMiniCluster(): Unit = {
+    // shutdown workers
+    workerInfos.foreach {
+      case (worker, _) =>
+        worker.close()
+        worker.rpcEnv.shutdown()
+    }
 
-    val (worker5, workerRpcEnv5) = createWorker(workerConfs)
-    val workerThread5 = runnerWrap(worker5.initialize())
-    workerThread5.start()
+    // shutdown masters
+    masterInfo._1.close()
+    masterInfo._1.rpcEnv.shutdown()
 
-    Thread.sleep(5000L)
+    // interrupt threads
+    Thread.sleep(5000)
+    workerInfos.foreach {
+      case (_, thread) =>
+        thread.interrupt()
+    }
 
-    assert(worker1.isRegistered())
-    assert(worker2.isRegistered())
-    assert(worker3.isRegistered())
-    assert(worker4.isRegistered())
-    assert(worker5.isRegistered())
-
-    (
-      worker1,
-      workerRpcEnv1,
-      worker2,
-      workerRpcEnv2,
-      worker3,
-      workerRpcEnv3,
-      worker4,
-      workerRpcEnv4,
-      worker5,
-      workerRpcEnv5)
+    masterInfo._2.interrupt()
   }
 }
diff --git a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ClusterReadWriteTestWithLZ4.scala b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ClusterReadWriteTestWithLZ4.scala
index 5612faee..cfc14b1d 100644
--- a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ClusterReadWriteTestWithLZ4.scala
+++ b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ClusterReadWriteTestWithLZ4.scala
@@ -17,36 +17,9 @@
 
 package org.apache.celeborn.service.deploy.cluster
 
-import java.io.ByteArrayOutputStream
-import java.nio.charset.StandardCharsets
-
-import org.apache.commons.lang3.RandomStringUtils
-import org.scalatest.BeforeAndAfterAll
-import org.scalatest.funsuite.AnyFunSuite
-
-import org.apache.celeborn.client.{LifecycleManager, ShuffleClientImpl}
-import org.apache.celeborn.common.CelebornConf
-import org.apache.celeborn.common.identity.UserIdentifier
 import org.apache.celeborn.common.protocol.CompressionCodec
-import org.apache.celeborn.service.deploy.MiniClusterFeature
-
-class ClusterReadWriteTestWithLZ4 extends AnyFunSuite with MiniClusterFeature
-  with BeforeAndAfterAll with ReadWriteTestBase {
 
-  override def beforeAll(): Unit = {
-    val masterPort = 19097
-    val masterConf = Map(
-      "celeborn.master.host" -> "localhost",
-      "celeborn.master.port" -> masterPort.toString)
-    val workerConf = Map(
-      "celeborn.master.endpoints" -> s"localhost:$masterPort")
-    setUpMiniCluster(masterConf, workerConf)
-  }
-
-  override def afterAll(): Unit = {
-    println("test done")
-    sys.exit(0)
-  }
+class ClusterReadWriteTestWithLZ4 extends ReadWriteTestBase {
 
   test(s"test MiniCluster With LZ4") {
     testReadWriteByCode(CompressionCodec.LZ4)
diff --git a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ClusterReadWriteTestWithZSTD.scala b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ClusterReadWriteTestWithZSTD.scala
index b2be817b..8d207cb2 100644
--- a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ClusterReadWriteTestWithZSTD.scala
+++ b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ClusterReadWriteTestWithZSTD.scala
@@ -17,36 +17,9 @@
 
 package org.apache.celeborn.service.deploy.cluster
 
-import java.io.ByteArrayOutputStream
-import java.nio.charset.StandardCharsets
-
-import org.apache.commons.lang3.RandomStringUtils
-import org.scalatest.BeforeAndAfterAll
-import org.scalatest.funsuite.AnyFunSuite
-
-import org.apache.celeborn.client.{LifecycleManager, ShuffleClientImpl}
-import org.apache.celeborn.common.CelebornConf
-import org.apache.celeborn.common.identity.UserIdentifier
 import org.apache.celeborn.common.protocol.CompressionCodec
-import org.apache.celeborn.service.deploy.MiniClusterFeature
-
-class ClusterReadWriteTestWithZSTD extends AnyFunSuite with MiniClusterFeature
-  with BeforeAndAfterAll with ReadWriteTestBase {
 
-  override def beforeAll(): Unit = {
-    val masterPort = 19097
-    val masterConf = Map(
-      "celeborn.master.host" -> "localhost",
-      "celeborn.master.port" -> masterPort.toString)
-    val workerConf = Map(
-      "celeborn.master.endpoints" -> s"localhost:$masterPort")
-    setUpMiniCluster(masterConf, workerConf)
-  }
-
-  override def afterAll(): Unit = {
-    println("test done")
-    sys.exit(0)
-  }
+class ClusterReadWriteTestWithZSTD extends ReadWriteTestBase {
 
   test(s"test MiniCluster With ZSTD") {
     testReadWriteByCode(CompressionCodec.ZSTD)
diff --git a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestBase.scala b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestBase.scala
index ae50a499..50816a2e 100644
--- a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestBase.scala
+++ b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestBase.scala
@@ -22,16 +22,35 @@ import java.nio.charset.StandardCharsets
 
 import org.apache.commons.lang3.RandomStringUtils
 import org.junit.Assert
+import org.scalatest.BeforeAndAfterAll
+import org.scalatest.funsuite.AnyFunSuite
 
 import org.apache.celeborn.client.{LifecycleManager, ShuffleClientImpl}
 import org.apache.celeborn.common.CelebornConf
 import org.apache.celeborn.common.identity.UserIdentifier
 import org.apache.celeborn.common.internal.Logging
 import org.apache.celeborn.common.protocol.CompressionCodec
+import org.apache.celeborn.service.deploy.MiniClusterFeature
 
-trait ReadWriteTestBase extends Logging {
+trait ReadWriteTestBase extends AnyFunSuite
+  with Logging with MiniClusterFeature with BeforeAndAfterAll {
   val masterPort = 19097
 
+  override def beforeAll(): Unit = {
+    val masterConf = Map(
+      "celeborn.master.host" -> "localhost",
+      "celeborn.master.port" -> masterPort.toString)
+    val workerConf = Map(
+      "celeborn.master.endpoints" -> s"localhost:$masterPort")
+    logInfo("test initialized , setup rss mini cluster")
+    setUpMiniCluster(masterConf, workerConf)
+  }
+
+  override def afterAll(): Unit = {
+    logInfo("all test complete , stop rss mini cluster")
+    shutdownMiniCluster()
+  }
+
   def testReadWriteByCode(codec: CompressionCodec): Unit = {
     val APP = "app-1"