You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by td...@apache.org on 2016/01/08 11:02:18 UTC

spark git commit: [SPARK-12591][STREAMING] Register OpenHashMapBasedStateMap for Kryo (branch 1.6)

Repository: spark
Updated Branches:
  refs/heads/branch-1.6 a7c36362f -> 0d96c5453


[SPARK-12591][STREAMING] Register OpenHashMapBasedStateMap for Kryo (branch 1.6)

backport #10609 to branch 1.6

Author: Shixiong Zhu <sh...@databricks.com>

Closes #10656 from zsxwing/SPARK-12591-branch-1.6.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/0d96c545
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/0d96c545
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/0d96c545

Branch: refs/heads/branch-1.6
Commit: 0d96c54534d8bfca191c892b98397a176bc46152
Parents: a7c3636
Author: Shixiong Zhu <sh...@databricks.com>
Authored: Fri Jan 8 02:02:06 2016 -0800
Committer: Tathagata Das <ta...@gmail.com>
Committed: Fri Jan 8 02:02:06 2016 -0800

----------------------------------------------------------------------
 .../spark/serializer/KryoSerializer.scala       | 24 ++++-
 .../spark/serializer/KryoSerializerSuite.scala  | 20 +++-
 project/MimaExcludes.scala                      |  4 +
 .../apache/spark/streaming/util/StateMap.scala  | 71 ++++++++++-----
 .../apache/spark/streaming/StateMapSuite.scala  | 96 +++++++++++++++++---
 5 files changed, 174 insertions(+), 41 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/0d96c545/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
index e8268a5..87021ec 100644
--- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.serializer
 
-import java.io.{DataInput, DataOutput, EOFException, IOException, InputStream, OutputStream}
+import java.io._
 import java.nio.ByteBuffer
 import javax.annotation.Nullable
 
@@ -378,18 +378,24 @@ private[serializer] object KryoSerializer {
   private val toRegisterSerializer = Map[Class[_], KryoClassSerializer[_]](
     classOf[RoaringBitmap] -> new KryoClassSerializer[RoaringBitmap]() {
       override def write(kryo: Kryo, output: KryoOutput, bitmap: RoaringBitmap): Unit = {
-        bitmap.serialize(new KryoOutputDataOutputBridge(output))
+        bitmap.serialize(new KryoOutputObjectOutputBridge(kryo, output))
       }
       override def read(kryo: Kryo, input: KryoInput, cls: Class[RoaringBitmap]): RoaringBitmap = {
         val ret = new RoaringBitmap
-        ret.deserialize(new KryoInputDataInputBridge(input))
+        ret.deserialize(new KryoInputObjectInputBridge(kryo, input))
         ret
       }
     }
   )
 }
 
-private[serializer] class KryoInputDataInputBridge(input: KryoInput) extends DataInput {
+/**
+ * This is a bridge class to wrap KryoInput as an InputStream and ObjectInput. It forwards all
+ * methods of InputStream and ObjectInput to KryoInput. It's usually helpful when an API expects
+ * an InputStream or ObjectInput but you want to use Kryo.
+ */
+private[spark] class KryoInputObjectInputBridge(
+    kryo: Kryo, input: KryoInput) extends FilterInputStream(input) with ObjectInput {
   override def readLong(): Long = input.readLong()
   override def readChar(): Char = input.readChar()
   override def readFloat(): Float = input.readFloat()
@@ -408,9 +414,16 @@ private[serializer] class KryoInputDataInputBridge(input: KryoInput) extends Dat
   override def readBoolean(): Boolean = input.readBoolean()
   override def readUnsignedByte(): Int = input.readByteUnsigned()
   override def readDouble(): Double = input.readDouble()
+  override def readObject(): AnyRef = kryo.readClassAndObject(input)
 }
 
-private[serializer] class KryoOutputDataOutputBridge(output: KryoOutput) extends DataOutput {
+/**
+ * This is a bridge class to wrap KryoOutput as an OutputStream and ObjectOutput. It forwards all
+ * methods of OutputStream and ObjectOutput to KryoOutput. It's usually helpful when an API expects
+ * an OutputStream or ObjectOutput but you want to use Kryo.
+ */
+private[spark] class KryoOutputObjectOutputBridge(
+    kryo: Kryo, output: KryoOutput) extends FilterOutputStream(output) with ObjectOutput  {
   override def writeFloat(v: Float): Unit = output.writeFloat(v)
   // There is no "readChars" counterpart, except maybe "readLine", which is not supported
   override def writeChars(s: String): Unit = throw new UnsupportedOperationException("writeChars")
@@ -426,6 +439,7 @@ private[serializer] class KryoOutputDataOutputBridge(output: KryoOutput) extends
   override def writeChar(v: Int): Unit = output.writeChar(v.toChar)
   override def writeLong(v: Long): Unit = output.writeLong(v)
   override def writeByte(v: Int): Unit = output.writeByte(v)
+  override def writeObject(obj: AnyRef): Unit = kryo.writeClassAndObject(output, obj)
 }
 
 /**

http://git-wip-us.apache.org/repos/asf/spark/blob/0d96c545/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala
index f700f1d..ad5e1e2 100644
--- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala
@@ -363,19 +363,35 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext {
     bitmap.add(1)
     bitmap.add(3)
     bitmap.add(5)
-    bitmap.serialize(new KryoOutputDataOutputBridge(output))
+    // Ignore Kryo because it doesn't use writeObject
+    bitmap.serialize(new KryoOutputObjectOutputBridge(null, output))
     output.flush()
     output.close()
 
     val inStream = new FileInputStream(tmpfile)
     val input = new KryoInput(inStream)
     val ret = new RoaringBitmap
-    ret.deserialize(new KryoInputDataInputBridge(input))
+    // Ignore Kryo because it doesn't use readObject
+    ret.deserialize(new KryoInputObjectInputBridge(null, input))
     input.close()
     assert(ret == bitmap)
     Utils.deleteRecursively(dir)
   }
 
+  test("KryoOutputObjectOutputBridge.writeObject and KryoInputObjectInputBridge.readObject") {
+    val kryo = new KryoSerializer(conf).newKryo()
+
+    val bytesOutput = new ByteArrayOutputStream()
+    val objectOutput = new KryoOutputObjectOutputBridge(kryo, new KryoOutput(bytesOutput))
+    objectOutput.writeObject("test")
+    objectOutput.close()
+
+    val bytesInput = new ByteArrayInputStream(bytesOutput.toByteArray)
+    val objectInput = new KryoInputObjectInputBridge(kryo, new KryoInput(bytesInput))
+    assert(objectInput.readObject() === "test")
+    objectInput.close()
+  }
+
   test("getAutoReset") {
     val ser = new KryoSerializer(new SparkConf).newInstance().asInstanceOf[KryoSerializerInstance]
     assert(ser.getAutoReset)

http://git-wip-us.apache.org/repos/asf/spark/blob/0d96c545/project/MimaExcludes.scala
----------------------------------------------------------------------
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index d3a3c0c..08b4a23 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -159,6 +159,10 @@ object MimaExcludes {
         // SPARK-3580 Add getNumPartitions method to JavaRDD
         ProblemFilters.exclude[MissingMethodProblem](
           "org.apache.spark.api.java.JavaRDDLike.getNumPartitions")
+      ) ++ Seq(
+        // SPARK-12591 Register OpenHashMapBasedStateMap for Kryo
+        ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.serializer.KryoInputDataInputBridge"),
+        ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.serializer.KryoOutputDataOutputBridge")
       )
     case v if v.startsWith("1.5") =>
       Seq(

http://git-wip-us.apache.org/repos/asf/spark/blob/0d96c545/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala
----------------------------------------------------------------------
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala
index 3f139ad..4e5baeb 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala
@@ -17,16 +17,20 @@
 
 package org.apache.spark.streaming.util
 
-import java.io.{ObjectInputStream, ObjectOutputStream}
+import java.io._
 
 import scala.reflect.ClassTag
 
+import com.esotericsoftware.kryo.{Kryo, KryoSerializable}
+import com.esotericsoftware.kryo.io.{Input, Output}
+
 import org.apache.spark.SparkConf
+import org.apache.spark.serializer.{KryoOutputObjectOutputBridge, KryoInputObjectInputBridge}
 import org.apache.spark.streaming.util.OpenHashMapBasedStateMap._
 import org.apache.spark.util.collection.OpenHashMap
 
 /** Internal interface for defining the map that keeps track of sessions. */
-private[streaming] abstract class StateMap[K: ClassTag, S: ClassTag] extends Serializable {
+private[streaming] abstract class StateMap[K, S] extends Serializable {
 
   /** Get the state for a key if it exists */
   def get(key: K): Option[S]
@@ -54,7 +58,7 @@ private[streaming] abstract class StateMap[K: ClassTag, S: ClassTag] extends Ser
 
 /** Companion object for [[StateMap]], with utility methods */
 private[streaming] object StateMap {
-  def empty[K: ClassTag, S: ClassTag]: StateMap[K, S] = new EmptyStateMap[K, S]
+  def empty[K, S]: StateMap[K, S] = new EmptyStateMap[K, S]
 
   def create[K: ClassTag, S: ClassTag](conf: SparkConf): StateMap[K, S] = {
     val deltaChainThreshold = conf.getInt("spark.streaming.sessionByKey.deltaChainThreshold",
@@ -64,7 +68,7 @@ private[streaming] object StateMap {
 }
 
 /** Implementation of StateMap interface representing an empty map */
-private[streaming] class EmptyStateMap[K: ClassTag, S: ClassTag] extends StateMap[K, S] {
+private[streaming] class EmptyStateMap[K, S] extends StateMap[K, S] {
   override def put(key: K, session: S, updateTime: Long): Unit = {
     throw new NotImplementedError("put() should not be called on an EmptyStateMap")
   }
@@ -77,21 +81,26 @@ private[streaming] class EmptyStateMap[K: ClassTag, S: ClassTag] extends StateMa
 }
 
 /** Implementation of StateMap based on Spark's [[org.apache.spark.util.collection.OpenHashMap]] */
-private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag](
+private[streaming] class OpenHashMapBasedStateMap[K, S](
     @transient @volatile var parentStateMap: StateMap[K, S],
-    initialCapacity: Int = DEFAULT_INITIAL_CAPACITY,
-    deltaChainThreshold: Int = DELTA_CHAIN_LENGTH_THRESHOLD
-  ) extends StateMap[K, S] { self =>
+    private var initialCapacity: Int = DEFAULT_INITIAL_CAPACITY,
+    private var deltaChainThreshold: Int = DELTA_CHAIN_LENGTH_THRESHOLD
+  )(implicit private var keyClassTag: ClassTag[K], private var stateClassTag: ClassTag[S])
+  extends StateMap[K, S] with KryoSerializable { self =>
 
-  def this(initialCapacity: Int, deltaChainThreshold: Int) = this(
+  def this(initialCapacity: Int, deltaChainThreshold: Int)
+      (implicit keyClassTag: ClassTag[K], stateClassTag: ClassTag[S]) = this(
     new EmptyStateMap[K, S],
     initialCapacity = initialCapacity,
     deltaChainThreshold = deltaChainThreshold)
 
-  def this(deltaChainThreshold: Int) = this(
+  def this(deltaChainThreshold: Int)
+      (implicit keyClassTag: ClassTag[K], stateClassTag: ClassTag[S]) = this(
     initialCapacity = DEFAULT_INITIAL_CAPACITY, deltaChainThreshold = deltaChainThreshold)
 
-  def this() = this(DELTA_CHAIN_LENGTH_THRESHOLD)
+  def this()(implicit keyClassTag: ClassTag[K], stateClassTag: ClassTag[S]) = {
+    this(DELTA_CHAIN_LENGTH_THRESHOLD)
+  }
 
   require(initialCapacity >= 1, "Invalid initial capacity")
   require(deltaChainThreshold >= 1, "Invalid delta chain threshold")
@@ -206,11 +215,7 @@ private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag](
    * Serialize the map data. Besides serialization, this method actually compact the deltas
    * (if needed) in a single pass over all the data in the map.
    */
-
-  private def writeObject(outputStream: ObjectOutputStream): Unit = {
-    // Write all the non-transient fields, especially class tags, etc.
-    outputStream.defaultWriteObject()
-
+  private def writeObjectInternal(outputStream: ObjectOutput): Unit = {
     // Write the data in the delta of this state map
     outputStream.writeInt(deltaMap.size)
     val deltaMapIterator = deltaMap.iterator
@@ -262,11 +267,7 @@ private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag](
   }
 
   /** Deserialize the map data. */
-  private def readObject(inputStream: ObjectInputStream): Unit = {
-
-    // Read the non-transient fields, especially class tags, etc.
-    inputStream.defaultReadObject()
-
+  private def readObjectInternal(inputStream: ObjectInput): Unit = {
     // Read the data of the delta
     val deltaMapSize = inputStream.readInt()
     deltaMap = if (deltaMapSize != 0) {
@@ -309,6 +310,34 @@ private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag](
     }
     parentStateMap = newParentSessionStore
   }
+
+  private def writeObject(outputStream: ObjectOutputStream): Unit = {
+    // Write all the non-transient fields, especially class tags, etc.
+    outputStream.defaultWriteObject()
+    writeObjectInternal(outputStream)
+  }
+
+  private def readObject(inputStream: ObjectInputStream): Unit = {
+    // Read the non-transient fields, especially class tags, etc.
+    inputStream.defaultReadObject()
+    readObjectInternal(inputStream)
+  }
+
+  override def write(kryo: Kryo, output: Output): Unit = {
+    output.writeInt(initialCapacity)
+    output.writeInt(deltaChainThreshold)
+    kryo.writeClassAndObject(output, keyClassTag)
+    kryo.writeClassAndObject(output, stateClassTag)
+    writeObjectInternal(new KryoOutputObjectOutputBridge(kryo, output))
+  }
+
+  override def read(kryo: Kryo, input: Input): Unit = {
+    initialCapacity = input.readInt()
+    deltaChainThreshold = input.readInt()
+    keyClassTag = kryo.readClassAndObject(input).asInstanceOf[ClassTag[K]]
+    stateClassTag = kryo.readClassAndObject(input).asInstanceOf[ClassTag[S]]
+    readObjectInternal(new KryoInputObjectInputBridge(kryo, input))
+  }
 }
 
 /**

http://git-wip-us.apache.org/repos/asf/spark/blob/0d96c545/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala
----------------------------------------------------------------------
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala
index c4a01ea..ea32bbf 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala
@@ -17,15 +17,23 @@
 
 package org.apache.spark.streaming
 
+import org.apache.spark.streaming.rdd.MapWithStateRDDRecord
+
 import scala.collection.{immutable, mutable, Map}
+import scala.reflect.ClassTag
 import scala.util.Random
 
-import org.apache.spark.SparkFunSuite
+import com.esotericsoftware.kryo.{Kryo, KryoSerializable}
+import com.esotericsoftware.kryo.io.{Output, Input}
+
+import org.apache.spark.{SparkConf, SparkFunSuite}
+import org.apache.spark.serializer._
 import org.apache.spark.streaming.util.{EmptyStateMap, OpenHashMapBasedStateMap, StateMap}
-import org.apache.spark.util.Utils
 
 class StateMapSuite extends SparkFunSuite {
 
+  private val conf = new SparkConf()
+
   test("EmptyStateMap") {
     val map = new EmptyStateMap[Int, Int]
     intercept[scala.NotImplementedError] {
@@ -128,17 +136,17 @@ class StateMapSuite extends SparkFunSuite {
     map1.put(2, 200, 2)
     testSerialization(map1, "error deserializing and serialized map with data + no delta")
 
-    val map2 = map1.copy()
+    val map2 = map1.copy().asInstanceOf[OpenHashMapBasedStateMap[Int, Int]]
     // Do not test compaction
-    assert(map2.asInstanceOf[OpenHashMapBasedStateMap[_, _]].shouldCompact === false)
+    assert(map2.shouldCompact === false)
     testSerialization(map2, "error deserializing and serialized map with 1 delta + no new data")
 
     map2.put(3, 300, 3)
     map2.put(4, 400, 4)
     testSerialization(map2, "error deserializing and serialized map with 1 delta + new data")
 
-    val map3 = map2.copy()
-    assert(map3.asInstanceOf[OpenHashMapBasedStateMap[_, _]].shouldCompact === false)
+    val map3 = map2.copy().asInstanceOf[OpenHashMapBasedStateMap[Int, Int]]
+    assert(map3.shouldCompact === false)
     testSerialization(map3, "error deserializing and serialized map with 2 delta + no new data")
     map3.put(3, 600, 3)
     map3.remove(2)
@@ -267,18 +275,25 @@ class StateMapSuite extends SparkFunSuite {
     assertMap(stateMap, refMap.toMap, time, "Final state map does not match reference map")
   }
 
-  private def testSerialization[MapType <: StateMap[Int, Int]](
-    map: MapType, msg: String): MapType = {
-    val deserMap = Utils.deserialize[MapType](
-      Utils.serialize(map), Thread.currentThread().getContextClassLoader)
+  private def testSerialization[T: ClassTag](
+      map: OpenHashMapBasedStateMap[T, T], msg: String): OpenHashMapBasedStateMap[T, T] = {
+    testSerialization(new JavaSerializer(conf), map, msg)
+    testSerialization(new KryoSerializer(conf), map, msg)
+  }
+
+  private def testSerialization[T : ClassTag](
+      serializer: Serializer,
+      map: OpenHashMapBasedStateMap[T, T],
+      msg: String): OpenHashMapBasedStateMap[T, T] = {
+    val deserMap = serializeAndDeserialize(serializer, map)
     assertMap(deserMap, map, 1, msg)
     deserMap
   }
 
   // Assert whether all the data and operations on a state map matches that of a reference state map
-  private def assertMap(
-      mapToTest: StateMap[Int, Int],
-      refMapToTestWith: StateMap[Int, Int],
+  private def assertMap[T](
+      mapToTest: StateMap[T, T],
+      refMapToTestWith: StateMap[T, T],
       time: Long,
       msg: String): Unit = {
     withClue(msg) {
@@ -321,4 +336,59 @@ class StateMapSuite extends SparkFunSuite {
       }
     }
   }
+
+  test("OpenHashMapBasedStateMap - serializing and deserializing with KryoSerializable states") {
+    val map = new OpenHashMapBasedStateMap[KryoState, KryoState]()
+    map.put(new KryoState("a"), new KryoState("b"), 1)
+    testSerialization(
+      new KryoSerializer(conf), map, "error deserializing and serialized KryoSerializable states")
+  }
+
+  test("EmptyStateMap - serializing and deserializing") {
+    val map = StateMap.empty[KryoState, KryoState]
+    // Since EmptyStateMap doesn't contains any date, KryoState won't break JavaSerializer.
+    assert(serializeAndDeserialize(new JavaSerializer(conf), map).
+      isInstanceOf[EmptyStateMap[KryoState, KryoState]])
+    assert(serializeAndDeserialize(new KryoSerializer(conf), map).
+      isInstanceOf[EmptyStateMap[KryoState, KryoState]])
+  }
+
+  test("MapWithStateRDDRecord - serializing and deserializing with KryoSerializable states") {
+    val map = new OpenHashMapBasedStateMap[KryoState, KryoState]()
+    map.put(new KryoState("a"), new KryoState("b"), 1)
+
+    val record =
+      MapWithStateRDDRecord[KryoState, KryoState, KryoState](map, Seq(new KryoState("c")))
+    val deserRecord = serializeAndDeserialize(new KryoSerializer(conf), record)
+    assert(!(record eq deserRecord))
+    assert(record.stateMap.getAll().toSeq === deserRecord.stateMap.getAll().toSeq)
+    assert(record.mappedData === deserRecord.mappedData)
+  }
+
+  private def serializeAndDeserialize[T: ClassTag](serializer: Serializer, t: T): T = {
+    val serializerInstance = serializer.newInstance()
+    serializerInstance.deserialize[T](
+      serializerInstance.serialize(t), Thread.currentThread().getContextClassLoader)
+  }
+}
+
+/** A class that only supports Kryo serialization. */
+private[streaming] final class KryoState(var state: String) extends KryoSerializable {
+
+  override def write(kryo: Kryo, output: Output): Unit = {
+    kryo.writeClassAndObject(output, state)
+  }
+
+  override def read(kryo: Kryo, input: Input): Unit = {
+    state = kryo.readClassAndObject(input).asInstanceOf[String]
+  }
+
+  override def equals(other: Any): Boolean = other match {
+    case that: KryoState => state == that.state
+    case _ => false
+  }
+
+  override def hashCode(): Int = {
+    if (state == null) 0 else state.hashCode()
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org