You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2015/07/15 23:02:29 UTC

spark git commit: [SPARK-6602][Core]Replace Akka Serialization with Spark Serializer

Repository: spark
Updated Branches:
  refs/heads/master 536533cad -> b9a922e26


[SPARK-6602][Core]Replace Akka Serialization with Spark Serializer

Replace Akka Serialization with Spark Serializer and add unit tests.

Author: zsxwing <zs...@gmail.com>

Closes #7159 from zsxwing/remove-akka-serialization and squashes the following commits:

fc0fca3 [zsxwing] Merge branch 'master' into remove-akka-serialization
cf81a58 [zsxwing] Fix the code style
73251c6 [zsxwing] Add test scope
9ef4af9 [zsxwing] Add AkkaRpcEndpointRef.hashCode
433115c [zsxwing] Remove final
be3edb0 [zsxwing] Support deserializing RpcEndpointRef
ecec410 [zsxwing] Replace Akka Serialization with Spark Serializer


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/b9a922e2
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/b9a922e2
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/b9a922e2

Branch: refs/heads/master
Commit: b9a922e260bec1b211437f020be37fab46a85db0
Parents: 536533c
Author: zsxwing <zs...@gmail.com>
Authored: Wed Jul 15 14:02:23 2015 -0700
Committer: Reynold Xin <rx...@databricks.com>
Committed: Wed Jul 15 14:02:23 2015 -0700

----------------------------------------------------------------------
 core/pom.xml                                    |   5 +
 .../master/FileSystemPersistenceEngine.scala    |  35 +++---
 .../org/apache/spark/deploy/master/Master.scala |  18 ++-
 .../spark/deploy/master/PersistenceEngine.scala |   8 +-
 .../deploy/master/RecoveryModeFactory.scala     |   9 +-
 .../master/ZooKeeperPersistenceEngine.scala     |  16 +--
 .../scala/org/apache/spark/rpc/RpcEnv.scala     |   6 +
 .../org/apache/spark/rpc/akka/AkkaRpcEnv.scala  |  14 ++-
 .../master/CustomRecoveryModeFactory.scala      |  31 ++---
 .../spark/deploy/master/MasterSuite.scala       |   2 +-
 .../deploy/master/PersistenceEngineSuite.scala  | 126 +++++++++++++++++++
 pom.xml                                         |   6 +
 12 files changed, 214 insertions(+), 62 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/b9a922e2/core/pom.xml
----------------------------------------------------------------------
diff --git a/core/pom.xml b/core/pom.xml
index 558cc3f..73f7a75 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -373,6 +373,11 @@
       <scope>test</scope>
     </dependency>
     <dependency>
+      <groupId>org.apache.curator</groupId>
+      <artifactId>curator-test</artifactId>
+      <scope>test</scope>
+    </dependency>
+    <dependency>
       <groupId>net.razorvine</groupId>
       <artifactId>pyrolite</artifactId>
       <version>4.4</version>

http://git-wip-us.apache.org/repos/asf/spark/blob/b9a922e2/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala
index f459ed5..aa379d4 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala
@@ -21,9 +21,8 @@ import java.io._
 
 import scala.reflect.ClassTag
 
-import akka.serialization.Serialization
-
 import org.apache.spark.Logging
+import org.apache.spark.serializer.{DeserializationStream, SerializationStream, Serializer}
 import org.apache.spark.util.Utils
 
 
@@ -32,11 +31,11 @@ import org.apache.spark.util.Utils
  * Files are deleted when applications and workers are removed.
  *
  * @param dir Directory to store files. Created if non-existent (but not recursively).
- * @param serialization Used to serialize our objects.
+ * @param serializer Used to serialize our objects.
  */
 private[master] class FileSystemPersistenceEngine(
     val dir: String,
-    val serialization: Serialization)
+    val serializer: Serializer)
   extends PersistenceEngine with Logging {
 
   new File(dir).mkdir()
@@ -57,27 +56,31 @@ private[master] class FileSystemPersistenceEngine(
   private def serializeIntoFile(file: File, value: AnyRef) {
     val created = file.createNewFile()
     if (!created) { throw new IllegalStateException("Could not create file: " + file) }
-    val serializer = serialization.findSerializerFor(value)
-    val serialized = serializer.toBinary(value)
-    val out = new FileOutputStream(file)
+    val fileOut = new FileOutputStream(file)
+    var out: SerializationStream = null
     Utils.tryWithSafeFinally {
-      out.write(serialized)
+      out = serializer.newInstance().serializeStream(fileOut)
+      out.writeObject(value)
     } {
-      out.close()
+      fileOut.close()
+      if (out != null) {
+        out.close()
+      }
     }
   }
 
   private def deserializeFromFile[T](file: File)(implicit m: ClassTag[T]): T = {
-    val fileData = new Array[Byte](file.length().asInstanceOf[Int])
-    val dis = new DataInputStream(new FileInputStream(file))
+    val fileIn = new FileInputStream(file)
+    var in: DeserializationStream = null
     try {
-      dis.readFully(fileData)
+      in = serializer.newInstance().deserializeStream(fileIn)
+      in.readObject[T]()
     } finally {
-      dis.close()
+      fileIn.close()
+      if (in != null) {
+        in.close()
+      }
     }
-    val clazz = m.runtimeClass.asInstanceOf[Class[T]]
-    val serializer = serialization.serializerFor(clazz)
-    serializer.fromBinary(fileData).asInstanceOf[T]
   }
 
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/b9a922e2/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
index 245b047..4615feb 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
@@ -27,11 +27,8 @@ import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
 import scala.language.postfixOps
 import scala.util.Random
 
-import akka.serialization.Serialization
-import akka.serialization.SerializationExtension
 import org.apache.hadoop.fs.Path
 
-import org.apache.spark.rpc.akka.AkkaRpcEnv
 import org.apache.spark.rpc._
 import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException}
 import org.apache.spark.deploy.{ApplicationDescription, DriverDescription,
@@ -44,6 +41,7 @@ import org.apache.spark.deploy.master.ui.MasterWebUI
 import org.apache.spark.deploy.rest.StandaloneRestServer
 import org.apache.spark.metrics.MetricsSystem
 import org.apache.spark.scheduler.{EventLoggingListener, ReplayListenerBus}
+import org.apache.spark.serializer.{JavaSerializer, Serializer}
 import org.apache.spark.ui.SparkUI
 import org.apache.spark.util.{ThreadUtils, SignalLogger, Utils}
 
@@ -58,9 +56,6 @@ private[master] class Master(
   private val forwardMessageThread =
     ThreadUtils.newDaemonSingleThreadScheduledExecutor("master-forward-message-thread")
 
-  // TODO Remove it once we don't use akka.serialization.Serialization
-  private val actorSystem = rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem
-
   private val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf)
 
   private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs
@@ -161,20 +156,21 @@ private[master] class Master(
     masterMetricsSystem.getServletHandlers.foreach(webUi.attachHandler)
     applicationMetricsSystem.getServletHandlers.foreach(webUi.attachHandler)
 
+    val serializer = new JavaSerializer(conf)
     val (persistenceEngine_, leaderElectionAgent_) = RECOVERY_MODE match {
       case "ZOOKEEPER" =>
         logInfo("Persisting recovery state to ZooKeeper")
         val zkFactory =
-          new ZooKeeperRecoveryModeFactory(conf, SerializationExtension(actorSystem))
+          new ZooKeeperRecoveryModeFactory(conf, serializer)
         (zkFactory.createPersistenceEngine(), zkFactory.createLeaderElectionAgent(this))
       case "FILESYSTEM" =>
         val fsFactory =
-          new FileSystemRecoveryModeFactory(conf, SerializationExtension(actorSystem))
+          new FileSystemRecoveryModeFactory(conf, serializer)
         (fsFactory.createPersistenceEngine(), fsFactory.createLeaderElectionAgent(this))
       case "CUSTOM" =>
         val clazz = Utils.classForName(conf.get("spark.deploy.recoveryMode.factory"))
-        val factory = clazz.getConstructor(classOf[SparkConf], classOf[Serialization])
-          .newInstance(conf, SerializationExtension(actorSystem))
+        val factory = clazz.getConstructor(classOf[SparkConf], classOf[Serializer])
+          .newInstance(conf, serializer)
           .asInstanceOf[StandaloneRecoveryModeFactory]
         (factory.createPersistenceEngine(), factory.createLeaderElectionAgent(this))
       case _ =>
@@ -213,7 +209,7 @@ private[master] class Master(
 
   override def receive: PartialFunction[Any, Unit] = {
     case ElectedLeader => {
-      val (storedApps, storedDrivers, storedWorkers) = persistenceEngine.readPersistedData()
+      val (storedApps, storedDrivers, storedWorkers) = persistenceEngine.readPersistedData(rpcEnv)
       state = if (storedApps.isEmpty && storedDrivers.isEmpty && storedWorkers.isEmpty) {
         RecoveryState.ALIVE
       } else {

http://git-wip-us.apache.org/repos/asf/spark/blob/b9a922e2/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala
index a03d460..58a00bc 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.deploy.master
 
 import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.rpc.RpcEnv
 
 import scala.reflect.ClassTag
 
@@ -80,8 +81,11 @@ abstract class PersistenceEngine {
    * Returns the persisted data sorted by their respective ids (which implies that they're
    * sorted by time of creation).
    */
-  final def readPersistedData(): (Seq[ApplicationInfo], Seq[DriverInfo], Seq[WorkerInfo]) = {
-    (read[ApplicationInfo]("app_"), read[DriverInfo]("driver_"), read[WorkerInfo]("worker_"))
+  final def readPersistedData(
+      rpcEnv: RpcEnv): (Seq[ApplicationInfo], Seq[DriverInfo], Seq[WorkerInfo]) = {
+    rpcEnv.deserialize { () =>
+      (read[ApplicationInfo]("app_"), read[DriverInfo]("driver_"), read[WorkerInfo]("worker_"))
+    }
   }
 
   def close() {}

http://git-wip-us.apache.org/repos/asf/spark/blob/b9a922e2/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala b/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala
index 351db8f..c4c3283 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala
@@ -17,10 +17,9 @@
 
 package org.apache.spark.deploy.master
 
-import akka.serialization.Serialization
-
 import org.apache.spark.{Logging, SparkConf}
 import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.serializer.Serializer
 
 /**
  * ::DeveloperApi::
@@ -30,7 +29,7 @@ import org.apache.spark.annotation.DeveloperApi
  *
  */
 @DeveloperApi
-abstract class StandaloneRecoveryModeFactory(conf: SparkConf, serializer: Serialization) {
+abstract class StandaloneRecoveryModeFactory(conf: SparkConf, serializer: Serializer) {
 
   /**
    * PersistenceEngine defines how the persistent data(Information about worker, driver etc..)
@@ -49,7 +48,7 @@ abstract class StandaloneRecoveryModeFactory(conf: SparkConf, serializer: Serial
  * LeaderAgent in this case is a no-op. Since leader is forever leader as the actual
  * recovery is made by restoring from filesystem.
  */
-private[master] class FileSystemRecoveryModeFactory(conf: SparkConf, serializer: Serialization)
+private[master] class FileSystemRecoveryModeFactory(conf: SparkConf, serializer: Serializer)
   extends StandaloneRecoveryModeFactory(conf, serializer) with Logging {
 
   val RECOVERY_DIR = conf.get("spark.deploy.recoveryDirectory", "")
@@ -64,7 +63,7 @@ private[master] class FileSystemRecoveryModeFactory(conf: SparkConf, serializer:
   }
 }
 
-private[master] class ZooKeeperRecoveryModeFactory(conf: SparkConf, serializer: Serialization)
+private[master] class ZooKeeperRecoveryModeFactory(conf: SparkConf, serializer: Serializer)
   extends StandaloneRecoveryModeFactory(conf, serializer) {
 
   def createPersistenceEngine(): PersistenceEngine = {

http://git-wip-us.apache.org/repos/asf/spark/blob/b9a922e2/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala
index 328d95a..563831c 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.deploy.master
 
-import akka.serialization.Serialization
+import java.nio.ByteBuffer
 
 import scala.collection.JavaConversions._
 import scala.reflect.ClassTag
@@ -27,9 +27,10 @@ import org.apache.zookeeper.CreateMode
 
 import org.apache.spark.{Logging, SparkConf}
 import org.apache.spark.deploy.SparkCuratorUtil
+import org.apache.spark.serializer.Serializer
 
 
-private[master] class ZooKeeperPersistenceEngine(conf: SparkConf, val serialization: Serialization)
+private[master] class ZooKeeperPersistenceEngine(conf: SparkConf, val serializer: Serializer)
   extends PersistenceEngine
   with Logging {
 
@@ -57,17 +58,16 @@ private[master] class ZooKeeperPersistenceEngine(conf: SparkConf, val serializat
   }
 
   private def serializeIntoFile(path: String, value: AnyRef) {
-    val serializer = serialization.findSerializerFor(value)
-    val serialized = serializer.toBinary(value)
-    zk.create().withMode(CreateMode.PERSISTENT).forPath(path, serialized)
+    val serialized = serializer.newInstance().serialize(value)
+    val bytes = new Array[Byte](serialized.remaining())
+    serialized.get(bytes)
+    zk.create().withMode(CreateMode.PERSISTENT).forPath(path, bytes)
   }
 
   private def deserializeFromFile[T](filename: String)(implicit m: ClassTag[T]): Option[T] = {
     val fileData = zk.getData().forPath(WORKING_DIR + "/" + filename)
-    val clazz = m.runtimeClass.asInstanceOf[Class[T]]
-    val serializer = serialization.serializerFor(clazz)
     try {
-      Some(serializer.fromBinary(fileData).asInstanceOf[T])
+      Some(serializer.newInstance().deserialize[T](ByteBuffer.wrap(fileData)))
     } catch {
       case e: Exception => {
         logWarning("Exception while reading persisted file, deleting", e)

http://git-wip-us.apache.org/repos/asf/spark/blob/b9a922e2/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
index c9fcc7a..29debe8 100644
--- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
@@ -139,6 +139,12 @@ private[spark] abstract class RpcEnv(conf: SparkConf) {
    * creating it manually because different [[RpcEnv]] may have different formats.
    */
   def uriOf(systemName: String, address: RpcAddress, endpointName: String): String
+
+  /**
+   * [[RpcEndpointRef]] cannot be deserialized without [[RpcEnv]]. So when deserializing any object
+   * that contains [[RpcEndpointRef]]s, the deserialization codes should be wrapped by this method.
+   */
+  def deserialize[T](deserializationAction: () => T): T
 }
 
 

http://git-wip-us.apache.org/repos/asf/spark/blob/b9a922e2/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala
index f2d87f6..fc17542 100644
--- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala
@@ -28,7 +28,7 @@ import akka.actor.{ActorSystem, ExtendedActorSystem, Actor, ActorRef, Props, Add
 import akka.event.Logging.Error
 import akka.pattern.{ask => akkaAsk}
 import akka.remote.{AssociationEvent, AssociatedEvent, DisassociatedEvent, AssociationErrorEvent}
-import com.google.common.util.concurrent.MoreExecutors
+import akka.serialization.JavaSerializer
 
 import org.apache.spark.{SparkException, Logging, SparkConf}
 import org.apache.spark.rpc._
@@ -239,6 +239,12 @@ private[spark] class AkkaRpcEnv private[akka] (
   }
 
   override def toString: String = s"${getClass.getSimpleName}($actorSystem)"
+
+  override def deserialize[T](deserializationAction: () => T): T = {
+    JavaSerializer.currentSystem.withValue(actorSystem.asInstanceOf[ExtendedActorSystem]) {
+      deserializationAction()
+    }
+  }
 }
 
 private[spark] class AkkaRpcEnvFactory extends RpcEnvFactory {
@@ -315,6 +321,12 @@ private[akka] class AkkaRpcEndpointRef(
 
   override def toString: String = s"${getClass.getSimpleName}($actorRef)"
 
+  final override def equals(that: Any): Boolean = that match {
+    case other: AkkaRpcEndpointRef => actorRef == other.actorRef
+    case _ => false
+  }
+
+  final override def hashCode(): Int = if (actorRef == null) 0 else actorRef.hashCode()
 }
 
 /**

http://git-wip-us.apache.org/repos/asf/spark/blob/b9a922e2/core/src/test/scala/org/apache/spark/deploy/master/CustomRecoveryModeFactory.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/deploy/master/CustomRecoveryModeFactory.scala b/core/src/test/scala/org/apache/spark/deploy/master/CustomRecoveryModeFactory.scala
index f4e5663..8c96b0e 100644
--- a/core/src/test/scala/org/apache/spark/deploy/master/CustomRecoveryModeFactory.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/master/CustomRecoveryModeFactory.scala
@@ -19,18 +19,19 @@
 // when they are outside of org.apache.spark.
 package other.supplier
 
+import java.nio.ByteBuffer
+
 import scala.collection.mutable
 import scala.reflect.ClassTag
 
-import akka.serialization.Serialization
-
 import org.apache.spark.SparkConf
 import org.apache.spark.deploy.master._
+import org.apache.spark.serializer.Serializer
 
 class CustomRecoveryModeFactory(
   conf: SparkConf,
-  serialization: Serialization
-) extends StandaloneRecoveryModeFactory(conf, serialization) {
+  serializer: Serializer
+) extends StandaloneRecoveryModeFactory(conf, serializer) {
 
   CustomRecoveryModeFactory.instantiationAttempts += 1
 
@@ -40,7 +41,7 @@ class CustomRecoveryModeFactory(
    *
    */
   override def createPersistenceEngine(): PersistenceEngine =
-    new CustomPersistenceEngine(serialization)
+    new CustomPersistenceEngine(serializer)
 
   /**
    * Create an instance of LeaderAgent that decides who gets elected as master.
@@ -53,7 +54,7 @@ object CustomRecoveryModeFactory {
   @volatile var instantiationAttempts = 0
 }
 
-class CustomPersistenceEngine(serialization: Serialization) extends PersistenceEngine {
+class CustomPersistenceEngine(serializer: Serializer) extends PersistenceEngine {
   val data = mutable.HashMap[String, Array[Byte]]()
 
   CustomPersistenceEngine.lastInstance = Some(this)
@@ -64,10 +65,10 @@ class CustomPersistenceEngine(serialization: Serialization) extends PersistenceE
    */
   override def persist(name: String, obj: Object): Unit = {
     CustomPersistenceEngine.persistAttempts += 1
-    serialization.serialize(obj) match {
-      case util.Success(bytes) => data += name -> bytes
-      case util.Failure(cause) => throw new RuntimeException(cause)
-    }
+    val serialized = serializer.newInstance().serialize(obj)
+    val bytes = new Array[Byte](serialized.remaining())
+    serialized.get(bytes)
+    data += name -> bytes
   }
 
   /**
@@ -84,15 +85,9 @@ class CustomPersistenceEngine(serialization: Serialization) extends PersistenceE
    */
   override def read[T: ClassTag](prefix: String): Seq[T] = {
     CustomPersistenceEngine.readAttempts += 1
-    val clazz = implicitly[ClassTag[T]].runtimeClass.asInstanceOf[Class[T]]
     val results = for ((name, bytes) <- data; if name.startsWith(prefix))
-      yield serialization.deserialize(bytes, clazz)
-
-    results.find(_.isFailure).foreach {
-      case util.Failure(cause) => throw new RuntimeException(cause)
-    }
-
-    results.flatMap(_.toOption).toSeq
+      yield serializer.newInstance().deserialize[T](ByteBuffer.wrap(bytes))
+    results.toSeq
   }
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/b9a922e2/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala
index 9cb6dd4..a8fbaf1 100644
--- a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala
@@ -105,7 +105,7 @@ class MasterSuite extends SparkFunSuite with Matchers with Eventually {
       persistenceEngine.addDriver(driverToPersist)
       persistenceEngine.addWorker(workerToPersist)
 
-      val (apps, drivers, workers) = persistenceEngine.readPersistedData()
+      val (apps, drivers, workers) = persistenceEngine.readPersistedData(rpcEnv)
 
       apps.map(_.id) should contain(appToPersist.id)
       drivers.map(_.id) should contain(driverToPersist.id)

http://git-wip-us.apache.org/repos/asf/spark/blob/b9a922e2/core/src/test/scala/org/apache/spark/deploy/master/PersistenceEngineSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/deploy/master/PersistenceEngineSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/PersistenceEngineSuite.scala
new file mode 100644
index 0000000..11e87bd
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/deploy/master/PersistenceEngineSuite.scala
@@ -0,0 +1,126 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+package org.apache.spark.deploy.master
+
+import java.net.ServerSocket
+
+import org.apache.commons.lang3.RandomUtils
+import org.apache.curator.test.TestingServer
+
+import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite}
+import org.apache.spark.rpc.{RpcEndpoint, RpcEnv}
+import org.apache.spark.serializer.{Serializer, JavaSerializer}
+import org.apache.spark.util.Utils
+
+class PersistenceEngineSuite extends SparkFunSuite {
+
+  test("FileSystemPersistenceEngine") {
+    val dir = Utils.createTempDir()
+    try {
+      val conf = new SparkConf()
+      testPersistenceEngine(conf, serializer =>
+        new FileSystemPersistenceEngine(dir.getAbsolutePath, serializer)
+      )
+    } finally {
+      Utils.deleteRecursively(dir)
+    }
+  }
+
+  test("ZooKeeperPersistenceEngine") {
+    val conf = new SparkConf()
+    // TestingServer logs the port conflict exception rather than throwing an exception.
+    // So we have to find a free port by ourselves. This approach cannot guarantee always starting
+    // zkTestServer successfully because there is a time gap between finding a free port and
+    // starting zkTestServer. But the failure possibility should be very low.
+    val zkTestServer = new TestingServer(findFreePort(conf))
+    try {
+      testPersistenceEngine(conf, serializer => {
+        conf.set("spark.deploy.zookeeper.url", zkTestServer.getConnectString)
+        new ZooKeeperPersistenceEngine(conf, serializer)
+      })
+    } finally {
+      zkTestServer.stop()
+    }
+  }
+
+  private def testPersistenceEngine(
+      conf: SparkConf, persistenceEngineCreator: Serializer => PersistenceEngine): Unit = {
+    val serializer = new JavaSerializer(conf)
+    val persistenceEngine = persistenceEngineCreator(serializer)
+    persistenceEngine.persist("test_1", "test_1_value")
+    assert(Seq("test_1_value") === persistenceEngine.read[String]("test_"))
+    persistenceEngine.persist("test_2", "test_2_value")
+    assert(Set("test_1_value", "test_2_value") === persistenceEngine.read[String]("test_").toSet)
+    persistenceEngine.unpersist("test_1")
+    assert(Seq("test_2_value") === persistenceEngine.read[String]("test_"))
+    persistenceEngine.unpersist("test_2")
+    assert(persistenceEngine.read[String]("test_").isEmpty)
+
+    // Test deserializing objects that contain RpcEndpointRef
+    val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf))
+    try {
+      // Create a real endpoint so that we can test RpcEndpointRef deserialization
+      val workerEndpoint = rpcEnv.setupEndpoint("worker", new RpcEndpoint {
+        override val rpcEnv: RpcEnv = rpcEnv
+      })
+
+      val workerToPersist = new WorkerInfo(
+        id = "test_worker",
+        host = "127.0.0.1",
+        port = 10000,
+        cores = 0,
+        memory = 0,
+        endpoint = workerEndpoint,
+        webUiPort = 0,
+        publicAddress = ""
+      )
+
+      persistenceEngine.addWorker(workerToPersist)
+
+      val (storedApps, storedDrivers, storedWorkers) = persistenceEngine.readPersistedData(rpcEnv)
+
+      assert(storedApps.isEmpty)
+      assert(storedDrivers.isEmpty)
+
+      // Check deserializing WorkerInfo
+      assert(storedWorkers.size == 1)
+      val recoveryWorkerInfo = storedWorkers.head
+      assert(workerToPersist.id === recoveryWorkerInfo.id)
+      assert(workerToPersist.host === recoveryWorkerInfo.host)
+      assert(workerToPersist.port === recoveryWorkerInfo.port)
+      assert(workerToPersist.cores === recoveryWorkerInfo.cores)
+      assert(workerToPersist.memory === recoveryWorkerInfo.memory)
+      assert(workerToPersist.endpoint === recoveryWorkerInfo.endpoint)
+      assert(workerToPersist.webUiPort === recoveryWorkerInfo.webUiPort)
+      assert(workerToPersist.publicAddress === recoveryWorkerInfo.publicAddress)
+    } finally {
+      rpcEnv.shutdown()
+      rpcEnv.awaitTermination()
+    }
+  }
+
+  private def findFreePort(conf: SparkConf): Int = {
+    val candidatePort = RandomUtils.nextInt(1024, 65536)
+    Utils.startServiceOnPort(candidatePort, (trialPort: Int) => {
+      val socket = new ServerSocket(trialPort)
+      socket.close()
+      (null, trialPort)
+    }, conf)._2
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/b9a922e2/pom.xml
----------------------------------------------------------------------
diff --git a/pom.xml b/pom.xml
index 370c95d..aa49e2a 100644
--- a/pom.xml
+++ b/pom.xml
@@ -749,6 +749,12 @@
         <version>${curator.version}</version>
       </dependency>
       <dependency>
+        <groupId>org.apache.curator</groupId>
+        <artifactId>curator-test</artifactId>
+        <version>${curator.version}</version>
+        <scope>test</scope>
+      </dependency>
+      <dependency>
         <groupId>org.apache.hadoop</groupId>
         <artifactId>hadoop-client</artifactId>
         <version>${hadoop.version}</version>


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org