You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2015/07/20 08:41:31 UTC

spark git commit: [SPARK-9023] [SQL] Efficiency improvements for UnsafeRows in Exchange

Repository: spark
Updated Branches:
  refs/heads/master 972d8900a -> 79ec07290


[SPARK-9023] [SQL] Efficiency improvements for UnsafeRows in Exchange

This pull request aims to improve the performance of SQL's Exchange operator when shuffling UnsafeRows.  It also makes several general efficiency improvements to Exchange.

Key changes:

- When performing hash partitioning, the old Exchange projected the partitioning columns into a new row then passed a `(partitioningColumRow: InternalRow, row: InternalRow)` pair into the shuffle. This is very inefficient because it ends up redundantly serializing the partitioning columns only to immediately discard them after the shuffle.  After this patch's changes, Exchange now shuffles `(partitionId: Int, row: InternalRow)` pairs.  This still isn't optimal, since we're still shuffling extra data that we don't need, but it's significantly more efficient than the old implementation; in the future, we may be able to further optimize this once we implement a new shuffle write interface that accepts non-key-value-pair inputs.
- Exchange's `compute()` method has been significantly simplified; the new code has less duplication and thus is easier to understand.
- When the Exchange's input operator produces UnsafeRows, Exchange will use a specialized `UnsafeRowSerializer` to serialize these rows.  This serializer is significantly more efficient since it simply copies the UnsafeRow's underlying bytes.  Note that this approach does not work for UnsafeRows that use the ObjectPool mechanism; I did not add support for this because we are planning to remove ObjectPool in the next few weeks.

Author: Josh Rosen <jo...@databricks.com>

Closes #7456 from JoshRosen/unsafe-exchange and squashes the following commits:

7e75259 [Josh Rosen] Fix cast in SparkSqlSerializer2Suite
0082515 [Josh Rosen] Some additional comments + small cleanup to remove an unused parameter
a27cfc1 [Josh Rosen] Add missing newline
741973c [Josh Rosen] Add simple test of UnsafeRow shuffling in Exchange.
359c6a4 [Josh Rosen] Remove println() and add comments
93904e7 [Josh Rosen] Merge remote-tracking branch 'origin/master' into unsafe-exchange
8dd3ff2 [Josh Rosen] Exchange outputs UnsafeRows when its child outputs them
dd9c66d [Josh Rosen] Fix for copying logic
035af21 [Josh Rosen] Add logic for choosing when to use UnsafeRowSerializer
7876f31 [Josh Rosen] Merge remote-tracking branch 'origin/master' into unsafe-shuffle
cbea80b [Josh Rosen] Add UnsafeRowSerializer
0f2ac86 [Josh Rosen] Import ordering
3ca8515 [Josh Rosen] Big code simplification in Exchange
3526868 [Josh Rosen] Iniitial cut at removing shuffle on KV pairs


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/79ec0729
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/79ec0729
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/79ec0729

Branch: refs/heads/master
Commit: 79ec07290d0b4d16f1643af83824d926304c8f46
Parents: 972d890
Author: Josh Rosen <jo...@databricks.com>
Authored: Sun Jul 19 23:41:28 2015 -0700
Committer: Reynold Xin <rx...@databricks.com>
Committed: Sun Jul 19 23:41:28 2015 -0700

----------------------------------------------------------------------
 .../apache/spark/sql/execution/Exchange.scala   | 132 +++++++----------
 .../spark/sql/execution/ShuffledRowRDD.scala    |  80 +++++++++++
 .../sql/execution/SparkSqlSerializer2.scala     |  43 +++---
 .../sql/execution/UnsafeRowSerializer.scala     | 142 +++++++++++++++++++
 .../spark/sql/execution/basicOperators.scala    |   5 +-
 .../spark/sql/execution/ExchangeSuite.scala     |  32 +++++
 .../execution/SparkSqlSerializer2Suite.scala    |   4 +-
 .../execution/UnsafeRowSerializerSuite.scala    |  76 ++++++++++
 8 files changed, 398 insertions(+), 116 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/79ec0729/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
index feea4f2..2750053 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
@@ -18,7 +18,7 @@
 package org.apache.spark.sql.execution
 
 import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.rdd.{RDD, ShuffledRDD}
+import org.apache.spark.rdd.RDD
 import org.apache.spark.serializer.Serializer
 import org.apache.spark.shuffle.hash.HashShuffleManager
 import org.apache.spark.shuffle.sort.SortShuffleManager
@@ -29,7 +29,6 @@ import org.apache.spark.sql.catalyst.errors.attachTree
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.physical._
 import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.types.DataType
 import org.apache.spark.util.MutablePair
 import org.apache.spark.{HashPartitioner, Partitioner, RangePartitioner, SparkEnv}
 
@@ -44,6 +43,12 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
 
   override def output: Seq[Attribute] = child.output
 
+  override def outputsUnsafeRows: Boolean = child.outputsUnsafeRows
+
+  override def canProcessSafeRows: Boolean = true
+
+  override def canProcessUnsafeRows: Boolean = true
+
   /**
    * Determines whether records must be defensively copied before being sent to the shuffle.
    * Several of Spark's shuffle components will buffer deserialized Java objects in memory. The
@@ -112,109 +117,70 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
 
   @transient private lazy val sparkConf = child.sqlContext.sparkContext.getConf
 
-  private def getSerializer(
-      keySchema: Array[DataType],
-      valueSchema: Array[DataType],
-      numPartitions: Int): Serializer = {
+  private val serializer: Serializer = {
+    val rowDataTypes = child.output.map(_.dataType).toArray
     // It is true when there is no field that needs to be write out.
     // For now, we will not use SparkSqlSerializer2 when noField is true.
-    val noField =
-      (keySchema == null || keySchema.length == 0) &&
-      (valueSchema == null || valueSchema.length == 0)
+    val noField = rowDataTypes == null || rowDataTypes.length == 0
 
     val useSqlSerializer2 =
         child.sqlContext.conf.useSqlSerializer2 &&   // SparkSqlSerializer2 is enabled.
-        SparkSqlSerializer2.support(keySchema) &&    // The schema of key is supported.
-        SparkSqlSerializer2.support(valueSchema) &&  // The schema of value is supported.
+        SparkSqlSerializer2.support(rowDataTypes) &&  // The schema of row is supported.
         !noField
 
-    val serializer = if (useSqlSerializer2) {
+    if (child.outputsUnsafeRows) {
+      logInfo("Using UnsafeRowSerializer.")
+      new UnsafeRowSerializer(child.output.size)
+    } else if (useSqlSerializer2) {
       logInfo("Using SparkSqlSerializer2.")
-      new SparkSqlSerializer2(keySchema, valueSchema)
+      new SparkSqlSerializer2(rowDataTypes)
     } else {
       logInfo("Using SparkSqlSerializer.")
       new SparkSqlSerializer(sparkConf)
     }
-
-    serializer
   }
 
   protected override def doExecute(): RDD[InternalRow] = attachTree(this , "execute") {
-    newPartitioning match {
-      case HashPartitioning(expressions, numPartitions) =>
-        val keySchema = expressions.map(_.dataType).toArray
-        val valueSchema = child.output.map(_.dataType).toArray
-        val serializer = getSerializer(keySchema, valueSchema, numPartitions)
-        val part = new HashPartitioner(numPartitions)
-
-        val rdd = if (needToCopyObjectsBeforeShuffle(part, serializer)) {
-          child.execute().mapPartitions { iter =>
-            val hashExpressions = newMutableProjection(expressions, child.output)()
-            iter.map(r => (hashExpressions(r).copy(), r.copy()))
-          }
-        } else {
-          child.execute().mapPartitions { iter =>
-            val hashExpressions = newMutableProjection(expressions, child.output)()
-            val mutablePair = new MutablePair[InternalRow, InternalRow]()
-            iter.map(r => mutablePair.update(hashExpressions(r), r))
-          }
-        }
-        val shuffled = new ShuffledRDD[InternalRow, InternalRow, InternalRow](rdd, part)
-        shuffled.setSerializer(serializer)
-        shuffled.map(_._2)
-
+    val rdd = child.execute()
+    val part: Partitioner = newPartitioning match {
+      case HashPartitioning(expressions, numPartitions) => new HashPartitioner(numPartitions)
       case RangePartitioning(sortingExpressions, numPartitions) =>
-        val keySchema = child.output.map(_.dataType).toArray
-        val serializer = getSerializer(keySchema, null, numPartitions)
-
-        val childRdd = child.execute()
-        val part: Partitioner = {
-          // Internally, RangePartitioner runs a job on the RDD that samples keys to compute
-          // partition bounds. To get accurate samples, we need to copy the mutable keys.
-          val rddForSampling = childRdd.mapPartitions { iter =>
-            val mutablePair = new MutablePair[InternalRow, Null]()
-            iter.map(row => mutablePair.update(row.copy(), null))
-          }
-          // TODO: RangePartitioner should take an Ordering.
-          implicit val ordering = new RowOrdering(sortingExpressions, child.output)
-          new RangePartitioner(numPartitions, rddForSampling, ascending = true)
-        }
-
-        val rdd = if (needToCopyObjectsBeforeShuffle(part, serializer)) {
-          childRdd.mapPartitions { iter => iter.map(row => (row.copy(), null))}
-        } else {
-          childRdd.mapPartitions { iter =>
-            val mutablePair = new MutablePair[InternalRow, Null]()
-            iter.map(row => mutablePair.update(row, null))
-          }
+        // Internally, RangePartitioner runs a job on the RDD that samples keys to compute
+        // partition bounds. To get accurate samples, we need to copy the mutable keys.
+        val rddForSampling = rdd.mapPartitions { iter =>
+          val mutablePair = new MutablePair[InternalRow, Null]()
+          iter.map(row => mutablePair.update(row.copy(), null))
         }
-
-        val shuffled = new ShuffledRDD[InternalRow, Null, Null](rdd, part)
-        shuffled.setSerializer(serializer)
-        shuffled.map(_._1)
-
+        implicit val ordering = new RowOrdering(sortingExpressions, child.output)
+        new RangePartitioner(numPartitions, rddForSampling, ascending = true)
       case SinglePartition =>
-        val valueSchema = child.output.map(_.dataType).toArray
-        val serializer = getSerializer(null, valueSchema, numPartitions = 1)
-        val partitioner = new HashPartitioner(1)
-
-        val rdd = if (needToCopyObjectsBeforeShuffle(partitioner, serializer)) {
-          child.execute().mapPartitions {
-            iter => iter.map(r => (null, r.copy()))
-          }
-        } else {
-          child.execute().mapPartitions { iter =>
-            val mutablePair = new MutablePair[Null, InternalRow]()
-            iter.map(r => mutablePair.update(null, r))
-          }
+        new Partitioner {
+          override def numPartitions: Int = 1
+          override def getPartition(key: Any): Int = 0
         }
-        val shuffled = new ShuffledRDD[Null, InternalRow, InternalRow](rdd, partitioner)
-        shuffled.setSerializer(serializer)
-        shuffled.map(_._2)
-
       case _ => sys.error(s"Exchange not implemented for $newPartitioning")
       // TODO: Handle BroadcastPartitioning.
     }
+    def getPartitionKeyExtractor(): InternalRow => InternalRow = newPartitioning match {
+      case HashPartitioning(expressions, _) => newMutableProjection(expressions, child.output)()
+      case RangePartitioning(_, _) | SinglePartition => identity
+      case _ => sys.error(s"Exchange not implemented for $newPartitioning")
+    }
+    val rddWithPartitionIds: RDD[Product2[Int, InternalRow]] = {
+      if (needToCopyObjectsBeforeShuffle(part, serializer)) {
+        rdd.mapPartitions { iter =>
+          val getPartitionKey = getPartitionKeyExtractor()
+          iter.map { row => (part.getPartition(getPartitionKey(row)), row.copy()) }
+        }
+      } else {
+        rdd.mapPartitions { iter =>
+          val getPartitionKey = getPartitionKeyExtractor()
+          val mutablePair = new MutablePair[Int, InternalRow]()
+          iter.map { row => mutablePair.update(part.getPartition(getPartitionKey(row)), row) }
+        }
+      }
+    }
+    new ShuffledRowRDD(rddWithPartitionIds, serializer, part.numPartitions)
   }
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/79ec0729/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala
new file mode 100644
index 0000000..88f5b13
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import org.apache.spark._
+import org.apache.spark.rdd.RDD
+import org.apache.spark.serializer.Serializer
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.types.DataType
+
+private class ShuffledRowRDDPartition(val idx: Int) extends Partition {
+  override val index: Int = idx
+  override def hashCode(): Int = idx
+}
+
+/**
+ * A dummy partitioner for use with records whose partition ids have been pre-computed (i.e. for
+ * use on RDDs of (Int, Row) pairs where the Int is a partition id in the expected range).
+ */
+private class PartitionIdPassthrough(override val numPartitions: Int) extends Partitioner {
+  override def getPartition(key: Any): Int = key.asInstanceOf[Int]
+}
+
+/**
+ * This is a specialized version of [[org.apache.spark.rdd.ShuffledRDD]] that is optimized for
+ * shuffling rows instead of Java key-value pairs. Note that something like this should eventually
+ * be implemented in Spark core, but that is blocked by some more general refactorings to shuffle
+ * interfaces / internals.
+ *
+ * @param prev the RDD being shuffled. Elements of this RDD are (partitionId, Row) pairs.
+ *             Partition ids should be in the range [0, numPartitions - 1].
+ * @param serializer the serializer used during the shuffle.
+ * @param numPartitions the number of post-shuffle partitions.
+ */
+class ShuffledRowRDD(
+    @transient var prev: RDD[Product2[Int, InternalRow]],
+    serializer: Serializer,
+    numPartitions: Int)
+  extends RDD[InternalRow](prev.context, Nil) {
+
+  private val part: Partitioner = new PartitionIdPassthrough(numPartitions)
+
+  override def getDependencies: Seq[Dependency[_]] = {
+    List(new ShuffleDependency[Int, InternalRow, InternalRow](prev, part, Some(serializer)))
+  }
+
+  override val partitioner = Some(part)
+
+  override def getPartitions: Array[Partition] = {
+    Array.tabulate[Partition](part.numPartitions)(i => new ShuffledRowRDDPartition(i))
+  }
+
+  override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = {
+    val dep = dependencies.head.asInstanceOf[ShuffleDependency[Int, InternalRow, InternalRow]]
+    SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context)
+      .read()
+      .asInstanceOf[Iterator[Product2[Int, InternalRow]]]
+      .map(_._2)
+  }
+
+  override def clearDependencies() {
+    super.clearDependencies()
+    prev = null
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/79ec0729/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala
index 6ed822d..c87e206 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala
@@ -45,14 +45,12 @@ import org.apache.spark.unsafe.types.UTF8String
  *     the comment of the `serializer` method in [[Exchange]] for more information on it.
  */
 private[sql] class Serializer2SerializationStream(
-    keySchema: Array[DataType],
-    valueSchema: Array[DataType],
+    rowSchema: Array[DataType],
     out: OutputStream)
   extends SerializationStream with Logging {
 
   private val rowOut = new DataOutputStream(new BufferedOutputStream(out))
-  private val writeKeyFunc = SparkSqlSerializer2.createSerializationFunction(keySchema, rowOut)
-  private val writeValueFunc = SparkSqlSerializer2.createSerializationFunction(valueSchema, rowOut)
+  private val writeRowFunc = SparkSqlSerializer2.createSerializationFunction(rowSchema, rowOut)
 
   override def writeObject[T: ClassTag](t: T): SerializationStream = {
     val kv = t.asInstanceOf[Product2[Row, Row]]
@@ -63,12 +61,12 @@ private[sql] class Serializer2SerializationStream(
   }
 
   override def writeKey[T: ClassTag](t: T): SerializationStream = {
-    writeKeyFunc(t.asInstanceOf[Row])
+    // No-op.
     this
   }
 
   override def writeValue[T: ClassTag](t: T): SerializationStream = {
-    writeValueFunc(t.asInstanceOf[Row])
+    writeRowFunc(t.asInstanceOf[Row])
     this
   }
 
@@ -85,8 +83,7 @@ private[sql] class Serializer2SerializationStream(
  * The corresponding deserialization stream for [[Serializer2SerializationStream]].
  */
 private[sql] class Serializer2DeserializationStream(
-    keySchema: Array[DataType],
-    valueSchema: Array[DataType],
+    rowSchema: Array[DataType],
     in: InputStream)
   extends DeserializationStream with Logging  {
 
@@ -103,22 +100,20 @@ private[sql] class Serializer2DeserializationStream(
   }
 
   // Functions used to return rows for key and value.
-  private val getKey = rowGenerator(keySchema)
-  private val getValue = rowGenerator(valueSchema)
+  private val getRow = rowGenerator(rowSchema)
   // Functions used to read a serialized row from the InputStream and deserialize it.
-  private val readKeyFunc = SparkSqlSerializer2.createDeserializationFunction(keySchema, rowIn)
-  private val readValueFunc = SparkSqlSerializer2.createDeserializationFunction(valueSchema, rowIn)
+  private val readRowFunc = SparkSqlSerializer2.createDeserializationFunction(rowSchema, rowIn)
 
   override def readObject[T: ClassTag](): T = {
-    (readKeyFunc(getKey()), readValueFunc(getValue())).asInstanceOf[T]
+    readValue()
   }
 
   override def readKey[T: ClassTag](): T = {
-    readKeyFunc(getKey()).asInstanceOf[T]
+    null.asInstanceOf[T] // intentionally left blank.
   }
 
   override def readValue[T: ClassTag](): T = {
-    readValueFunc(getValue()).asInstanceOf[T]
+    readRowFunc(getRow()).asInstanceOf[T]
   }
 
   override def close(): Unit = {
@@ -127,8 +122,7 @@ private[sql] class Serializer2DeserializationStream(
 }
 
 private[sql] class SparkSqlSerializer2Instance(
-    keySchema: Array[DataType],
-    valueSchema: Array[DataType])
+    rowSchema: Array[DataType])
   extends SerializerInstance {
 
   def serialize[T: ClassTag](t: T): ByteBuffer =
@@ -141,30 +135,25 @@ private[sql] class SparkSqlSerializer2Instance(
     throw new UnsupportedOperationException("Not supported.")
 
   def serializeStream(s: OutputStream): SerializationStream = {
-    new Serializer2SerializationStream(keySchema, valueSchema, s)
+    new Serializer2SerializationStream(rowSchema, s)
   }
 
   def deserializeStream(s: InputStream): DeserializationStream = {
-    new Serializer2DeserializationStream(keySchema, valueSchema, s)
+    new Serializer2DeserializationStream(rowSchema, s)
   }
 }
 
 /**
  * SparkSqlSerializer2 is a special serializer that creates serialization function and
  * deserialization function based on the schema of data. It assumes that values passed in
- * are key/value pairs and values returned from it are also key/value pairs.
- * The schema of keys is represented by `keySchema` and that of values is represented by
- * `valueSchema`.
+ * are Rows.
  */
-private[sql] class SparkSqlSerializer2(
-    keySchema: Array[DataType],
-    valueSchema: Array[DataType])
+private[sql] class SparkSqlSerializer2(rowSchema: Array[DataType])
   extends Serializer
   with Logging
   with Serializable{
 
-  def newInstance(): SerializerInstance =
-    new SparkSqlSerializer2Instance(keySchema, valueSchema)
+  def newInstance(): SerializerInstance = new SparkSqlSerializer2Instance(rowSchema)
 
   override def supportsRelocationOfSerializedObjects: Boolean = {
     // SparkSqlSerializer2 is stateless and writes no stream headers

http://git-wip-us.apache.org/repos/asf/spark/blob/79ec0729/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala
new file mode 100644
index 0000000..19503ed
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import java.io.{DataInputStream, DataOutputStream, OutputStream, InputStream}
+import java.nio.ByteBuffer
+
+import scala.reflect.ClassTag
+
+import com.google.common.io.ByteStreams
+
+import org.apache.spark.serializer.{SerializationStream, DeserializationStream, SerializerInstance, Serializer}
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.unsafe.PlatformDependent
+
+/**
+ * Serializer for serializing [[UnsafeRow]]s during shuffle. Since UnsafeRows are already stored as
+ * bytes, this serializer simply copies those bytes to the underlying output stream. When
+ * deserializing a stream of rows, instances of this serializer mutate and return a single UnsafeRow
+ * instance that is backed by an on-heap byte array.
+ *
+ * Note that this serializer implements only the [[Serializer]] methods that are used during
+ * shuffle, so certain [[SerializerInstance]] methods will throw UnsupportedOperationException.
+ *
+ * This serializer does not support UnsafeRows that use
+ * [[org.apache.spark.sql.catalyst.util.ObjectPool]].
+ *
+ * @param numFields the number of fields in the row being serialized.
+ */
+private[sql] class UnsafeRowSerializer(numFields: Int) extends Serializer with Serializable {
+  override def newInstance(): SerializerInstance = new UnsafeRowSerializerInstance(numFields)
+  override private[spark] def supportsRelocationOfSerializedObjects: Boolean = true
+}
+
+private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInstance {
+
+  private[this] val EOF: Int = -1
+
+  override def serializeStream(out: OutputStream): SerializationStream = new SerializationStream {
+    private[this] var writeBuffer: Array[Byte] = new Array[Byte](4096)
+    private[this] val dOut: DataOutputStream = new DataOutputStream(out)
+
+    override def writeValue[T: ClassTag](value: T): SerializationStream = {
+      val row = value.asInstanceOf[UnsafeRow]
+      assert(row.getPool == null, "UnsafeRowSerializer does not support ObjectPool")
+      dOut.writeInt(row.getSizeInBytes)
+      var dataRemaining: Int = row.getSizeInBytes
+      val baseObject = row.getBaseObject
+      var rowReadPosition: Long = row.getBaseOffset
+      while (dataRemaining > 0) {
+        val toTransfer: Int = Math.min(writeBuffer.length, dataRemaining)
+        PlatformDependent.copyMemory(
+          baseObject,
+          rowReadPosition,
+          writeBuffer,
+          PlatformDependent.BYTE_ARRAY_OFFSET,
+          toTransfer)
+        out.write(writeBuffer, 0, toTransfer)
+        rowReadPosition += toTransfer
+        dataRemaining -= toTransfer
+      }
+      this
+    }
+    override def writeKey[T: ClassTag](key: T): SerializationStream = {
+      assert(key.isInstanceOf[Int])
+      this
+    }
+    override def writeAll[T: ClassTag](iter: Iterator[T]): SerializationStream =
+      throw new UnsupportedOperationException
+    override def writeObject[T: ClassTag](t: T): SerializationStream =
+      throw new UnsupportedOperationException
+    override def flush(): Unit = dOut.flush()
+    override def close(): Unit = {
+      writeBuffer = null
+      dOut.writeInt(EOF)
+      dOut.close()
+    }
+  }
+
+  override def deserializeStream(in: InputStream): DeserializationStream = {
+    new DeserializationStream {
+      private[this] val dIn: DataInputStream = new DataInputStream(in)
+      private[this] var rowBuffer: Array[Byte] = new Array[Byte](1024)
+      private[this] var row: UnsafeRow = new UnsafeRow()
+      private[this] var rowTuple: (Int, UnsafeRow) = (0, row)
+
+      override def asKeyValueIterator: Iterator[(Int, UnsafeRow)] = {
+        new Iterator[(Int, UnsafeRow)] {
+          private[this] var rowSize: Int = dIn.readInt()
+
+          override def hasNext: Boolean = rowSize != EOF
+
+          override def next(): (Int, UnsafeRow) = {
+            if (rowBuffer.length < rowSize) {
+              rowBuffer = new Array[Byte](rowSize)
+            }
+            ByteStreams.readFully(in, rowBuffer, 0, rowSize)
+            row.pointTo(rowBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, rowSize, null)
+            rowSize = dIn.readInt() // read the next row's size
+            if (rowSize == EOF) { // We are returning the last row in this stream
+              val _rowTuple = rowTuple
+              // Null these out so that the byte array can be garbage collected once the entire
+              // iterator has been consumed
+              row = null
+              rowBuffer = null
+              rowTuple = null
+              _rowTuple
+            } else {
+              rowTuple
+            }
+          }
+        }
+      }
+      override def asIterator: Iterator[Any] = throw new UnsupportedOperationException
+      override def readKey[T: ClassTag](): T = throw new UnsupportedOperationException
+      override def readValue[T: ClassTag](): T = throw new UnsupportedOperationException
+      override def readObject[T: ClassTag](): T = throw new UnsupportedOperationException
+      override def close(): Unit = dIn.close()
+    }
+  }
+
+  override def serialize[T: ClassTag](t: T): ByteBuffer = throw new UnsupportedOperationException
+  override def deserialize[T: ClassTag](bytes: ByteBuffer): T =
+    throw new UnsupportedOperationException
+  override def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T =
+    throw new UnsupportedOperationException
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/79ec0729/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index 82bef26..fdd7ad5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -56,11 +56,8 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends
 case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode {
   override def output: Seq[Attribute] = child.output
 
-  @transient lazy val conditionEvaluator: (InternalRow) => Boolean =
-    newPredicate(condition, child.output)
-
   protected override def doExecute(): RDD[InternalRow] = child.execute().mapPartitions { iter =>
-    iter.filter(conditionEvaluator)
+    iter.filter(newPredicate(condition, child.output))
   }
 
   override def outputOrdering: Seq[SortOrder] = child.outputOrdering

http://git-wip-us.apache.org/repos/asf/spark/blob/79ec0729/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala
new file mode 100644
index 0000000..79e903c
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.plans.physical.SinglePartition
+
+class ExchangeSuite extends SparkPlanTest {
+  test("shuffling UnsafeRows in exchange") {
+    val input = (1 to 1000).map(Tuple1.apply)
+    checkAnswer(
+      input.toDF(),
+      plan => ConvertToSafe(Exchange(SinglePartition, ConvertToUnsafe(plan))),
+      input.map(Row.fromTuple)
+    )
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/79ec0729/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala
index 71f6b26..4a53fad 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala
@@ -132,8 +132,8 @@ abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll
       expectedSerializerClass: Class[T]): Unit = {
     executedPlan.foreach {
       case exchange: Exchange =>
-        val shuffledRDD = exchange.execute().firstParent.asInstanceOf[ShuffledRDD[_, _, _]]
-        val dependency = shuffledRDD.getDependencies.head.asInstanceOf[ShuffleDependency[_, _, _]]
+        val shuffledRDD = exchange.execute()
+        val dependency = shuffledRDD.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]]
         val serializerNotSetMessage =
           s"Expected $expectedSerializerClass as the serializer of Exchange. " +
           s"However, the serializer was not set."

http://git-wip-us.apache.org/repos/asf/spark/blob/79ec0729/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala
new file mode 100644
index 0000000..bd788ec
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
+import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, UnsafeRowConverter}
+import org.apache.spark.sql.catalyst.util.ObjectPool
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.PlatformDependent
+
+class UnsafeRowSerializerSuite extends SparkFunSuite {
+
+  private def toUnsafeRow(
+      row: Row,
+      schema: Array[DataType],
+      objPool: ObjectPool = null): UnsafeRow = {
+    val internalRow = CatalystTypeConverters.convertToCatalyst(row).asInstanceOf[InternalRow]
+    val rowConverter = new UnsafeRowConverter(schema)
+    val rowSizeInBytes = rowConverter.getSizeRequirement(internalRow)
+    val byteArray = new Array[Byte](rowSizeInBytes)
+    rowConverter.writeRow(
+      internalRow, byteArray, PlatformDependent.BYTE_ARRAY_OFFSET, rowSizeInBytes, objPool)
+    val unsafeRow = new UnsafeRow()
+    unsafeRow.pointTo(
+      byteArray, PlatformDependent.BYTE_ARRAY_OFFSET, row.length, rowSizeInBytes, objPool)
+    unsafeRow
+  }
+
+  test("toUnsafeRow() test helper method") {
+    val row = Row("Hello", 123)
+    val unsafeRow = toUnsafeRow(row, Array(StringType, IntegerType))
+    assert(row.getString(0) === unsafeRow.get(0).toString)
+    assert(row.getInt(1) === unsafeRow.getInt(1))
+  }
+
+  test("basic row serialization") {
+    val rows = Seq(Row("Hello", 1), Row("World", 2))
+    val unsafeRows = rows.map(row => toUnsafeRow(row, Array(StringType, IntegerType)))
+    val serializer = new UnsafeRowSerializer(numFields = 2).newInstance()
+    val baos = new ByteArrayOutputStream()
+    val serializerStream = serializer.serializeStream(baos)
+    for (unsafeRow <- unsafeRows) {
+      serializerStream.writeKey(0)
+      serializerStream.writeValue(unsafeRow)
+    }
+    serializerStream.close()
+    val deserializerIter = serializer.deserializeStream(
+      new ByteArrayInputStream(baos.toByteArray)).asKeyValueIterator
+    for (expectedRow <- unsafeRows) {
+      val actualRow = deserializerIter.next().asInstanceOf[(Integer, UnsafeRow)]._2
+      assert(expectedRow.getSizeInBytes === actualRow.getSizeInBytes)
+      assert(expectedRow.getString(0) === actualRow.getString(0))
+      assert(expectedRow.getInt(1) === actualRow.getInt(1))
+    }
+    assert(!deserializerIter.hasNext)
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org