You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by an...@apache.org on 2015/10/08 20:56:48 UTC

spark git commit: [SPARK-10887] [SQL] Build HashedRelation outside of HashJoinNode.

Repository: spark
Updated Branches:
  refs/heads/master 2a6f614cd -> 82d275f27


[SPARK-10887] [SQL] Build HashedRelation outside of HashJoinNode.

This PR refactors `HashJoinNode` to take a existing `HashedRelation`. So, we can reuse this node for both `ShuffledHashJoin` and `BroadcastHashJoin`.

https://issues.apache.org/jira/browse/SPARK-10887

Author: Yin Huai <yh...@databricks.com>

Closes #8953 from yhuai/SPARK-10887.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/82d275f2
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/82d275f2
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/82d275f2

Branch: refs/heads/master
Commit: 82d275f27c3e9211ce69c5c8685a0fe90c0be26f
Parents: 2a6f614
Author: Yin Huai <yh...@databricks.com>
Authored: Thu Oct 8 11:56:44 2015 -0700
Committer: Andrew Or <an...@databricks.com>
Committed: Thu Oct 8 11:56:44 2015 -0700

----------------------------------------------------------------------
 .../codegen/GenerateMutableProjection.scala     |  2 +
 .../codegen/GenerateSafeProjection.scala        |  4 +-
 .../execution/local/BinaryHashJoinNode.scala    | 76 +++++++++++++++++
 .../execution/local/BroadcastHashJoinNode.scala | 59 ++++++++++++++
 .../sql/execution/local/HashJoinNode.scala      | 67 +++++++--------
 .../sql/execution/local/HashJoinNodeSuite.scala | 85 +++++++++++++++++---
 .../sql/execution/local/LocalNodeTest.scala     | 20 ++++-
 7 files changed, 262 insertions(+), 51 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/82d275f2/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
index d82d191..e8ee647 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
@@ -27,6 +27,8 @@ abstract class BaseMutableProjection extends MutableProjection
 /**
  * Generates byte code that produces a [[MutableRow]] object that can update itself based on a new
  * input [[InternalRow]] for a fixed set of [[Expression Expressions]].
+ * It exposes a `target` method, which is used to set the row that will be updated.
+ * The internal [[MutableRow]] object created internally is used only when `target` is not used.
  */
 object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => MutableProjection] {
 

http://git-wip-us.apache.org/repos/asf/spark/blob/82d275f2/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
index ea09e02..9873630 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
@@ -23,8 +23,8 @@ import org.apache.spark.sql.types._
 
 
 /**
- * Generates byte code that produces a [[MutableRow]] object that can update itself based on a new
- * input [[InternalRow]] for a fixed set of [[Expression Expressions]].
+ * Generates byte code that produces a [[MutableRow]] object (not an [[UnsafeRow]]) that can update
+ * itself based on a new input [[InternalRow]] for a fixed set of [[Expression Expressions]].
  */
 object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] {
 

http://git-wip-us.apache.org/repos/asf/spark/blob/82d275f2/sql/core/src/main/scala/org/apache/spark/sql/execution/local/BinaryHashJoinNode.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/BinaryHashJoinNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/BinaryHashJoinNode.scala
new file mode 100644
index 0000000..52dcb9e
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/BinaryHashJoinNode.scala
@@ -0,0 +1,76 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements.  See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License.  You may obtain a copy of the License at
+*
+*    http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql.execution.local
+
+import org.apache.spark.sql.SQLConf
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.execution.joins.{HashedRelation, BuildLeft, BuildRight, BuildSide}
+
+/**
+ * A [[HashJoinNode]] that builds the [[HashedRelation]] according to the value of
+ * `buildSide`. The actual work of this node is defined in [[HashJoinNode]].
+ */
+case class BinaryHashJoinNode(
+    conf: SQLConf,
+    leftKeys: Seq[Expression],
+    rightKeys: Seq[Expression],
+    buildSide: BuildSide,
+    left: LocalNode,
+    right: LocalNode)
+  extends BinaryLocalNode(conf) with HashJoinNode {
+
+  protected override val (streamedNode, streamedKeys) = buildSide match {
+    case BuildLeft => (right, rightKeys)
+    case BuildRight => (left, leftKeys)
+  }
+
+  private val (buildNode, buildKeys) = buildSide match {
+    case BuildLeft => (left, leftKeys)
+    case BuildRight => (right, rightKeys)
+  }
+
+  override def output: Seq[Attribute] = left.output ++ right.output
+
+  private def buildSideKeyGenerator: Projection = {
+    // We are expecting the data types of buildKeys and streamedKeys are the same.
+    assert(buildKeys.map(_.dataType) == streamedKeys.map(_.dataType))
+    if (isUnsafeMode) {
+      UnsafeProjection.create(buildKeys, buildNode.output)
+    } else {
+      newMutableProjection(buildKeys, buildNode.output)()
+    }
+  }
+
+  protected override def doOpen(): Unit = {
+    buildNode.open()
+    val hashedRelation = HashedRelation(buildNode, buildSideKeyGenerator)
+    // We have built the HashedRelation. So, close buildNode.
+    buildNode.close()
+
+    streamedNode.open()
+    // Set the HashedRelation used by the HashJoinNode.
+    withHashedRelation(hashedRelation)
+  }
+
+  override def close(): Unit = {
+    // Please note that we do not need to call the close method of our buildNode because
+    // it has been called in this.open.
+    streamedNode.close()
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/82d275f2/sql/core/src/main/scala/org/apache/spark/sql/execution/local/BroadcastHashJoinNode.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/BroadcastHashJoinNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/BroadcastHashJoinNode.scala
new file mode 100644
index 0000000..cd1c865
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/BroadcastHashJoinNode.scala
@@ -0,0 +1,59 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements.  See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License.  You may obtain a copy of the License at
+*
+*    http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql.execution.local
+
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.sql.SQLConf
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide, HashedRelation}
+
+/**
+ * A [[HashJoinNode]] for broadcast join. It takes a streamedNode and a broadcast
+ * [[HashedRelation]]. The actual work of this node is defined in [[HashJoinNode]].
+ */
+case class BroadcastHashJoinNode(
+    conf: SQLConf,
+    streamedKeys: Seq[Expression],
+    streamedNode: LocalNode,
+    buildSide: BuildSide,
+    buildOutput: Seq[Attribute],
+    hashedRelation: Broadcast[HashedRelation])
+  extends UnaryLocalNode(conf) with HashJoinNode {
+
+  override val child = streamedNode
+
+  // Because we do not pass in the buildNode, we take the output of buildNode to
+  // create the inputSet properly.
+  override def inputSet: AttributeSet = AttributeSet(child.output ++ buildOutput)
+
+  override def output: Seq[Attribute] = buildSide match {
+    case BuildRight => streamedNode.output ++ buildOutput
+    case BuildLeft => buildOutput ++ streamedNode.output
+  }
+
+  protected override def doOpen(): Unit = {
+    streamedNode.open()
+    // Set the HashedRelation used by the HashJoinNode.
+    withHashedRelation(hashedRelation.value)
+  }
+
+  override def close(): Unit = {
+    streamedNode.close()
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/82d275f2/sql/core/src/main/scala/org/apache/spark/sql/execution/local/HashJoinNode.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/HashJoinNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/HashJoinNode.scala
index e7b24e3..b1dc719 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/HashJoinNode.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/HashJoinNode.scala
@@ -17,27 +17,23 @@
 
 package org.apache.spark.sql.execution.local
 
-import org.apache.spark.sql.SQLConf
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.execution.joins._
-import org.apache.spark.sql.execution.metric.SQLMetrics
 
 /**
+ * An abstract node for sharing common functionality among different implementations of
+ * inner hash equi-join, notably [[BinaryHashJoinNode]] and [[BroadcastHashJoinNode]].
+ *
  * Much of this code is similar to [[org.apache.spark.sql.execution.joins.HashJoin]].
  */
-case class HashJoinNode(
-    conf: SQLConf,
-    leftKeys: Seq[Expression],
-    rightKeys: Seq[Expression],
-    buildSide: BuildSide,
-    left: LocalNode,
-    right: LocalNode) extends BinaryLocalNode(conf) {
-
-  private[this] lazy val (buildNode, buildKeys, streamedNode, streamedKeys) = buildSide match {
-    case BuildLeft => (left, leftKeys, right, rightKeys)
-    case BuildRight => (right, rightKeys, left, leftKeys)
-  }
+trait HashJoinNode {
+
+  self: LocalNode =>
+
+  protected def streamedKeys: Seq[Expression]
+  protected def streamedNode: LocalNode
+  protected def buildSide: BuildSide
 
   private[this] var currentStreamedRow: InternalRow = _
   private[this] var currentHashMatches: Seq[InternalRow] = _
@@ -49,23 +45,14 @@ case class HashJoinNode(
   private[this] var hashed: HashedRelation = _
   private[this] var joinKeys: Projection = _
 
-  override def output: Seq[Attribute] = left.output ++ right.output
-
-  private[this] def isUnsafeMode: Boolean = {
-    (codegenEnabled && unsafeEnabled
-      && UnsafeProjection.canSupport(buildKeys)
-      && UnsafeProjection.canSupport(schema))
-  }
-
-  private[this] def buildSideKeyGenerator: Projection = {
-    if (isUnsafeMode) {
-      UnsafeProjection.create(buildKeys, buildNode.output)
-    } else {
-      newMutableProjection(buildKeys, buildNode.output)()
-    }
+  protected def isUnsafeMode: Boolean = {
+    (codegenEnabled &&
+      unsafeEnabled &&
+      UnsafeProjection.canSupport(schema) &&
+      UnsafeProjection.canSupport(streamedKeys))
   }
 
-  private[this] def streamSideKeyGenerator: Projection = {
+  private def streamSideKeyGenerator: Projection = {
     if (isUnsafeMode) {
       UnsafeProjection.create(streamedKeys, streamedNode.output)
     } else {
@@ -73,10 +60,21 @@ case class HashJoinNode(
     }
   }
 
+  /**
+   * Sets the HashedRelation used by this node. This method needs to be called after
+   * before the first `next` gets called.
+   */
+  protected def withHashedRelation(hashedRelation: HashedRelation): Unit = {
+    hashed = hashedRelation
+  }
+
+  /**
+   * Custom open implementation to be overridden by subclasses.
+   */
+  protected def doOpen(): Unit
+
   override def open(): Unit = {
-    buildNode.open()
-    hashed = HashedRelation(buildNode, buildSideKeyGenerator)
-    streamedNode.open()
+    doOpen()
     joinRow = new JoinedRow
     resultProjection = {
       if (isUnsafeMode) {
@@ -128,9 +126,4 @@ case class HashJoinNode(
     }
     resultProjection(ret)
   }
-
-  override def close(): Unit = {
-    left.close()
-    right.close()
-  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/82d275f2/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala
index 5c1bdb0..8c2e78b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala
@@ -17,10 +17,13 @@
 
 package org.apache.spark.sql.execution.local
 
+import org.mockito.Mockito.{mock, when}
+
+import org.apache.spark.broadcast.TorrentBroadcast
 import org.apache.spark.sql.SQLConf
 import org.apache.spark.sql.catalyst.dsl.expressions._
-import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide}
-
+import org.apache.spark.sql.catalyst.expressions.{InterpretedMutableProjection, UnsafeProjection, Expression}
+import org.apache.spark.sql.execution.joins.{HashedRelation, BuildLeft, BuildRight, BuildSide}
 
 class HashJoinNodeSuite extends LocalNodeTest {
 
@@ -34,6 +37,35 @@ class HashJoinNodeSuite extends LocalNodeTest {
   }
 
   /**
+   * Builds a [[HashedRelation]] based on a resolved `buildKeys`
+   * and a resolved `buildNode`.
+   */
+  private def buildHashedRelation(
+      conf: SQLConf,
+      buildKeys: Seq[Expression],
+      buildNode: LocalNode): HashedRelation = {
+
+    val isUnsafeMode =
+      conf.codegenEnabled &&
+        conf.unsafeEnabled &&
+        UnsafeProjection.canSupport(buildKeys)
+
+    val buildSideKeyGenerator =
+      if (isUnsafeMode) {
+        UnsafeProjection.create(buildKeys, buildNode.output)
+      } else {
+        new InterpretedMutableProjection(buildKeys, buildNode.output)
+      }
+
+    buildNode.prepare()
+    buildNode.open()
+    val hashedRelation = HashedRelation(buildNode, buildSideKeyGenerator)
+    buildNode.close()
+
+    hashedRelation
+  }
+
+  /**
    * Test inner hash join with varying degrees of matches.
    */
   private def testJoin(
@@ -51,20 +83,51 @@ class HashJoinNodeSuite extends LocalNodeTest {
       val rightInputMap = rightInput.toMap
       val leftNode = new DummyNode(joinNameAttributes, leftInput)
       val rightNode = new DummyNode(joinNicknameAttributes, rightInput)
-      val makeNode = (node1: LocalNode, node2: LocalNode) => {
-        resolveExpressions(new HashJoinNode(
-          conf, Seq('id1), Seq('id2), buildSide, node1, node2))
+      val makeBinaryHashJoinNode = (node1: LocalNode, node2: LocalNode) => {
+        val binaryHashJoinNode =
+          BinaryHashJoinNode(conf, Seq('id1), Seq('id2), buildSide, node1, node2)
+        resolveExpressions(binaryHashJoinNode)
+      }
+      val makeBroadcastJoinNode = (node1: LocalNode, node2: LocalNode) => {
+        val leftKeys = Seq('id1.attr)
+        val rightKeys = Seq('id2.attr)
+        // Figure out the build side and stream side.
+        val (buildNode, buildKeys, streamedNode, streamedKeys) = buildSide match {
+          case BuildLeft => (node1, leftKeys, node2, rightKeys)
+          case BuildRight => (node2, rightKeys, node1, leftKeys)
+        }
+        // Resolve the expressions of the build side and then create a HashedRelation.
+        val resolvedBuildNode = resolveExpressions(buildNode)
+        val resolvedBuildKeys = resolveExpressions(buildKeys, resolvedBuildNode)
+        val hashedRelation = buildHashedRelation(conf, resolvedBuildKeys, resolvedBuildNode)
+        val broadcastHashedRelation = mock(classOf[TorrentBroadcast[HashedRelation]])
+        when(broadcastHashedRelation.value).thenReturn(hashedRelation)
+
+        val hashJoinNode =
+          BroadcastHashJoinNode(
+            conf,
+            streamedKeys,
+            streamedNode,
+            buildSide,
+            resolvedBuildNode.output,
+            broadcastHashedRelation)
+        resolveExpressions(hashJoinNode)
       }
-      val makeUnsafeNode = if (unsafeAndCodegen) wrapForUnsafe(makeNode) else makeNode
-      val hashJoinNode = makeUnsafeNode(leftNode, rightNode)
+
       val expectedOutput = leftInput
         .filter { case (k, _) => rightInputMap.contains(k) }
         .map { case (k, v) => (k, v, k, rightInputMap(k)) }
-      val actualOutput = hashJoinNode.collect().map { row =>
-        // (id, name, id, nickname)
-        (row.getInt(0), row.getString(1), row.getInt(2), row.getString(3))
+
+      Seq(makeBinaryHashJoinNode, makeBroadcastJoinNode).foreach { makeNode =>
+        val makeUnsafeNode = if (unsafeAndCodegen) wrapForUnsafe(makeNode) else makeNode
+        val hashJoinNode = makeUnsafeNode(leftNode, rightNode)
+
+        val actualOutput = hashJoinNode.collect().map { row =>
+          // (id, name, id, nickname)
+          (row.getInt(0), row.getString(1), row.getInt(2), row.getString(3))
+        }
+        assert(actualOutput === expectedOutput)
       }
-      assert(actualOutput === expectedOutput)
     }
 
     test(s"$testNamePrefix: empty") {

http://git-wip-us.apache.org/repos/asf/spark/blob/82d275f2/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala
index 098050b..615c417 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.local
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.SQLConf
 import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
-import org.apache.spark.sql.catalyst.expressions.AttributeReference
+import org.apache.spark.sql.catalyst.expressions.{Expression, AttributeReference}
 import org.apache.spark.sql.types.{IntegerType, StringType}
 
 
@@ -67,4 +67,22 @@ class LocalNodeTest extends SparkFunSuite {
     }
   }
 
+  /**
+   * Resolve all expressions in `expressions` based on the `output` of `localNode`.
+   * It assumes that all expressions in the `localNode` are resolved.
+   */
+  protected def resolveExpressions(
+      expressions: Seq[Expression],
+      localNode: LocalNode): Seq[Expression] = {
+    require(localNode.expressions.forall(_.resolved))
+    val inputMap = localNode.output.map { a => (a.name, a) }.toMap
+    expressions.map { expression =>
+      expression.transformUp {
+        case UnresolvedAttribute(Seq(u)) =>
+          inputMap.getOrElse(u,
+            sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap"))
+      }
+    }
+  }
+
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org