You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by yh...@apache.org on 2015/05/08 05:59:47 UTC

spark git commit: [SPARK-6986] [SQL] Use Serializer2 in more cases.

Repository: spark
Updated Branches:
  refs/heads/master 92f8f803a -> 3af423c92


[SPARK-6986] [SQL] Use Serializer2 in more cases.

With https://github.com/apache/spark/commit/0a2b15ce43cf6096e1a7ae060b7c8a4010ce3b92, the serialization stream and deserialization stream has enough information to determine it is handling a key-value pari, a key, or a value. It is safe to use `SparkSqlSerializer2` in more cases.

Author: Yin Huai <yh...@databricks.com>

Closes #5849 from yhuai/serializer2MoreCases and squashes the following commits:

53a5eaa [Yin Huai] Josh's comments.
487f540 [Yin Huai] Use BufferedOutputStream.
8385f95 [Yin Huai] Always create a new row at the deserialization side to work with sort merge join.
c7e2129 [Yin Huai] Update tests.
4513d13 [Yin Huai] Use Serializer2 in more places.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/3af423c9
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/3af423c9
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/3af423c9

Branch: refs/heads/master
Commit: 3af423c92f117b5dd4dc6832dc50911cedb29abc
Parents: 92f8f80
Author: Yin Huai <yh...@databricks.com>
Authored: Thu May 7 20:59:42 2015 -0700
Committer: Yin Huai <yh...@databricks.com>
Committed: Thu May 7 20:59:42 2015 -0700

----------------------------------------------------------------------
 .../apache/spark/sql/execution/Exchange.scala   | 23 ++----
 .../sql/execution/SparkSqlSerializer2.scala     | 74 +++++++++++++-------
 .../execution/SparkSqlSerializer2Suite.scala    | 30 ++++----
 3 files changed, 69 insertions(+), 58 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/3af423c9/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
index 5b2e469..f0d54cd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
@@ -84,18 +84,8 @@ case class Exchange(
   def serializer(
       keySchema: Array[DataType],
       valueSchema: Array[DataType],
+      hasKeyOrdering: Boolean,
       numPartitions: Int): Serializer = {
-    // In ExternalSorter's spillToMergeableFile function, key-value pairs are written out
-    // through write(key) and then write(value) instead of write((key, value)). Because
-    // SparkSqlSerializer2 assumes that objects passed in are Product2, we cannot safely use
-    // it when spillToMergeableFile in ExternalSorter will be used.
-    // So, we will not use SparkSqlSerializer2 when
-    //  - Sort-based shuffle is enabled and the number of reducers (numPartitions) is greater
-    //     then the bypassMergeThreshold; or
-    //  - newOrdering is defined.
-    val cannotUseSqlSerializer2 =
-      (sortBasedShuffleOn && numPartitions > bypassMergeThreshold) || newOrdering.nonEmpty
-
     // It is true when there is no field that needs to be write out.
     // For now, we will not use SparkSqlSerializer2 when noField is true.
     val noField =
@@ -104,14 +94,13 @@ case class Exchange(
 
     val useSqlSerializer2 =
         child.sqlContext.conf.useSqlSerializer2 &&   // SparkSqlSerializer2 is enabled.
-        !cannotUseSqlSerializer2 &&                  // Safe to use Serializer2.
         SparkSqlSerializer2.support(keySchema) &&    // The schema of key is supported.
         SparkSqlSerializer2.support(valueSchema) &&  // The schema of value is supported.
         !noField
 
     val serializer = if (useSqlSerializer2) {
       logInfo("Using SparkSqlSerializer2.")
-      new SparkSqlSerializer2(keySchema, valueSchema)
+      new SparkSqlSerializer2(keySchema, valueSchema, hasKeyOrdering)
     } else {
       logInfo("Using SparkSqlSerializer.")
       new SparkSqlSerializer(sparkConf)
@@ -154,7 +143,8 @@ case class Exchange(
           }
         val keySchema = expressions.map(_.dataType).toArray
         val valueSchema = child.output.map(_.dataType).toArray
-        shuffled.setSerializer(serializer(keySchema, valueSchema, numPartitions))
+        shuffled.setSerializer(
+          serializer(keySchema, valueSchema, newOrdering.nonEmpty, numPartitions))
 
         shuffled.map(_._2)
 
@@ -179,7 +169,8 @@ case class Exchange(
             new ShuffledRDD[Row, Null, Null](rdd, part)
           }
         val keySchema = child.output.map(_.dataType).toArray
-        shuffled.setSerializer(serializer(keySchema, null, numPartitions))
+        shuffled.setSerializer(
+          serializer(keySchema, null, newOrdering.nonEmpty, numPartitions))
 
         shuffled.map(_._1)
 
@@ -199,7 +190,7 @@ case class Exchange(
         val partitioner = new HashPartitioner(1)
         val shuffled = new ShuffledRDD[Null, Row, Row](rdd, partitioner)
         val valueSchema = child.output.map(_.dataType).toArray
-        shuffled.setSerializer(serializer(null, valueSchema, 1))
+        shuffled.setSerializer(serializer(null, valueSchema, false, 1))
         shuffled.map(_._2)
 
       case _ => sys.error(s"Exchange not implemented for $newPartitioning")

http://git-wip-us.apache.org/repos/asf/spark/blob/3af423c9/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala
index 35ad987..256d527 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala
@@ -27,7 +27,7 @@ import scala.reflect.ClassTag
 import org.apache.spark.serializer._
 import org.apache.spark.Logging
 import org.apache.spark.sql.Row
-import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow
+import org.apache.spark.sql.catalyst.expressions.{SpecificMutableRow, MutableRow, GenericMutableRow}
 import org.apache.spark.sql.types._
 
 /**
@@ -49,9 +49,9 @@ private[sql] class Serializer2SerializationStream(
     out: OutputStream)
   extends SerializationStream with Logging {
 
-  val rowOut = new DataOutputStream(out)
-  val writeKeyFunc = SparkSqlSerializer2.createSerializationFunction(keySchema, rowOut)
-  val writeValueFunc = SparkSqlSerializer2.createSerializationFunction(valueSchema, rowOut)
+  private val rowOut = new DataOutputStream(new BufferedOutputStream(out))
+  private val writeKeyFunc = SparkSqlSerializer2.createSerializationFunction(keySchema, rowOut)
+  private val writeValueFunc = SparkSqlSerializer2.createSerializationFunction(valueSchema, rowOut)
 
   override def writeObject[T: ClassTag](t: T): SerializationStream = {
     val kv = t.asInstanceOf[Product2[Row, Row]]
@@ -86,31 +86,44 @@ private[sql] class Serializer2SerializationStream(
 private[sql] class Serializer2DeserializationStream(
     keySchema: Array[DataType],
     valueSchema: Array[DataType],
+    hasKeyOrdering: Boolean,
     in: InputStream)
   extends DeserializationStream with Logging  {
 
-  val rowIn = new DataInputStream(new BufferedInputStream(in))
+  private val rowIn = new DataInputStream(new BufferedInputStream(in))
+
+  private def rowGenerator(schema: Array[DataType]): () => (MutableRow) = {
+    if (schema == null) {
+      () => null
+    } else {
+      if (hasKeyOrdering) {
+        // We have key ordering specified in a ShuffledRDD, it is not safe to reuse a mutable row.
+        () => new GenericMutableRow(schema.length)
+      } else {
+        // It is safe to reuse the mutable row.
+        val mutableRow = new SpecificMutableRow(schema)
+        () => mutableRow
+      }
+    }
+  }
 
-  val key = if (keySchema != null) new SpecificMutableRow(keySchema) else null
-  val value = if (valueSchema != null) new SpecificMutableRow(valueSchema) else null
-  val readKeyFunc = SparkSqlSerializer2.createDeserializationFunction(keySchema, rowIn, key)
-  val readValueFunc = SparkSqlSerializer2.createDeserializationFunction(valueSchema, rowIn, value)
+  // Functions used to return rows for key and value.
+  private val getKey = rowGenerator(keySchema)
+  private val getValue = rowGenerator(valueSchema)
+  // Functions used to read a serialized row from the InputStream and deserialize it.
+  private val readKeyFunc = SparkSqlSerializer2.createDeserializationFunction(keySchema, rowIn)
+  private val readValueFunc = SparkSqlSerializer2.createDeserializationFunction(valueSchema, rowIn)
 
   override def readObject[T: ClassTag](): T = {
-    readKeyFunc()
-    readValueFunc()
-
-    (key, value).asInstanceOf[T]
+    (readKeyFunc(getKey()), readValueFunc(getValue())).asInstanceOf[T]
   }
 
   override def readKey[T: ClassTag](): T = {
-    readKeyFunc()
-    key.asInstanceOf[T]
+    readKeyFunc(getKey()).asInstanceOf[T]
   }
 
   override def readValue[T: ClassTag](): T = {
-    readValueFunc()
-    value.asInstanceOf[T]
+    readValueFunc(getValue()).asInstanceOf[T]
   }
 
   override def close(): Unit = {
@@ -118,9 +131,10 @@ private[sql] class Serializer2DeserializationStream(
   }
 }
 
-private[sql] class ShuffleSerializerInstance(
+private[sql] class SparkSqlSerializer2Instance(
     keySchema: Array[DataType],
-    valueSchema: Array[DataType])
+    valueSchema: Array[DataType],
+    hasKeyOrdering: Boolean)
   extends SerializerInstance {
 
   def serialize[T: ClassTag](t: T): ByteBuffer =
@@ -137,7 +151,7 @@ private[sql] class ShuffleSerializerInstance(
   }
 
   def deserializeStream(s: InputStream): DeserializationStream = {
-    new Serializer2DeserializationStream(keySchema, valueSchema, s)
+    new Serializer2DeserializationStream(keySchema, valueSchema, hasKeyOrdering, s)
   }
 }
 
@@ -148,12 +162,16 @@ private[sql] class ShuffleSerializerInstance(
  * The schema of keys is represented by `keySchema` and that of values is represented by
  * `valueSchema`.
  */
-private[sql] class SparkSqlSerializer2(keySchema: Array[DataType], valueSchema: Array[DataType])
+private[sql] class SparkSqlSerializer2(
+    keySchema: Array[DataType],
+    valueSchema: Array[DataType],
+    hasKeyOrdering: Boolean)
   extends Serializer
   with Logging
   with Serializable{
 
-  def newInstance(): SerializerInstance = new ShuffleSerializerInstance(keySchema, valueSchema)
+  def newInstance(): SerializerInstance =
+    new SparkSqlSerializer2Instance(keySchema, valueSchema, hasKeyOrdering)
 
   override def supportsRelocationOfSerializedObjects: Boolean = {
     // SparkSqlSerializer2 is stateless and writes no stream headers
@@ -323,11 +341,11 @@ private[sql] object SparkSqlSerializer2 {
    */
   def createDeserializationFunction(
       schema: Array[DataType],
-      in: DataInputStream,
-      mutableRow: SpecificMutableRow): () => Unit = {
-    () => {
-      // If the schema is null, the returned function does nothing when it get called.
-      if (schema != null) {
+      in: DataInputStream): (MutableRow) => Row = {
+    if (schema == null) {
+      (mutableRow: MutableRow) => null
+    } else {
+      (mutableRow: MutableRow) => {
         var i = 0
         while (i < schema.length) {
           schema(i) match {
@@ -440,6 +458,8 @@ private[sql] object SparkSqlSerializer2 {
           }
           i += 1
         }
+
+        mutableRow
       }
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/3af423c9/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala
index 27f063d..15337c4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala
@@ -148,6 +148,15 @@ abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll
       table("shuffle").collect())
   }
 
+  test("key schema is null") {
+    val aggregations = allColumns.split(",").map(c => s"COUNT($c)").mkString(",")
+    val df = sql(s"SELECT $aggregations FROM shuffle")
+    checkSerializer(df.queryExecution.executedPlan, serializerClass)
+    checkAnswer(
+      df,
+      Row(1000, 1000, 0, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000))
+  }
+
   test("value schema is null") {
     val df = sql(s"SELECT col0 FROM shuffle ORDER BY col0")
     checkSerializer(df.queryExecution.executedPlan, serializerClass)
@@ -167,29 +176,20 @@ class SparkSqlSerializer2SortShuffleSuite extends SparkSqlSerializer2Suite {
   override def beforeAll(): Unit = {
     super.beforeAll()
     // Sort merge will not be triggered.
-    sql("set spark.sql.shuffle.partitions = 200")
-  }
-
-  test("key schema is null") {
-    val aggregations = allColumns.split(",").map(c => s"COUNT($c)").mkString(",")
-    val df = sql(s"SELECT $aggregations FROM shuffle")
-    checkSerializer(df.queryExecution.executedPlan, serializerClass)
-    checkAnswer(
-      df,
-      Row(1000, 1000, 0, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000))
+    val bypassMergeThreshold =
+      sparkContext.conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200)
+    sql(s"set spark.sql.shuffle.partitions=${bypassMergeThreshold-1}")
   }
 }
 
 /** For now, we will use SparkSqlSerializer for sort based shuffle with sort merge. */
 class SparkSqlSerializer2SortMergeShuffleSuite extends SparkSqlSerializer2Suite {
 
-  // We are expecting SparkSqlSerializer.
-  override val serializerClass: Class[Serializer] =
-    classOf[SparkSqlSerializer].asInstanceOf[Class[Serializer]]
-
   override def beforeAll(): Unit = {
     super.beforeAll()
     // To trigger the sort merge.
-    sql("set spark.sql.shuffle.partitions = 201")
+    val bypassMergeThreshold =
+      sparkContext.conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200)
+    sql(s"set spark.sql.shuffle.partitions=${bypassMergeThreshold + 1}")
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org