You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2016/02/23 01:34:04 UTC

spark git commit: [SPARK-13422][SQL] Use HashedRelation instead of HashSet in Left Semi Joins

Repository: spark
Updated Branches:
  refs/heads/master 173aa949c -> 206378184


[SPARK-13422][SQL] Use HashedRelation instead of HashSet in Left Semi Joins

Use the HashedRelation which is a more optimized datastructure and reduce code complexity

Author: Xiu Guo <xg...@gmail.com>

Closes #11291 from xguo27/SPARK-13422.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/20637818
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/20637818
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/20637818

Branch: refs/heads/master
Commit: 2063781840831469b394313694bfd25cbde2bb1e
Parents: 173aa94
Author: Xiu Guo <xg...@gmail.com>
Authored: Mon Feb 22 16:34:02 2016 -0800
Committer: Reynold Xin <rx...@databricks.com>
Committed: Mon Feb 22 16:34:02 2016 -0800

----------------------------------------------------------------------
 .../joins/BroadcastLeftSemiJoinHash.scala       | 27 +++-------
 .../sql/execution/joins/HashSemiJoin.scala      | 55 +-------------------
 .../sql/execution/joins/LeftSemiJoinHash.scala  | 13 ++---
 3 files changed, 14 insertions(+), 81 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/20637818/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala
index 1f99fbe..d3bcfad 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala
@@ -26,8 +26,8 @@ import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
 import org.apache.spark.sql.execution.metric.SQLMetrics
 
 /**
- * Build the right table's join keys into a HashSet, and iteratively go through the left
- * table, to find the if join keys are in the Hash set.
+ * Build the right table's join keys into a HashedRelation, and iteratively go through the left
+ * table, to find if the join keys are in the HashedRelation.
  */
 case class BroadcastLeftSemiJoinHash(
     leftKeys: Seq[Expression],
@@ -40,29 +40,18 @@ case class BroadcastLeftSemiJoinHash(
     "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
 
   override def requiredChildDistribution: Seq[Distribution] = {
-    val mode = if (condition.isEmpty) {
-      HashSetBroadcastMode(rightKeys, right.output)
-    } else {
-      HashedRelationBroadcastMode(canJoinKeyFitWithinLong = false, rightKeys, right.output)
-    }
+    val mode = HashedRelationBroadcastMode(canJoinKeyFitWithinLong = false, rightKeys, right.output)
     UnspecifiedDistribution :: BroadcastDistribution(mode) :: Nil
   }
 
   protected override def doExecute(): RDD[InternalRow] = {
     val numOutputRows = longMetric("numOutputRows")
 
-    if (condition.isEmpty) {
-      val broadcastedRelation = right.executeBroadcast[java.util.Set[InternalRow]]()
-      left.execute().mapPartitionsInternal { streamIter =>
-        hashSemiJoin(streamIter, broadcastedRelation.value, numOutputRows)
-      }
-    } else {
-      val broadcastedRelation = right.executeBroadcast[HashedRelation]()
-      left.execute().mapPartitionsInternal { streamIter =>
-        val hashedRelation = broadcastedRelation.value
-        TaskContext.get().taskMetrics().incPeakExecutionMemory(hashedRelation.getMemorySize)
-        hashSemiJoin(streamIter, hashedRelation, numOutputRows)
-      }
+    val broadcastedRelation = right.executeBroadcast[HashedRelation]()
+    left.execute().mapPartitionsInternal { streamIter =>
+      val hashedRelation = broadcastedRelation.value
+      TaskContext.get().taskMetrics().incPeakExecutionMemory(hashedRelation.getMemorySize)
+      hashSemiJoin(streamIter, hashedRelation, numOutputRows)
     }
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/20637818/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala
index 1cb6a00..3eed6e3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala
@@ -43,24 +43,6 @@ trait HashSemiJoin {
   @transient private lazy val boundCondition =
     newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output)
 
-  protected def buildKeyHashSet(
-      buildIter: Iterator[InternalRow]): java.util.Set[InternalRow] = {
-    HashSemiJoin.buildKeyHashSet(rightKeys, right.output, buildIter)
-  }
-
-  protected def hashSemiJoin(
-    streamIter: Iterator[InternalRow],
-    hashSet: java.util.Set[InternalRow],
-    numOutputRows: LongSQLMetric): Iterator[InternalRow] = {
-    val joinKeys = leftKeyGenerator
-    streamIter.filter(current => {
-      val key = joinKeys(current)
-      val r = !key.anyNull && hashSet.contains(key)
-      if (r) numOutputRows += 1
-      r
-    })
-  }
-
   protected def hashSemiJoin(
       streamIter: Iterator[InternalRow],
       hashedRelation: HashedRelation,
@@ -70,44 +52,11 @@ trait HashSemiJoin {
     streamIter.filter { current =>
       val key = joinKeys(current)
       lazy val rowBuffer = hashedRelation.get(key)
-      val r = !key.anyNull && rowBuffer != null && rowBuffer.exists {
+      val r = !key.anyNull && rowBuffer != null && (condition.isEmpty || rowBuffer.exists {
         (row: InternalRow) => boundCondition(joinedRow(current, row))
-      }
+      })
       if (r) numOutputRows += 1
       r
     }
   }
 }
-
-private[execution] object HashSemiJoin {
-  def buildKeyHashSet(
-    keys: Seq[Expression],
-    attributes: Seq[Attribute],
-    rows: Iterator[InternalRow]): java.util.HashSet[InternalRow] = {
-    val hashSet = new java.util.HashSet[InternalRow]()
-
-    // Create a Hash set of buildKeys
-    val key = UnsafeProjection.create(keys, attributes)
-    while (rows.hasNext) {
-      val currentRow = rows.next()
-      val rowKey = key(currentRow)
-      if (!rowKey.anyNull) {
-        val keyExists = hashSet.contains(rowKey)
-        if (!keyExists) {
-          hashSet.add(rowKey.copy())
-        }
-      }
-    }
-    hashSet
-  }
-}
-
-/** HashSetBroadcastMode requires that the input rows are broadcasted as a set. */
-private[execution] case class HashSetBroadcastMode(
-    keys: Seq[Expression],
-    attributes: Seq[Attribute]) extends BroadcastMode {
-
-  override def transform(rows: Array[InternalRow]): java.util.HashSet[InternalRow] = {
-    HashSemiJoin.buildKeyHashSet(keys, attributes, rows.iterator)
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/20637818/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala
index d8d3045..242ed61 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala
@@ -25,8 +25,8 @@ import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
 import org.apache.spark.sql.execution.metric.SQLMetrics
 
 /**
- * Build the right table's join keys into a HashSet, and iteratively go through the left
- * table, to find the if join keys are in the Hash set.
+ * Build the right table's join keys into a HashedRelation, and iteratively go through the left
+ * table, to find if the join keys are in the HashedRelation.
  */
 case class LeftSemiJoinHash(
     leftKeys: Seq[Expression],
@@ -47,13 +47,8 @@ case class LeftSemiJoinHash(
     val numOutputRows = longMetric("numOutputRows")
 
     right.execute().zipPartitions(left.execute()) { (buildIter, streamIter) =>
-      if (condition.isEmpty) {
-        val hashSet = buildKeyHashSet(buildIter)
-        hashSemiJoin(streamIter, hashSet, numOutputRows)
-      } else {
-        val hashRelation = HashedRelation(buildIter, rightKeyGenerator)
-        hashSemiJoin(streamIter, hashRelation, numOutputRows)
-      }
+      val hashRelation = HashedRelation(buildIter, rightKeyGenerator)
+      hashSemiJoin(streamIter, hashRelation, numOutputRows)
     }
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org