You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by da...@apache.org on 2015/07/22 22:02:48 UTC

spark git commit: [SPARK-9024] Unsafe HashJoin/HashOuterJoin/HashSemiJoin

Repository: spark
Updated Branches:
  refs/heads/master 86f80e2b4 -> e0b7ba59a


[SPARK-9024] Unsafe HashJoin/HashOuterJoin/HashSemiJoin

This PR introduce unsafe version (using UnsafeRow) of HashJoin, HashOuterJoin and HashSemiJoin, including the broadcast one and shuffle one (except FullOuterJoin, which is better to be implemented using SortMergeJoin).

It use HashMap to store UnsafeRow right now, will change to use BytesToBytesMap for better performance (in another PR).

Author: Davies Liu <da...@databricks.com>

Closes #7480 from davies/unsafe_join and squashes the following commits:

6294b1e [Davies Liu] fix projection
10583f1 [Davies Liu] Merge branch 'master' of github.com:apache/spark into unsafe_join
dede020 [Davies Liu] fix test
84c9807 [Davies Liu] address comments
a05b4f6 [Davies Liu] support UnsafeRow in LeftSemiJoinBNL and BroadcastNestedLoopJoin
611d2ed [Davies Liu] Merge branch 'master' of github.com:apache/spark into unsafe_join
9481ae8 [Davies Liu] return UnsafeRow after join()
ca2b40f [Davies Liu] revert unrelated change
68f5cd9 [Davies Liu] Merge branch 'master' of github.com:apache/spark into unsafe_join
0f4380d [Davies Liu] ada a comment
69e38f5 [Davies Liu] Merge branch 'master' of github.com:apache/spark into unsafe_join
1a40f02 [Davies Liu] refactor
ab1690f [Davies Liu] address comments
60371f2 [Davies Liu] use UnsafeRow in SemiJoin
a6c0b7d [Davies Liu] Merge branch 'master' of github.com:apache/spark into unsafe_join
184b852 [Davies Liu] fix style
6acbb11 [Davies Liu] fix tests
95d0762 [Davies Liu] remove println
bea4a50 [Davies Liu] Unsafe HashJoin


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/e0b7ba59
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/e0b7ba59
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/e0b7ba59

Branch: refs/heads/master
Commit: e0b7ba59a1ace9b78a1ad6f3f07fe153db20b52c
Parents: 86f80e2
Author: Davies Liu <da...@databricks.com>
Authored: Wed Jul 22 13:02:43 2015 -0700
Committer: Davies Liu <da...@gmail.com>
Committed: Wed Jul 22 13:02:43 2015 -0700

----------------------------------------------------------------------
 .../sql/catalyst/expressions/UnsafeRow.java     | 50 +++++++++++-
 .../sql/execution/UnsafeExternalRowSorter.java  | 10 +--
 .../catalyst/expressions/BoundAttribute.scala   | 19 ++++-
 .../sql/catalyst/expressions/Projection.scala   | 34 +++++++-
 .../sql/execution/joins/BroadcastHashJoin.scala |  2 +-
 .../joins/BroadcastHashOuterJoin.scala          | 32 ++------
 .../joins/BroadcastLeftSemiJoinHash.scala       |  5 +-
 .../joins/BroadcastNestedLoopJoin.scala         | 37 ++++++---
 .../spark/sql/execution/joins/HashJoin.scala    | 43 ++++++++--
 .../sql/execution/joins/HashOuterJoin.scala     | 82 ++++++++++++++++---
 .../sql/execution/joins/HashSemiJoin.scala      | 74 ++++++++++-------
 .../sql/execution/joins/HashedRelation.scala    | 85 +++++++++++++++++++-
 .../sql/execution/joins/LeftSemiJoinBNL.scala   |  3 +
 .../sql/execution/joins/LeftSemiJoinHash.scala  |  4 +-
 .../sql/execution/joins/ShuffledHashJoin.scala  |  2 +-
 .../execution/joins/ShuffledHashOuterJoin.scala | 13 +--
 .../sql/execution/rowFormatConverters.scala     | 21 +++--
 .../org/apache/spark/sql/UnsafeRowSuite.scala   |  4 +-
 .../execution/joins/HashedRelationSuite.scala   | 49 ++++++++---
 .../spark/unsafe/hash/Murmur3_x86_32.java       | 10 ++-
 20 files changed, 444 insertions(+), 135 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/e0b7ba59/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
index 6ce03a4..7f08bf7 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
@@ -20,10 +20,11 @@ package org.apache.spark.sql.catalyst.expressions;
 import java.io.IOException;
 import java.io.OutputStream;
 
-import org.apache.spark.sql.catalyst.InternalRow;
 import org.apache.spark.sql.catalyst.util.ObjectPool;
 import org.apache.spark.unsafe.PlatformDependent;
+import org.apache.spark.unsafe.array.ByteArrayMethods;
 import org.apache.spark.unsafe.bitset.BitSetMethods;
+import org.apache.spark.unsafe.hash.Murmur3_x86_32;
 import org.apache.spark.unsafe.types.UTF8String;
 
 
@@ -354,7 +355,7 @@ public final class UnsafeRow extends MutableRow {
    * This method is only supported on UnsafeRows that do not use ObjectPools.
    */
   @Override
-  public InternalRow copy() {
+  public UnsafeRow copy() {
     if (pool != null) {
       throw new UnsupportedOperationException(
         "Copy is not supported for UnsafeRows that use object pools");
@@ -405,7 +406,50 @@ public final class UnsafeRow extends MutableRow {
   }
 
   @Override
+  public int hashCode() {
+    return Murmur3_x86_32.hashUnsafeWords(baseObject, baseOffset, sizeInBytes, 42);
+  }
+
+  @Override
+  public boolean equals(Object other) {
+    if (other instanceof UnsafeRow) {
+      UnsafeRow o = (UnsafeRow) other;
+      return (sizeInBytes == o.sizeInBytes) &&
+        ByteArrayMethods.arrayEquals(baseObject, baseOffset, o.baseObject, o.baseOffset,
+          sizeInBytes);
+    }
+    return false;
+  }
+
+  /**
+   * Returns the underlying bytes for this UnsafeRow.
+   */
+  public byte[] getBytes() {
+    if (baseObject instanceof byte[] && baseOffset == PlatformDependent.BYTE_ARRAY_OFFSET
+        && (((byte[]) baseObject).length == sizeInBytes)) {
+      return (byte[]) baseObject;
+    } else {
+      byte[] bytes = new byte[sizeInBytes];
+      PlatformDependent.copyMemory(baseObject, baseOffset, bytes,
+        PlatformDependent.BYTE_ARRAY_OFFSET, sizeInBytes);
+      return bytes;
+    }
+  }
+
+  // This is for debugging
+  @Override
+  public String toString() {
+    StringBuilder build = new StringBuilder("[");
+    for (int i = 0; i < sizeInBytes; i += 8) {
+      build.append(PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + i));
+      build.append(',');
+    }
+    build.append(']');
+    return build.toString();
+  }
+
+  @Override
   public boolean anyNull() {
-    return BitSetMethods.anySet(baseObject, baseOffset, bitSetWidthInBytes);
+    return BitSetMethods.anySet(baseObject, baseOffset, bitSetWidthInBytes / 8);
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/e0b7ba59/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
index d1d81c8..39fd6e1 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
@@ -28,11 +28,10 @@ import org.apache.spark.SparkEnv;
 import org.apache.spark.TaskContext;
 import org.apache.spark.sql.AbstractScalaRowIterator;
 import org.apache.spark.sql.catalyst.InternalRow;
-import org.apache.spark.sql.catalyst.expressions.UnsafeColumnWriter;
 import org.apache.spark.sql.catalyst.expressions.UnsafeProjection;
 import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
 import org.apache.spark.sql.catalyst.util.ObjectPool;
-import org.apache.spark.sql.types.*;
+import org.apache.spark.sql.types.StructType;
 import org.apache.spark.unsafe.PlatformDependent;
 import org.apache.spark.util.collection.unsafe.sort.PrefixComparator;
 import org.apache.spark.util.collection.unsafe.sort.RecordComparator;
@@ -176,12 +175,7 @@ final class UnsafeExternalRowSorter {
    */
   public static boolean supportsSchema(StructType schema) {
     // TODO: add spilling note to explain why we do this for now:
-    for (StructField field : schema.fields()) {
-      if (!UnsafeColumnWriter.canEmbed(field.dataType())) {
-        return false;
-      }
-    }
-    return true;
+    return UnsafeProjection.canSupport(schema);
   }
 
   private static final class RowComparator extends RecordComparator {

http://git-wip-us.apache.org/repos/asf/spark/blob/e0b7ba59/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
index b10a3c8..4a13b68 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
@@ -21,7 +21,6 @@ import org.apache.spark.Logging
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.errors.attachTree
 import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
-import org.apache.spark.sql.catalyst.trees
 import org.apache.spark.sql.types._
 
 /**
@@ -34,7 +33,23 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
 
   override def toString: String = s"input[$ordinal, $dataType]"
 
-  override def eval(input: InternalRow): Any = input(ordinal)
+  // Use special getter for primitive types (for UnsafeRow)
+  override def eval(input: InternalRow): Any = {
+    if (input.isNullAt(ordinal)) {
+      null
+    } else {
+      dataType match {
+        case BooleanType => input.getBoolean(ordinal)
+        case ByteType => input.getByte(ordinal)
+        case ShortType => input.getShort(ordinal)
+        case IntegerType | DateType => input.getInt(ordinal)
+        case LongType | TimestampType => input.getLong(ordinal)
+        case FloatType => input.getFloat(ordinal)
+        case DoubleType => input.getDouble(ordinal)
+        case _ => input.get(ordinal)
+      }
+    }
+  }
 
   override def name: String = s"i[$ordinal]"
 

http://git-wip-us.apache.org/repos/asf/spark/blob/e0b7ba59/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
index 24b01ea..69758e6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
@@ -83,12 +83,42 @@ abstract class UnsafeProjection extends Projection {
 }
 
 object UnsafeProjection {
+
+  /*
+   * Returns whether UnsafeProjection can support given StructType, Array[DataType] or
+   * Seq[Expression].
+   */
+  def canSupport(schema: StructType): Boolean = canSupport(schema.fields.map(_.dataType))
+  def canSupport(types: Array[DataType]): Boolean = types.forall(UnsafeColumnWriter.canEmbed(_))
+  def canSupport(exprs: Seq[Expression]): Boolean = canSupport(exprs.map(_.dataType).toArray)
+
+  /**
+   * Returns an UnsafeProjection for given StructType.
+   */
   def create(schema: StructType): UnsafeProjection = create(schema.fields.map(_.dataType))
 
-  def create(fields: Seq[DataType]): UnsafeProjection = {
+  /**
+   * Returns an UnsafeProjection for given Array of DataTypes.
+   */
+  def create(fields: Array[DataType]): UnsafeProjection = {
     val exprs = fields.zipWithIndex.map(x => new BoundReference(x._2, x._1, true))
+    create(exprs)
+  }
+
+  /**
+   * Returns an UnsafeProjection for given sequence of Expressions (bounded).
+   */
+  def create(exprs: Seq[Expression]): UnsafeProjection = {
     GenerateUnsafeProjection.generate(exprs)
   }
+
+  /**
+   * Returns an UnsafeProjection for given sequence of Expressions, which will be bound to
+   * `inputSchema`.
+   */
+  def create(exprs: Seq[Expression], inputSchema: Seq[Attribute]): UnsafeProjection = {
+    create(exprs.map(BindReferences.bindReference(_, inputSchema)))
+  }
 }
 
 /**
@@ -96,6 +126,8 @@ object UnsafeProjection {
  */
 case class FromUnsafeProjection(fields: Seq[DataType]) extends Projection {
 
+  def this(schema: StructType) = this(schema.fields.map(_.dataType))
+
   private[this] val expressions = fields.zipWithIndex.map { case (dt, idx) =>
     new BoundReference(idx, dt, true)
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/e0b7ba59/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
index 7ffdce6..abaa4a6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
@@ -62,7 +62,7 @@ case class BroadcastHashJoin(
   private val broadcastFuture = future {
     // Note that we use .execute().collect() because we don't want to convert data to Scala types
     val input: Array[InternalRow] = buildPlan.execute().map(_.copy()).collect()
-    val hashed = HashedRelation(input.iterator, buildSideKeyGenerator, input.length)
+    val hashed = buildHashRelation(input.iterator)
     sparkContext.broadcast(hashed)
   }(BroadcastHashJoin.broadcastHashJoinExecutionContext)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/e0b7ba59/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala
index ab757fc..c9d1a88 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala
@@ -17,6 +17,9 @@
 
 package org.apache.spark.sql.execution.joins
 
+import scala.concurrent._
+import scala.concurrent.duration._
+
 import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
@@ -26,10 +29,6 @@ import org.apache.spark.sql.catalyst.plans.{JoinType, LeftOuter, RightOuter}
 import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
 import org.apache.spark.util.ThreadUtils
 
-import scala.collection.JavaConversions._
-import scala.concurrent._
-import scala.concurrent.duration._
-
 /**
  * :: DeveloperApi ::
  * Performs a outer hash join for two child relations.  When the output RDD of this operator is
@@ -58,28 +57,11 @@ case class BroadcastHashOuterJoin(
   override def requiredChildDistribution: Seq[Distribution] =
     UnspecifiedDistribution :: UnspecifiedDistribution :: Nil
 
-  private[this] lazy val (buildPlan, streamedPlan) = joinType match {
-    case RightOuter => (left, right)
-    case LeftOuter => (right, left)
-    case x =>
-      throw new IllegalArgumentException(
-        s"BroadcastHashOuterJoin should not take $x as the JoinType")
-  }
-
-  private[this] lazy val (buildKeys, streamedKeys) = joinType match {
-    case RightOuter => (leftKeys, rightKeys)
-    case LeftOuter => (rightKeys, leftKeys)
-    case x =>
-      throw new IllegalArgumentException(
-        s"BroadcastHashOuterJoin should not take $x as the JoinType")
-  }
-
   @transient
   private val broadcastFuture = future {
     // Note that we use .execute().collect() because we don't want to convert data to Scala types
     val input: Array[InternalRow] = buildPlan.execute().map(_.copy()).collect()
-    // buildHashTable uses code-generated rows as keys, which are not serializable
-    val hashed = buildHashTable(input.iterator, newProjection(buildKeys, buildPlan.output))
+    val hashed = buildHashRelation(input.iterator)
     sparkContext.broadcast(hashed)
   }(BroadcastHashOuterJoin.broadcastHashOuterJoinExecutionContext)
 
@@ -89,21 +71,21 @@ case class BroadcastHashOuterJoin(
     streamedPlan.execute().mapPartitions { streamedIter =>
       val joinedRow = new JoinedRow()
       val hashTable = broadcastRelation.value
-      val keyGenerator = newProjection(streamedKeys, streamedPlan.output)
+      val keyGenerator = streamedKeyGenerator
 
       joinType match {
         case LeftOuter =>
           streamedIter.flatMap(currentRow => {
             val rowKey = keyGenerator(currentRow)
             joinedRow.withLeft(currentRow)
-            leftOuterIterator(rowKey, joinedRow, hashTable.getOrElse(rowKey, EMPTY_LIST))
+            leftOuterIterator(rowKey, joinedRow, hashTable.get(rowKey))
           })
 
         case RightOuter =>
           streamedIter.flatMap(currentRow => {
             val rowKey = keyGenerator(currentRow)
             joinedRow.withRight(currentRow)
-            rightOuterIterator(rowKey, hashTable.getOrElse(rowKey, EMPTY_LIST), joinedRow)
+            rightOuterIterator(rowKey, hashTable.get(rowKey), joinedRow)
           })
 
         case x =>

http://git-wip-us.apache.org/repos/asf/spark/blob/e0b7ba59/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala
index 2750f58..f71c0ce 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala
@@ -40,15 +40,14 @@ case class BroadcastLeftSemiJoinHash(
     val buildIter = right.execute().map(_.copy()).collect().toIterator
 
     if (condition.isEmpty) {
-      // rowKey may be not serializable (from codegen)
-      val hashSet = buildKeyHashSet(buildIter, copy = true)
+      val hashSet = buildKeyHashSet(buildIter)
       val broadcastedRelation = sparkContext.broadcast(hashSet)
 
       left.execute().mapPartitions { streamIter =>
         hashSemiJoin(streamIter, broadcastedRelation.value)
       }
     } else {
-      val hashRelation = HashedRelation(buildIter, rightKeyGenerator)
+      val hashRelation = buildHashRelation(buildIter)
       val broadcastedRelation = sparkContext.broadcast(hashRelation)
 
       left.execute().mapPartitions { streamIter =>

http://git-wip-us.apache.org/repos/asf/spark/blob/e0b7ba59/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala
index 60b4266..7006369 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala
@@ -44,6 +44,19 @@ case class BroadcastNestedLoopJoin(
     case BuildLeft => (right, left)
   }
 
+  override def outputsUnsafeRows: Boolean = left.outputsUnsafeRows || right.outputsUnsafeRows
+  override def canProcessUnsafeRows: Boolean = true
+
+  @transient private[this] lazy val resultProjection: Projection = {
+    if (outputsUnsafeRows) {
+      UnsafeProjection.create(schema)
+    } else {
+      new Projection {
+        override def apply(r: InternalRow): InternalRow = r
+      }
+    }
+  }
+
   override def outputPartitioning: Partitioning = streamed.outputPartitioning
 
   override def output: Seq[Attribute] = {
@@ -74,6 +87,7 @@ case class BroadcastNestedLoopJoin(
       val includedBroadcastTuples =
         new scala.collection.mutable.BitSet(broadcastedRelation.value.size)
       val joinedRow = new JoinedRow
+
       val leftNulls = new GenericMutableRow(left.output.size)
       val rightNulls = new GenericMutableRow(right.output.size)
 
@@ -86,11 +100,11 @@ case class BroadcastNestedLoopJoin(
           val broadcastedRow = broadcastedRelation.value(i)
           buildSide match {
             case BuildRight if boundCondition(joinedRow(streamedRow, broadcastedRow)) =>
-              matchedRows += joinedRow(streamedRow, broadcastedRow).copy()
+              matchedRows += resultProjection(joinedRow(streamedRow, broadcastedRow)).copy()
               streamRowMatched = true
               includedBroadcastTuples += i
             case BuildLeft if boundCondition(joinedRow(broadcastedRow, streamedRow)) =>
-              matchedRows += joinedRow(broadcastedRow, streamedRow).copy()
+              matchedRows += resultProjection(joinedRow(broadcastedRow, streamedRow)).copy()
               streamRowMatched = true
               includedBroadcastTuples += i
             case _ =>
@@ -100,9 +114,9 @@ case class BroadcastNestedLoopJoin(
 
         (streamRowMatched, joinType, buildSide) match {
           case (false, LeftOuter | FullOuter, BuildRight) =>
-            matchedRows += joinedRow(streamedRow, rightNulls).copy()
+            matchedRows += resultProjection(joinedRow(streamedRow, rightNulls)).copy()
           case (false, RightOuter | FullOuter, BuildLeft) =>
-            matchedRows += joinedRow(leftNulls, streamedRow).copy()
+            matchedRows += resultProjection(joinedRow(leftNulls, streamedRow)).copy()
           case _ =>
         }
       }
@@ -110,12 +124,9 @@ case class BroadcastNestedLoopJoin(
     }
 
     val includedBroadcastTuples = matchesOrStreamedRowsWithNulls.map(_._2)
-    val allIncludedBroadcastTuples =
-      if (includedBroadcastTuples.count == 0) {
-        new scala.collection.mutable.BitSet(broadcastedRelation.value.size)
-      } else {
-        includedBroadcastTuples.reduce(_ ++ _)
-      }
+    val allIncludedBroadcastTuples = includedBroadcastTuples.fold(
+      new scala.collection.mutable.BitSet(broadcastedRelation.value.size)
+    )(_ ++ _)
 
     val leftNulls = new GenericMutableRow(left.output.size)
     val rightNulls = new GenericMutableRow(right.output.size)
@@ -127,8 +138,10 @@ case class BroadcastNestedLoopJoin(
       while (i < rel.length) {
         if (!allIncludedBroadcastTuples.contains(i)) {
           (joinType, buildSide) match {
-            case (RightOuter | FullOuter, BuildRight) => buf += new JoinedRow(leftNulls, rel(i))
-            case (LeftOuter | FullOuter, BuildLeft) => buf += new JoinedRow(rel(i), rightNulls)
+            case (RightOuter | FullOuter, BuildRight) =>
+              buf += resultProjection(new JoinedRow(leftNulls, rel(i)))
+            case (LeftOuter | FullOuter, BuildLeft) =>
+              buf += resultProjection(new JoinedRow(rel(i), rightNulls))
             case _ =>
           }
         }

http://git-wip-us.apache.org/repos/asf/spark/blob/e0b7ba59/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
index ff85ea3..ae34409 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
@@ -44,11 +44,20 @@ trait HashJoin {
 
   override def output: Seq[Attribute] = left.output ++ right.output
 
-  @transient protected lazy val buildSideKeyGenerator: Projection =
-    newProjection(buildKeys, buildPlan.output)
+  protected[this] def supportUnsafe: Boolean = {
+    (self.codegenEnabled && UnsafeProjection.canSupport(buildKeys)
+      && UnsafeProjection.canSupport(self.schema))
+  }
+
+  override def outputsUnsafeRows: Boolean = supportUnsafe
+  override def canProcessUnsafeRows: Boolean = supportUnsafe
 
-  @transient protected lazy val streamSideKeyGenerator: () => MutableProjection =
-    newMutableProjection(streamedKeys, streamedPlan.output)
+  @transient protected lazy val streamSideKeyGenerator: Projection =
+    if (supportUnsafe) {
+      UnsafeProjection.create(streamedKeys, streamedPlan.output)
+    } else {
+      newMutableProjection(streamedKeys, streamedPlan.output)()
+    }
 
   protected def hashJoin(
       streamIter: Iterator[InternalRow],
@@ -61,8 +70,17 @@ trait HashJoin {
 
       // Mutable per row objects.
       private[this] val joinRow = new JoinedRow2
+      private[this] val resultProjection: Projection = {
+        if (supportUnsafe) {
+          UnsafeProjection.create(self.schema)
+        } else {
+          new Projection {
+            override def apply(r: InternalRow): InternalRow = r
+          }
+        }
+      }
 
-      private[this] val joinKeys = streamSideKeyGenerator()
+      private[this] val joinKeys = streamSideKeyGenerator
 
       override final def hasNext: Boolean =
         (currentMatchPosition != -1 && currentMatchPosition < currentHashMatches.size) ||
@@ -74,7 +92,7 @@ trait HashJoin {
           case BuildLeft => joinRow(currentHashMatches(currentMatchPosition), currentStreamedRow)
         }
         currentMatchPosition += 1
-        ret
+        resultProjection(ret)
       }
 
       /**
@@ -89,8 +107,9 @@ trait HashJoin {
 
         while (currentHashMatches == null && streamIter.hasNext) {
           currentStreamedRow = streamIter.next()
-          if (!joinKeys(currentStreamedRow).anyNull) {
-            currentHashMatches = hashedRelation.get(joinKeys.currentValue)
+          val key = joinKeys(currentStreamedRow)
+          if (!key.anyNull) {
+            currentHashMatches = hashedRelation.get(key)
           }
         }
 
@@ -103,4 +122,12 @@ trait HashJoin {
       }
     }
   }
+
+  protected[this] def buildHashRelation(buildIter: Iterator[InternalRow]): HashedRelation = {
+    if (supportUnsafe) {
+      UnsafeHashedRelation(buildIter, buildKeys, buildPlan)
+    } else {
+      HashedRelation(buildIter, newProjection(buildKeys, buildPlan.output))
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/e0b7ba59/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala
index 74a7db7..6bf2f82 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala
@@ -23,7 +23,7 @@ import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning}
-import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter}
+import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.execution.SparkPlan
 import org.apache.spark.util.collection.CompactBuffer
 
@@ -38,7 +38,7 @@ trait HashOuterJoin {
   val left: SparkPlan
   val right: SparkPlan
 
-override def outputPartitioning: Partitioning = joinType match {
+  override def outputPartitioning: Partitioning = joinType match {
     case LeftOuter => left.outputPartitioning
     case RightOuter => right.outputPartitioning
     case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions)
@@ -59,6 +59,49 @@ override def outputPartitioning: Partitioning = joinType match {
     }
   }
 
+  protected[this] lazy val (buildPlan, streamedPlan) = joinType match {
+    case RightOuter => (left, right)
+    case LeftOuter => (right, left)
+    case x =>
+      throw new IllegalArgumentException(
+        s"HashOuterJoin should not take $x as the JoinType")
+  }
+
+  protected[this] lazy val (buildKeys, streamedKeys) = joinType match {
+    case RightOuter => (leftKeys, rightKeys)
+    case LeftOuter => (rightKeys, leftKeys)
+    case x =>
+      throw new IllegalArgumentException(
+        s"HashOuterJoin should not take $x as the JoinType")
+  }
+
+  protected[this] def supportUnsafe: Boolean = {
+    (self.codegenEnabled && joinType != FullOuter
+      && UnsafeProjection.canSupport(buildKeys)
+      && UnsafeProjection.canSupport(self.schema))
+  }
+
+  override def outputsUnsafeRows: Boolean = supportUnsafe
+  override def canProcessUnsafeRows: Boolean = supportUnsafe
+
+  protected[this] def streamedKeyGenerator(): Projection = {
+    if (supportUnsafe) {
+      UnsafeProjection.create(streamedKeys, streamedPlan.output)
+    } else {
+      newProjection(streamedKeys, streamedPlan.output)
+    }
+  }
+
+  @transient private[this] lazy val resultProjection: Projection = {
+    if (supportUnsafe) {
+      UnsafeProjection.create(self.schema)
+    } else {
+      new Projection {
+        override def apply(r: InternalRow): InternalRow = r
+      }
+    }
+  }
+
   @transient private[this] lazy val DUMMY_LIST = CompactBuffer[InternalRow](null)
   @transient protected[this] lazy val EMPTY_LIST = CompactBuffer[InternalRow]()
 
@@ -76,16 +119,20 @@ override def outputPartitioning: Partitioning = joinType match {
       rightIter: Iterable[InternalRow]): Iterator[InternalRow] = {
     val ret: Iterable[InternalRow] = {
       if (!key.anyNull) {
-        val temp = rightIter.collect {
-          case r if boundCondition(joinedRow.withRight(r)) => joinedRow.copy()
+        val temp = if (rightIter != null) {
+          rightIter.collect {
+            case r if boundCondition(joinedRow.withRight(r)) => resultProjection(joinedRow).copy()
+          }
+        } else {
+          List.empty
         }
         if (temp.isEmpty) {
-          joinedRow.withRight(rightNullRow).copy :: Nil
+          resultProjection(joinedRow.withRight(rightNullRow)).copy :: Nil
         } else {
           temp
         }
       } else {
-        joinedRow.withRight(rightNullRow).copy :: Nil
+        resultProjection(joinedRow.withRight(rightNullRow)).copy :: Nil
       }
     }
     ret.iterator
@@ -97,17 +144,21 @@ override def outputPartitioning: Partitioning = joinType match {
       joinedRow: JoinedRow): Iterator[InternalRow] = {
     val ret: Iterable[InternalRow] = {
       if (!key.anyNull) {
-        val temp = leftIter.collect {
-          case l if boundCondition(joinedRow.withLeft(l)) =>
-            joinedRow.copy()
+        val temp = if (leftIter != null) {
+          leftIter.collect {
+            case l if boundCondition(joinedRow.withLeft(l)) =>
+              resultProjection(joinedRow).copy()
+          }
+        } else {
+          List.empty
         }
         if (temp.isEmpty) {
-          joinedRow.withLeft(leftNullRow).copy :: Nil
+          resultProjection(joinedRow.withLeft(leftNullRow)).copy :: Nil
         } else {
           temp
         }
       } else {
-        joinedRow.withLeft(leftNullRow).copy :: Nil
+        resultProjection(joinedRow.withLeft(leftNullRow)).copy :: Nil
       }
     }
     ret.iterator
@@ -159,6 +210,7 @@ override def outputPartitioning: Partitioning = joinType match {
     }
   }
 
+  // This is only used by FullOuter
   protected[this] def buildHashTable(
       iter: Iterator[InternalRow],
       keyGenerator: Projection): JavaHashMap[InternalRow, CompactBuffer[InternalRow]] = {
@@ -178,4 +230,12 @@ override def outputPartitioning: Partitioning = joinType match {
 
     hashTable
   }
+
+  protected[this] def buildHashRelation(buildIter: Iterator[InternalRow]): HashedRelation = {
+    if (supportUnsafe) {
+      UnsafeHashedRelation(buildIter, buildKeys, buildPlan)
+    } else {
+      HashedRelation(buildIter, newProjection(buildKeys, buildPlan.output))
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/e0b7ba59/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala
index 1b983bc..7f49264 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala
@@ -32,34 +32,45 @@ trait HashSemiJoin {
 
   override def output: Seq[Attribute] = left.output
 
-  @transient protected lazy val rightKeyGenerator: Projection =
-    newProjection(rightKeys, right.output)
+  protected[this] def supportUnsafe: Boolean = {
+    (self.codegenEnabled && UnsafeProjection.canSupport(leftKeys)
+      && UnsafeProjection.canSupport(rightKeys)
+      && UnsafeProjection.canSupport(left.schema))
+  }
+
+  override def outputsUnsafeRows: Boolean = right.outputsUnsafeRows
+  override def canProcessUnsafeRows: Boolean = supportUnsafe
+
+  @transient protected lazy val leftKeyGenerator: Projection =
+    if (supportUnsafe) {
+      UnsafeProjection.create(leftKeys, left.output)
+    } else {
+      newMutableProjection(leftKeys, left.output)()
+    }
 
-  @transient protected lazy val leftKeyGenerator: () => MutableProjection =
-    newMutableProjection(leftKeys, left.output)
+  @transient protected lazy val rightKeyGenerator: Projection =
+    if (supportUnsafe) {
+      UnsafeProjection.create(rightKeys, right.output)
+    } else {
+      newMutableProjection(rightKeys, right.output)()
+    }
 
   @transient private lazy val boundCondition =
     newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output)
 
-  protected def buildKeyHashSet(
-      buildIter: Iterator[InternalRow],
-      copy: Boolean): java.util.Set[InternalRow] = {
+  protected def buildKeyHashSet(buildIter: Iterator[InternalRow]): java.util.Set[InternalRow] = {
     val hashSet = new java.util.HashSet[InternalRow]()
     var currentRow: InternalRow = null
 
     // Create a Hash set of buildKeys
+    val rightKey = rightKeyGenerator
     while (buildIter.hasNext) {
       currentRow = buildIter.next()
-      val rowKey = rightKeyGenerator(currentRow)
+      val rowKey = rightKey(currentRow)
       if (!rowKey.anyNull) {
         val keyExists = hashSet.contains(rowKey)
         if (!keyExists) {
-          if (copy) {
-            hashSet.add(rowKey.copy())
-          } else {
-            // rowKey may be not serializable (from codegen)
-            hashSet.add(rowKey)
-          }
+          hashSet.add(rowKey.copy())
         }
       }
     }
@@ -67,25 +78,34 @@ trait HashSemiJoin {
   }
 
   protected def hashSemiJoin(
-      streamIter: Iterator[InternalRow],
-      hashedRelation: HashedRelation): Iterator[InternalRow] = {
-    val joinKeys = leftKeyGenerator()
-    val joinedRow = new JoinedRow
+    streamIter: Iterator[InternalRow],
+    hashSet: java.util.Set[InternalRow]): Iterator[InternalRow] = {
+    val joinKeys = leftKeyGenerator
     streamIter.filter(current => {
-      lazy val rowBuffer = hashedRelation.get(joinKeys.currentValue)
-      !joinKeys(current).anyNull && rowBuffer != null && rowBuffer.exists {
-        (build: InternalRow) => boundCondition(joinedRow(current, build))
-      }
+      val key = joinKeys(current)
+      !key.anyNull && hashSet.contains(key)
     })
   }
 
+  protected def buildHashRelation(buildIter: Iterator[InternalRow]): HashedRelation = {
+    if (supportUnsafe) {
+      UnsafeHashedRelation(buildIter, rightKeys, right)
+    } else {
+      HashedRelation(buildIter, newProjection(rightKeys, right.output))
+    }
+  }
+
   protected def hashSemiJoin(
       streamIter: Iterator[InternalRow],
-      hashSet: java.util.Set[InternalRow]): Iterator[InternalRow] = {
-    val joinKeys = leftKeyGenerator()
+      hashedRelation: HashedRelation): Iterator[InternalRow] = {
+    val joinKeys = leftKeyGenerator
     val joinedRow = new JoinedRow
-    streamIter.filter(current => {
-      !joinKeys(current.copy()).anyNull && hashSet.contains(joinKeys.currentValue)
-    })
+    streamIter.filter { current =>
+      val key = joinKeys(current)
+      lazy val rowBuffer = hashedRelation.get(key)
+      !key.anyNull && rowBuffer != null && rowBuffer.exists {
+        (row: InternalRow) => boundCondition(joinedRow(current, row))
+      }
+    }
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/e0b7ba59/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
index 6b51f5d..8d5731a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
@@ -17,12 +17,13 @@
 
 package org.apache.spark.sql.execution.joins
 
-import java.io.{ObjectInput, ObjectOutput, Externalizable}
+import java.io.{Externalizable, ObjectInput, ObjectOutput}
 import java.util.{HashMap => JavaHashMap}
 
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.Projection
-import org.apache.spark.sql.execution.SparkSqlSerializer
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.execution.{SparkPlan, SparkSqlSerializer}
+import org.apache.spark.sql.types.StructType
 import org.apache.spark.util.collection.CompactBuffer
 
 
@@ -98,7 +99,6 @@ final class UniqueKeyHashedRelation(private var hashTable: JavaHashMap[InternalR
   }
 }
 
-
 // TODO(rxin): a version of [[HashedRelation]] backed by arrays for consecutive integer keys.
 
 
@@ -148,3 +148,80 @@ private[joins] object HashedRelation {
     }
   }
 }
+
+
+/**
+ * A HashedRelation for UnsafeRow, which is backed by BytesToBytesMap that maps the key into a
+ * sequence of values.
+ *
+ * TODO(davies): use BytesToBytesMap
+ */
+private[joins] final class UnsafeHashedRelation(
+    private var hashTable: JavaHashMap[UnsafeRow, CompactBuffer[UnsafeRow]])
+  extends HashedRelation with Externalizable {
+
+  def this() = this(null)  // Needed for serialization
+
+  override def get(key: InternalRow): CompactBuffer[InternalRow] = {
+    val unsafeKey = key.asInstanceOf[UnsafeRow]
+    // Thanks to type eraser
+    hashTable.get(unsafeKey).asInstanceOf[CompactBuffer[InternalRow]]
+  }
+
+  override def writeExternal(out: ObjectOutput): Unit = {
+    writeBytes(out, SparkSqlSerializer.serialize(hashTable))
+  }
+
+  override def readExternal(in: ObjectInput): Unit = {
+    hashTable = SparkSqlSerializer.deserialize(readBytes(in))
+  }
+}
+
+private[joins] object UnsafeHashedRelation {
+
+  def apply(
+      input: Iterator[InternalRow],
+      buildKeys: Seq[Expression],
+      buildPlan: SparkPlan,
+      sizeEstimate: Int = 64): HashedRelation = {
+    val boundedKeys = buildKeys.map(BindReferences.bindReference(_, buildPlan.output))
+    apply(input, boundedKeys, buildPlan.schema, sizeEstimate)
+  }
+
+  // Used for tests
+  def apply(
+      input: Iterator[InternalRow],
+      buildKeys: Seq[Expression],
+      rowSchema: StructType,
+      sizeEstimate: Int): HashedRelation = {
+
+    // TODO: Use BytesToBytesMap.
+    val hashTable = new JavaHashMap[UnsafeRow, CompactBuffer[UnsafeRow]](sizeEstimate)
+    val toUnsafe = UnsafeProjection.create(rowSchema)
+    val keyGenerator = UnsafeProjection.create(buildKeys)
+
+    // Create a mapping of buildKeys -> rows
+    while (input.hasNext) {
+      val currentRow = input.next()
+      val unsafeRow = if (currentRow.isInstanceOf[UnsafeRow]) {
+        currentRow.asInstanceOf[UnsafeRow]
+      } else {
+        toUnsafe(currentRow)
+      }
+      val rowKey = keyGenerator(unsafeRow)
+      if (!rowKey.anyNull) {
+        val existingMatchList = hashTable.get(rowKey)
+        val matchList = if (existingMatchList == null) {
+          val newMatchList = new CompactBuffer[UnsafeRow]()
+          hashTable.put(rowKey.copy(), newMatchList)
+          newMatchList
+        } else {
+          existingMatchList
+        }
+        matchList += unsafeRow.copy()
+      }
+    }
+
+    new UnsafeHashedRelation(hashTable)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/e0b7ba59/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala
index db5be9f..4443455 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala
@@ -39,6 +39,9 @@ case class LeftSemiJoinBNL(
 
   override def output: Seq[Attribute] = left.output
 
+  override def outputsUnsafeRows: Boolean = streamed.outputsUnsafeRows
+  override def canProcessUnsafeRows: Boolean = true
+
   /** The Streamed Relation */
   override def left: SparkPlan = streamed
 

http://git-wip-us.apache.org/repos/asf/spark/blob/e0b7ba59/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala
index 9eaac81..874712a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala
@@ -43,10 +43,10 @@ case class LeftSemiJoinHash(
   protected override def doExecute(): RDD[InternalRow] = {
     right.execute().zipPartitions(left.execute()) { (buildIter, streamIter) =>
       if (condition.isEmpty) {
-        val hashSet = buildKeyHashSet(buildIter, copy = false)
+        val hashSet = buildKeyHashSet(buildIter)
         hashSemiJoin(streamIter, hashSet)
       } else {
-        val hashRelation = HashedRelation(buildIter, rightKeyGenerator)
+        val hashRelation = buildHashRelation(buildIter)
         hashSemiJoin(streamIter, hashRelation)
       }
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/e0b7ba59/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala
index 5439e10..948d0cc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala
@@ -45,7 +45,7 @@ case class ShuffledHashJoin(
 
   protected override def doExecute(): RDD[InternalRow] = {
     buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) =>
-      val hashed = HashedRelation(buildIter, buildSideKeyGenerator)
+      val hashed = buildHashRelation(buildIter)
       hashJoin(streamIter, hashed)
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/e0b7ba59/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala
index ab0a6ad..f54f1ed 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala
@@ -50,24 +50,25 @@ case class ShuffledHashOuterJoin(
       // TODO this probably can be replaced by external sort (sort merged join?)
       joinType match {
         case LeftOuter =>
-          val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output))
-          val keyGenerator = newProjection(leftKeys, left.output)
+          val hashed = buildHashRelation(rightIter)
+          val keyGenerator = streamedKeyGenerator()
           leftIter.flatMap( currentRow => {
             val rowKey = keyGenerator(currentRow)
             joinedRow.withLeft(currentRow)
-            leftOuterIterator(rowKey, joinedRow, rightHashTable.getOrElse(rowKey, EMPTY_LIST))
+            leftOuterIterator(rowKey, joinedRow, hashed.get(rowKey))
           })
 
         case RightOuter =>
-          val leftHashTable = buildHashTable(leftIter, newProjection(leftKeys, left.output))
-          val keyGenerator = newProjection(rightKeys, right.output)
+          val hashed = buildHashRelation(leftIter)
+          val keyGenerator = streamedKeyGenerator()
           rightIter.flatMap ( currentRow => {
             val rowKey = keyGenerator(currentRow)
             joinedRow.withRight(currentRow)
-            rightOuterIterator(rowKey, leftHashTable.getOrElse(rowKey, EMPTY_LIST), joinedRow)
+            rightOuterIterator(rowKey, hashed.get(rowKey), joinedRow)
           })
 
         case FullOuter =>
+          // TODO(davies): use UnsafeRow
           val leftHashTable = buildHashTable(leftIter, newProjection(leftKeys, left.output))
           val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output))
           (leftHashTable.keySet ++ rightHashTable.keySet).iterator.flatMap { key =>

http://git-wip-us.apache.org/repos/asf/spark/blob/e0b7ba59/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala
index 421d510..29f3beb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala
@@ -29,6 +29,9 @@ import org.apache.spark.sql.catalyst.rules.Rule
  */
 @DeveloperApi
 case class ConvertToUnsafe(child: SparkPlan) extends UnaryNode {
+
+  require(UnsafeProjection.canSupport(child.schema), s"Cannot convert ${child.schema} to Unsafe")
+
   override def output: Seq[Attribute] = child.output
   override def outputsUnsafeRows: Boolean = true
   override def canProcessUnsafeRows: Boolean = false
@@ -93,11 +96,19 @@ private[sql] object EnsureRowFormats extends Rule[SparkPlan] {
       }
     case operator: SparkPlan if handlesBothSafeAndUnsafeRows(operator) =>
       if (operator.children.map(_.outputsUnsafeRows).toSet.size != 1) {
-        // If this operator's children produce both unsafe and safe rows, then convert everything
-        // to unsafe rows
-        operator.withNewChildren {
-          operator.children.map {
-            c => if (!c.outputsUnsafeRows) ConvertToUnsafe(c) else c
+        // If this operator's children produce both unsafe and safe rows,
+        // convert everything unsafe rows if all the schema of them are support by UnsafeRow
+        if (operator.children.forall(c => UnsafeProjection.canSupport(c.schema))) {
+          operator.withNewChildren {
+            operator.children.map {
+              c => if (!c.outputsUnsafeRows) ConvertToUnsafe(c) else c
+            }
+          }
+        } else {
+          operator.withNewChildren {
+            operator.children.map {
+              c => if (c.outputsUnsafeRows) ConvertToSafe(c) else c
+            }
           }
         }
       } else {

http://git-wip-us.apache.org/repos/asf/spark/blob/e0b7ba59/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala
index 3854dc1..d36e263 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala
@@ -22,7 +22,7 @@ import java.io.ByteArrayOutputStream
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, UnsafeProjection}
-import org.apache.spark.sql.types.{IntegerType, StringType}
+import org.apache.spark.sql.types.{DataType, IntegerType, StringType}
 import org.apache.spark.unsafe.PlatformDependent
 import org.apache.spark.unsafe.memory.MemoryAllocator
 import org.apache.spark.unsafe.types.UTF8String
@@ -31,7 +31,7 @@ class UnsafeRowSuite extends SparkFunSuite {
   test("writeToStream") {
     val row = InternalRow.apply(UTF8String.fromString("hello"), UTF8String.fromString("world"), 123)
     val arrayBackedUnsafeRow: UnsafeRow =
-      UnsafeProjection.create(Seq(StringType, StringType, IntegerType)).apply(row)
+      UnsafeProjection.create(Array[DataType](StringType, StringType, IntegerType)).apply(row)
     assert(arrayBackedUnsafeRow.getBaseObject.isInstanceOf[Array[Byte]])
     val bytesFromArrayBackedRow: Array[Byte] = {
       val baos = new ByteArrayOutputStream()

http://git-wip-us.apache.org/repos/asf/spark/blob/e0b7ba59/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
index 9d9858b..9dd2220 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
@@ -19,7 +19,9 @@ package org.apache.spark.sql.execution.joins
 
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.Projection
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.execution.SparkSqlSerializer
+import org.apache.spark.sql.types.{StructField, StructType, IntegerType}
 import org.apache.spark.util.collection.CompactBuffer
 
 
@@ -35,13 +37,13 @@ class HashedRelationSuite extends SparkFunSuite {
     val hashed = HashedRelation(data.iterator, keyProjection)
     assert(hashed.isInstanceOf[GeneralHashedRelation])
 
-    assert(hashed.get(data(0)) == CompactBuffer[InternalRow](data(0)))
-    assert(hashed.get(data(1)) == CompactBuffer[InternalRow](data(1)))
+    assert(hashed.get(data(0)) === CompactBuffer[InternalRow](data(0)))
+    assert(hashed.get(data(1)) === CompactBuffer[InternalRow](data(1)))
     assert(hashed.get(InternalRow(10)) === null)
 
     val data2 = CompactBuffer[InternalRow](data(2))
     data2 += data(2)
-    assert(hashed.get(data(2)) == data2)
+    assert(hashed.get(data(2)) === data2)
   }
 
   test("UniqueKeyHashedRelation") {
@@ -49,15 +51,40 @@ class HashedRelationSuite extends SparkFunSuite {
     val hashed = HashedRelation(data.iterator, keyProjection)
     assert(hashed.isInstanceOf[UniqueKeyHashedRelation])
 
-    assert(hashed.get(data(0)) == CompactBuffer[InternalRow](data(0)))
-    assert(hashed.get(data(1)) == CompactBuffer[InternalRow](data(1)))
-    assert(hashed.get(data(2)) == CompactBuffer[InternalRow](data(2)))
+    assert(hashed.get(data(0)) === CompactBuffer[InternalRow](data(0)))
+    assert(hashed.get(data(1)) === CompactBuffer[InternalRow](data(1)))
+    assert(hashed.get(data(2)) === CompactBuffer[InternalRow](data(2)))
     assert(hashed.get(InternalRow(10)) === null)
 
     val uniqHashed = hashed.asInstanceOf[UniqueKeyHashedRelation]
-    assert(uniqHashed.getValue(data(0)) == data(0))
-    assert(uniqHashed.getValue(data(1)) == data(1))
-    assert(uniqHashed.getValue(data(2)) == data(2))
-    assert(uniqHashed.getValue(InternalRow(10)) == null)
+    assert(uniqHashed.getValue(data(0)) === data(0))
+    assert(uniqHashed.getValue(data(1)) === data(1))
+    assert(uniqHashed.getValue(data(2)) === data(2))
+    assert(uniqHashed.getValue(InternalRow(10)) === null)
+  }
+
+  test("UnsafeHashedRelation") {
+    val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2))
+    val buildKey = Seq(BoundReference(0, IntegerType, false))
+    val schema = StructType(StructField("a", IntegerType, true) :: Nil)
+    val hashed = UnsafeHashedRelation(data.iterator, buildKey, schema, 1)
+    assert(hashed.isInstanceOf[UnsafeHashedRelation])
+
+    val toUnsafeKey = UnsafeProjection.create(schema)
+    val unsafeData = data.map(toUnsafeKey(_).copy()).toArray
+    assert(hashed.get(unsafeData(0)) === CompactBuffer[InternalRow](unsafeData(0)))
+    assert(hashed.get(unsafeData(1)) === CompactBuffer[InternalRow](unsafeData(1)))
+    assert(hashed.get(toUnsafeKey(InternalRow(10))) === null)
+
+    val data2 = CompactBuffer[InternalRow](unsafeData(2).copy())
+    data2 += unsafeData(2).copy()
+    assert(hashed.get(unsafeData(2)) === data2)
+
+    val hashed2 = SparkSqlSerializer.deserialize(SparkSqlSerializer.serialize(hashed))
+      .asInstanceOf[UnsafeHashedRelation]
+    assert(hashed2.get(unsafeData(0)) === CompactBuffer[InternalRow](unsafeData(0)))
+    assert(hashed2.get(unsafeData(1)) === CompactBuffer[InternalRow](unsafeData(1)))
+    assert(hashed2.get(toUnsafeKey(InternalRow(10))) === null)
+    assert(hashed2.get(unsafeData(2)) === data2)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/e0b7ba59/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java
----------------------------------------------------------------------
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java b/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java
index 85cd024..61f483c 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java
@@ -44,12 +44,16 @@ public final class Murmur3_x86_32 {
     return fmix(h1, 4);
   }
 
-  public int hashUnsafeWords(Object baseObject, long baseOffset, int lengthInBytes) {
+  public int hashUnsafeWords(Object base, long offset, int lengthInBytes) {
+    return hashUnsafeWords(base, offset, lengthInBytes, seed);
+  }
+
+  public static int hashUnsafeWords(Object base, long offset, int lengthInBytes, int seed) {
     // This is based on Guava's `Murmur32_Hasher.processRemaining(ByteBuffer)` method.
     assert (lengthInBytes % 8 == 0): "lengthInBytes must be a multiple of 8 (word-aligned)";
     int h1 = seed;
-    for (int offset = 0; offset < lengthInBytes; offset += 4) {
-      int halfWord = PlatformDependent.UNSAFE.getInt(baseObject, baseOffset + offset);
+    for (int i = 0; i < lengthInBytes; i += 4) {
+      int halfWord = PlatformDependent.UNSAFE.getInt(base, offset + i);
       int k1 = mixK1(halfWord);
       h1 = mixH1(h1, k1);
     }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org