You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by pw...@apache.org on 2014/01/10 03:38:33 UTC

[28/37] git commit: Adding unit tests and some refactoring to promote testability.

Adding unit tests and some refactoring to promote testability.


Project: http://git-wip-us.apache.org/repos/asf/incubator-spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-spark/commit/e21a707a
Tree: http://git-wip-us.apache.org/repos/asf/incubator-spark/tree/e21a707a
Diff: http://git-wip-us.apache.org/repos/asf/incubator-spark/diff/e21a707a

Branch: refs/heads/master
Commit: e21a707a13b437327cef25d44ef08ddb2e3931af
Parents: b72cceb
Author: Patrick Wendell <pw...@gmail.com>
Authored: Tue Jan 7 00:21:43 2014 -0800
Committer: Patrick Wendell <pw...@gmail.com>
Committed: Tue Jan 7 15:39:47 2014 -0800

----------------------------------------------------------------------
 core/pom.xml                                    |   5 +
 .../spark/deploy/worker/DriverRunner.scala      |  88 +++++++++----
 .../spark/deploy/worker/ExecutorRunner.scala    |  10 +-
 .../spark/deploy/worker/WorkerWatcher.scala     |  14 +-
 .../apache/spark/deploy/JsonProtocolSuite.scala |   2 +-
 .../spark/deploy/worker/DriverRunnerTest.scala  | 131 +++++++++++++++++++
 .../deploy/worker/ExecutorRunnerTest.scala      |   4 +-
 .../deploy/worker/WorkerWatcherSuite.scala      |  32 +++++
 pom.xml                                         |  12 ++
 project/SparkBuild.scala                        |   1 +
 10 files changed, 264 insertions(+), 35 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/e21a707a/core/pom.xml
----------------------------------------------------------------------
diff --git a/core/pom.xml b/core/pom.xml
index aac0a9d..1c52b33 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -99,6 +99,11 @@
             <artifactId>akka-slf4j_${scala.binary.version}</artifactId>
         </dependency>
         <dependency>
+            <groupId>${akka.group}</groupId>
+            <artifactId>akka-testkit_${scala.binary.version}</artifactId>
+            <scope>test</scope>
+        </dependency>
+        <dependency>
             <groupId>org.scala-lang</groupId>
             <artifactId>scala-library</artifactId>
         </dependency>

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/e21a707a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala
index f726089..d13d7ef 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala
@@ -19,6 +19,7 @@ package org.apache.spark.deploy.worker
 
 import java.io._
 
+import scala.collection.JavaConversions._
 import scala.collection.mutable.Map
 
 import akka.actor.ActorRef
@@ -47,6 +48,16 @@ private[spark] class DriverRunner(
   @volatile var process: Option[Process] = None
   @volatile var killed = false
 
+  // Decoupled for testing
+  private[deploy] def setClock(_clock: Clock) = clock = _clock
+  private[deploy] def setSleeper(_sleeper: Sleeper) = sleeper = _sleeper
+  private var clock = new Clock {
+    def currentTimeMillis(): Long = System.currentTimeMillis()
+  }
+  private var sleeper = new Sleeper {
+    def sleep(seconds: Int): Unit = (0 until seconds).takeWhile(f => {Thread.sleep(1000); !killed})
+  }
+
   /** Starts a thread to run and manage the driver. */
   def start() = {
     new Thread("DriverRunner for " + driverId) {
@@ -63,10 +74,9 @@ private[spark] class DriverRunner(
           env("SPARK_CLASSPATH") = env.getOrElse("SPARK_CLASSPATH", "") + s":$localJarFilename"
           val newCommand = Command(driverDesc.command.mainClass,
             driverDesc.command.arguments.map(substituteVariables), env)
-
           val command = CommandUtils.buildCommandSeq(newCommand, driverDesc.mem,
             sparkHome.getAbsolutePath)
-          runCommand(command, env, driverDir, driverDesc.supervise)
+          launchDriver(command, env, driverDir, driverDesc.supervise)
         }
         catch {
           case e: Exception => exn = Some(e)
@@ -116,7 +126,7 @@ private[spark] class DriverRunner(
 
     val jarPath = new Path(driverDesc.jarUrl)
 
-    val emptyConf = new Configuration() // TODO: In docs explain it needs to be full HDFS path
+    val emptyConf = new Configuration()
     val jarFileSystem = jarPath.getFileSystem(emptyConf)
 
     val destPath = new File(driverDir.getAbsolutePath, jarPath.getName)
@@ -136,51 +146,77 @@ private[spark] class DriverRunner(
     localJarFilename
   }
 
-  /** Launch the supplied command. */
-  private def runCommand(command: Seq[String], envVars: Map[String, String], baseDir: File,
-      supervise: Boolean) {
+  private def launchDriver(command: Seq[String], envVars: Map[String, String], baseDir: File,
+                           supervise: Boolean) {
+    val builder = new ProcessBuilder(command: _*).directory(baseDir)
+    envVars.map{ case(k,v) => builder.environment().put(k, v) }
+
+    def initialize(process: Process) = {
+      // Redirect stdout and stderr to files
+      val stdout = new File(baseDir, "stdout")
+      CommandUtils.redirectStream(process.getInputStream, stdout)
+
+      val stderr = new File(baseDir, "stderr")
+      val header = "Launch Command: %s\n%s\n\n".format(
+        command.mkString("\"", "\" \"", "\""), "=" * 40)
+      Files.append(header, stderr, Charsets.UTF_8)
+      CommandUtils.redirectStream(process.getErrorStream, stderr)
+    }
+    runCommandWithRetry(ProcessBuilderLike(builder), initialize, supervise)
+  }
 
+  private[deploy] def runCommandWithRetry(command: ProcessBuilderLike, initialize: Process => Unit,
+    supervise: Boolean) {
     // Time to wait between submission retries.
     var waitSeconds = 1
     // A run of this many seconds resets the exponential back-off.
-    val successfulRunDuration = 1
+    val successfulRunDuration = 5
 
     var keepTrying = !killed
 
     while (keepTrying) {
-      logInfo("Launch Command: " + command.mkString("\"", "\" \"", "\""))
-      val builder = new ProcessBuilder(command: _*).directory(baseDir)
-      envVars.map{ case(k,v) => builder.environment().put(k, v) }
+      logInfo("Launch Command: " + command.command.mkString("\"", "\" \"", "\""))
 
       synchronized {
         if (killed) { return }
-
-        process = Some(builder.start())
-
-        // Redirect stdout and stderr to files
-        val stdout = new File(baseDir, "stdout")
-        CommandUtils.redirectStream(process.get.getInputStream, stdout)
-
-        val stderr = new File(baseDir, "stderr")
-        val header = "Launch Command: %s\n%s\n\n".format(
-          command.mkString("\"", "\" \"", "\""), "=" * 40)
-        Files.append(header, stderr, Charsets.UTF_8)
-        CommandUtils.redirectStream(process.get.getErrorStream, stderr)
+        process = Some(command.start())
+        initialize(process.get)
       }
 
-      val processStart = System.currentTimeMillis()
+      val processStart = clock.currentTimeMillis()
       val exitCode = process.get.waitFor()
-      if (System.currentTimeMillis() - processStart > successfulRunDuration * 1000) {
+      if (clock.currentTimeMillis() - processStart > successfulRunDuration * 1000) {
         waitSeconds = 1
       }
 
       if (supervise && exitCode != 0 && !killed) {
-        waitSeconds = waitSeconds * 2 // exponential back-off
         logInfo(s"Command exited with status $exitCode, re-launching after $waitSeconds s.")
-        (0 until waitSeconds).takeWhile(f => {Thread.sleep(1000); !killed})
+        sleeper.sleep(waitSeconds)
+        waitSeconds = waitSeconds * 2 // exponential back-off
       }
 
       keepTrying = supervise && exitCode != 0 && !killed
     }
   }
 }
+
+private[deploy] trait Clock {
+  def currentTimeMillis(): Long
+}
+
+private[deploy] trait Sleeper {
+  def sleep(seconds: Int)
+}
+
+// Needed because ProcessBuilder is a final class and cannot be mocked
+private[deploy] trait ProcessBuilderLike {
+  def start(): Process
+  def command: Seq[String]
+}
+
+private[deploy] object ProcessBuilderLike {
+  def apply(processBuilder: ProcessBuilder) = new ProcessBuilderLike {
+    def start() = processBuilder.start()
+    def command = processBuilder.command()
+  }
+}

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/e21a707a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
index fdc9a34..a9cb998 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
@@ -98,6 +98,12 @@ private[spark] class ExecutorRunner(
     case other => other
   }
 
+  def getCommandSeq = {
+    val command = Command(appDesc.command.mainClass,
+      appDesc.command.arguments.map(substituteVariables), appDesc.command.environment)
+    CommandUtils.buildCommandSeq(command, memory, sparkHome.getAbsolutePath)
+  }
+
   /**
    * Download and run the executor described in our ApplicationDescription
    */
@@ -110,9 +116,7 @@ private[spark] class ExecutorRunner(
       }
 
       // Launch the process
-      val fullCommand = new Command(appDesc.command.mainClass,
-        appDesc.command.arguments.map(substituteVariables), appDesc.command.environment)
-      val command = CommandUtils.buildCommandSeq(fullCommand, memory, sparkHome.getAbsolutePath)
+      val command = getCommandSeq
       logInfo("Launch command: " + command.mkString("\"", "\" \"", "\""))
       val builder = new ProcessBuilder(command: _*).directory(executorDir)
       val env = builder.environment()

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/e21a707a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala
index f4184bc..0e0d0cd 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala
@@ -10,7 +10,8 @@ import org.apache.spark.deploy.DeployMessages.SendHeartbeat
  * Actor which connects to a worker process and terminates the JVM if the connection is severed.
  * Provides fate sharing between a worker and its associated child processes.
  */
-private[spark] class WorkerWatcher(workerUrl: String) extends Actor with Logging {
+private[spark] class WorkerWatcher(workerUrl: String) extends Actor
+    with Logging {
   override def preStart() {
     context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent])
 
@@ -19,10 +20,17 @@ private[spark] class WorkerWatcher(workerUrl: String) extends Actor with Logging
     worker ! SendHeartbeat // need to send a message here to initiate connection
   }
 
+  // Used to avoid shutting down JVM during tests
+  private[deploy] var isShutDown = false
+  private[deploy] def setTesting(testing: Boolean) = isTesting = testing
+  private var isTesting = false
+
   // Lets us filter events only from the worker's actor system
   private val expectedHostPort = AddressFromURIString(workerUrl).hostPort
   private def isWorker(address: Address) = address.hostPort == expectedHostPort
 
+  def exitNonZero() = if (isTesting) isShutDown = true else System.exit(-1)
+
   override def receive = {
     case AssociatedEvent(localAddress, remoteAddress, inbound) if isWorker(remoteAddress) =>
       logInfo(s"Successfully connected to $workerUrl")
@@ -32,12 +40,12 @@ private[spark] class WorkerWatcher(workerUrl: String) extends Actor with Logging
       // These logs may not be seen if the worker (and associated pipe) has died
       logError(s"Could not initialize connection to worker $workerUrl. Exiting.")
       logError(s"Error was: $cause")
-      System.exit(-1)
+      exitNonZero()
 
     case DisassociatedEvent(localAddress, remoteAddress, inbound) if isWorker(remoteAddress) =>
       // This log message will never be seen
       logError(s"Lost connection to worker actor $workerUrl. Exiting.")
-      System.exit(-1)
+      exitNonZero()
 
     case e: AssociationEvent =>
       // pass through association events relating to other remote actor systems

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/e21a707a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala
index 372c9f4..028196f 100644
--- a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala
@@ -86,7 +86,7 @@ class JsonProtocolSuite extends FunSuite {
   )
 
   def createDriverDesc() = new DriverDescription("hdfs://some-dir/some.jar", 100, 3,
-    createDriverCommand())
+    false, createDriverCommand())
 
   def createDriverInfo(): DriverInfo = new DriverInfo(3, "driver-3", createDriverDesc(), new Date())
 

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/e21a707a/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala b/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala
new file mode 100644
index 0000000..45dbcaf
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala
@@ -0,0 +1,131 @@
+package org.apache.spark.deploy.worker
+
+import java.io.File
+
+import scala.collection.JavaConversions._
+
+import org.mockito.Mockito._
+import org.mockito.Matchers._
+import org.scalatest.FunSuite
+
+import org.apache.spark.deploy.{Command, DriverDescription}
+import org.mockito.stubbing.Answer
+import org.mockito.invocation.InvocationOnMock
+
+class DriverRunnerTest extends FunSuite {
+  private def createDriverRunner() = {
+    val command = new Command("mainClass", Seq(), Map())
+    val driverDescription = new DriverDescription("jarUrl", 512, 1, true, command)
+    new DriverRunner("driverId", new File("workDir"), new File("sparkHome"), driverDescription,
+      null, "akka://1.2.3.4/worker/")
+  }
+
+  private def createProcessBuilderAndProcess(): (ProcessBuilderLike, Process) = {
+    val processBuilder = mock(classOf[ProcessBuilderLike])
+    when(processBuilder.command).thenReturn(Seq("mocked", "command"))
+    val process = mock(classOf[Process])
+    when(processBuilder.start()).thenReturn(process)
+    (processBuilder, process)
+  }
+
+  test("Process succeeds instantly") {
+    val runner = createDriverRunner()
+
+    val sleeper = mock(classOf[Sleeper])
+    runner.setSleeper(sleeper)
+
+    val (processBuilder, process) = createProcessBuilderAndProcess()
+    // One failure then a successful run
+    when(process.waitFor()).thenReturn(0)
+    runner.runCommandWithRetry(processBuilder, p => (), supervise = true)
+
+    verify(process, times(1)).waitFor()
+    verify(sleeper, times(0)).sleep(anyInt())
+  }
+
+  test("Process failing several times and then succeeding") {
+    val runner = createDriverRunner()
+
+    val sleeper = mock(classOf[Sleeper])
+    runner.setSleeper(sleeper)
+
+    val (processBuilder, process) = createProcessBuilderAndProcess()
+    // fail, fail, fail, success
+    when(process.waitFor()).thenReturn(-1).thenReturn(-1).thenReturn(-1).thenReturn(0)
+    runner.runCommandWithRetry(processBuilder, p => (), supervise = true)
+
+    verify(process, times(4)).waitFor()
+    verify(sleeper, times(3)).sleep(anyInt())
+    verify(sleeper, times(1)).sleep(1)
+    verify(sleeper, times(1)).sleep(2)
+    verify(sleeper, times(1)).sleep(4)
+  }
+
+  test("Process doesn't restart if not supervised") {
+    val runner = createDriverRunner()
+
+    val sleeper = mock(classOf[Sleeper])
+    runner.setSleeper(sleeper)
+
+    val (processBuilder, process) = createProcessBuilderAndProcess()
+    when(process.waitFor()).thenReturn(-1)
+
+    runner.runCommandWithRetry(processBuilder, p => (), supervise = false)
+
+    verify(process, times(1)).waitFor()
+    verify(sleeper, times(0)).sleep(anyInt())
+  }
+
+  test("Process doesn't restart if killed") {
+    val runner = createDriverRunner()
+
+    val sleeper = mock(classOf[Sleeper])
+    runner.setSleeper(sleeper)
+
+    val (processBuilder, process) = createProcessBuilderAndProcess()
+    when(process.waitFor()).thenAnswer(new Answer[Int] {
+      def answer(invocation: InvocationOnMock): Int = {
+        runner.kill()
+        -1
+      }
+    })
+
+    runner.runCommandWithRetry(processBuilder, p => (), supervise = true)
+
+    verify(process, times(1)).waitFor()
+    verify(sleeper, times(0)).sleep(anyInt())
+  }
+
+  test("Reset of backoff counter") {
+    val runner = createDriverRunner()
+
+    val sleeper = mock(classOf[Sleeper])
+    runner.setSleeper(sleeper)
+
+    val clock = mock(classOf[Clock])
+    runner.setClock(clock)
+
+    val (processBuilder, process) = createProcessBuilderAndProcess()
+
+    when(process.waitFor())
+      .thenReturn(-1) // fail 1
+      .thenReturn(-1) // fail 2
+      .thenReturn(-1) // fail 3
+      .thenReturn(-1) // fail 4
+      .thenReturn(0) // success
+    when(clock.currentTimeMillis())
+      .thenReturn(0).thenReturn(1000) // fail 1 (short)
+      .thenReturn(1000).thenReturn(2000) // fail 2 (short)
+      .thenReturn(2000).thenReturn(10000) // fail 3 (long)
+      .thenReturn(10000).thenReturn(11000) // fail 4 (short)
+      .thenReturn(11000).thenReturn(21000) // success (long)
+
+    runner.runCommandWithRetry(processBuilder, p => (), supervise = true)
+
+    verify(sleeper, times(4)).sleep(anyInt())
+    // Expected sequence of sleeps is 1,2,1,2
+    verify(sleeper, times(2)).sleep(1)
+    verify(sleeper, times(2)).sleep(2)
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/e21a707a/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala
index 7e5aaa3..bdb2c86 100644
--- a/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala
@@ -31,8 +31,8 @@ class ExecutorRunnerTest extends FunSuite {
       sparkHome, "appUiUrl")
     val appId = "12345-worker321-9876"
     val er = new ExecutorRunner(appId, 1, appDesc, 8, 500, null, "blah", "worker321", f(sparkHome),
-      f("ooga"), ExecutorState.RUNNING)
+      f("ooga"), "blah", ExecutorState.RUNNING)
 
-    assert(er.buildCommandSeq().last === appId)
+    assert(er.getCommandSeq.last === appId)
   }
 }

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/e21a707a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala
new file mode 100644
index 0000000..94d88d3
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala
@@ -0,0 +1,32 @@
+package org.apache.spark.deploy.worker
+
+
+import akka.testkit.TestActorRef
+import org.scalatest.FunSuite
+import akka.remote.DisassociatedEvent
+import akka.actor.{ActorSystem, AddressFromURIString, Props}
+
+class WorkerWatcherSuite extends FunSuite {
+  test("WorkerWatcher shuts down on valid disassociation") {
+    val actorSystem = ActorSystem("test")
+    val targetWorkerUrl = "akka://1.2.3.4/user/Worker"
+    val targetWorkerAddress = AddressFromURIString(targetWorkerUrl)
+    val actorRef = TestActorRef[WorkerWatcher](Props(classOf[WorkerWatcher], targetWorkerUrl))(actorSystem)
+    val workerWatcher = actorRef.underlyingActor
+    workerWatcher.setTesting(testing = true)
+    actorRef.underlyingActor.receive(new DisassociatedEvent(null, targetWorkerAddress, false))
+    assert(actorRef.underlyingActor.isShutDown)
+  }
+
+  test("WorkerWatcher stays alive on invalid disassociation") {
+    val actorSystem = ActorSystem("test")
+    val targetWorkerUrl = "akka://1.2.3.4/user/Worker"
+    val otherAkkaURL = "akka://4.3.2.1/user/OtherActor"
+    val otherAkkaAddress = AddressFromURIString(otherAkkaURL)
+    val actorRef = TestActorRef[WorkerWatcher](Props(classOf[WorkerWatcher], targetWorkerUrl))(actorSystem)
+    val workerWatcher = actorRef.underlyingActor
+    workerWatcher.setTesting(testing = true)
+    actorRef.underlyingActor.receive(new DisassociatedEvent(null, otherAkkaAddress, false))
+    assert(!actorRef.underlyingActor.isShutDown)
+  }
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/e21a707a/pom.xml
----------------------------------------------------------------------
diff --git a/pom.xml b/pom.xml
index 78d2f16..7b734c5 100644
--- a/pom.xml
+++ b/pom.xml
@@ -270,6 +270,18 @@
         </exclusions>
       </dependency>
       <dependency>
+        <groupId>${akka.group}</groupId>
+        <artifactId>akka-testkit_${scala.binary.version}</artifactId>
+        <version>${akka.version}</version>
+        <scope>test</scope>
+        <exclusions>
+          <exclusion>
+            <groupId>org.jboss.netty</groupId>
+            <artifactId>netty</artifactId>
+          </exclusion>
+        </exclusions>
+      </dependency>
+      <dependency>
         <groupId>it.unimi.dsi</groupId>
         <artifactId>fastutil</artifactId>
         <version>6.4.4</version>

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/e21a707a/project/SparkBuild.scala
----------------------------------------------------------------------
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index 051e510..bd5f3f7 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -233,6 +233,7 @@ object SparkBuild extends Build {
         "org.ow2.asm"              % "asm"              % "4.0",
         "org.spark-project.akka"  %% "akka-remote"      % "2.2.3-shaded-protobuf"  excludeAll(excludeNetty),
         "org.spark-project.akka"  %% "akka-slf4j"       % "2.2.3-shaded-protobuf"  excludeAll(excludeNetty),
+        "org.spark-project.akka"  %% "akka-testkit"     % "2.2.3-shaded-protobuf" % "test",
         "net.liftweb"             %% "lift-json"        % "2.5.1"  excludeAll(excludeNetty),
         "it.unimi.dsi"             % "fastutil"         % "6.4.4",
         "colt"                     % "colt"             % "1.2.0",