You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by da...@apache.org on 2016/03/21 19:21:45 UTC

spark git commit: [SPARK-14007] [SQL] Manage the memory used by hash map in shuffled hash join

Repository: spark
Updated Branches:
  refs/heads/master 5d8de16e7 -> 9b4e15ba1


[SPARK-14007] [SQL] Manage the memory used by hash map in shuffled hash join

## What changes were proposed in this pull request?

This PR try acquire the memory for hash map in shuffled hash join, fail the task if there is no enough memory (otherwise it could OOM the executor).

It also removed unused HashedRelation.

## How was this patch tested?

Existing unit tests. Manual tests with TPCDS Q78.

Author: Davies Liu <da...@databricks.com>

Closes #11826 from davies/cleanup_hash2.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/9b4e15ba
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/9b4e15ba
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/9b4e15ba

Branch: refs/heads/master
Commit: 9b4e15ba13f62cff302d978093633fc3181a8475
Parents: 5d8de16
Author: Davies Liu <da...@databricks.com>
Authored: Mon Mar 21 11:21:39 2016 -0700
Committer: Davies Liu <da...@gmail.com>
Committed: Mon Mar 21 11:21:39 2016 -0700

----------------------------------------------------------------------
 .../apache/spark/memory/TaskMemoryManager.java  |  2 +-
 .../spark/sql/execution/SparkStrategies.scala   |  6 +-
 .../sql/execution/joins/HashedRelation.scala    | 98 ++------------------
 .../sql/execution/joins/ShuffledHashJoin.scala  | 43 ++++++++-
 .../execution/joins/HashedRelationSuite.scala   | 36 -------
 5 files changed, 52 insertions(+), 133 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/9b4e15ba/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java
index 18612dd..9044bb4 100644
--- a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java
+++ b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java
@@ -268,8 +268,8 @@ public class TaskMemoryManager {
       logger.warn("Failed to allocate a page ({} bytes), try again.", acquired);
       // there is no enough memory actually, it means the actual free memory is smaller than
       // MemoryManager thought, we should keep the acquired memory.
-      acquiredButNotUsed += acquired;
       synchronized (this) {
+        acquiredButNotUsed += acquired;
         allocatedPages.clear(pageNumber);
       }
       // this could trigger spilling to free some pages.

http://git-wip-us.apache.org/repos/asf/spark/blob/9b4e15ba/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index de4b4b7..7841ff0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -100,13 +100,17 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
    *     [[org.apache.spark.sql.functions.broadcast()]] function to a DataFrame), then that side
    *     of the join will be broadcasted and the other side will be streamed, with no shuffling
    *     performed. If both sides of the join are eligible to be broadcasted then the
-   * - Shuffle hash join: if single partition is small enough to build a hash table.
+   * - Shuffle hash join: if the average size of a single partition is small enough to build a hash
+   *     table.
    * - Sort merge: if the matching join keys are sortable.
    */
   object EquiJoinSelection extends Strategy with PredicateHelper {
 
     /**
      * Matches a plan whose single partition should be small enough to build a hash table.
+     *
+     * Note: this assume that the number of partition is fixed, requires addtional work if it's
+     * dynamic.
      */
     def canBuildHashMap(plan: LogicalPlan): Boolean = {
       plan.statistics.sizeInBytes < conf.autoBroadcastJoinThreshold * conf.numShufflePartitions

http://git-wip-us.apache.org/repos/asf/spark/blob/9b4e15ba/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
index 0b0f59c..8cc3528 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
@@ -109,51 +109,6 @@ private[execution] trait UniqueHashedRelation extends HashedRelation {
   }
 }
 
-/**
- * A general [[HashedRelation]] backed by a hash map that maps the key into a sequence of values.
- */
-private[joins] class GeneralHashedRelation(
-    private var hashTable: JavaHashMap[InternalRow, CompactBuffer[InternalRow]])
-  extends HashedRelation with Externalizable {
-
-  // Needed for serialization (it is public to make Java serialization work)
-  def this() = this(null)
-
-  override def get(key: InternalRow): Seq[InternalRow] = hashTable.get(key)
-
-  override def writeExternal(out: ObjectOutput): Unit = {
-    writeBytes(out, SparkSqlSerializer.serialize(hashTable))
-  }
-
-  override def readExternal(in: ObjectInput): Unit = {
-    hashTable = SparkSqlSerializer.deserialize(readBytes(in))
-  }
-}
-
-
-/**
- * A specialized [[HashedRelation]] that maps key into a single value. This implementation
- * assumes the key is unique.
- */
-private[joins] class UniqueKeyHashedRelation(
-  private var hashTable: JavaHashMap[InternalRow, InternalRow])
-  extends UniqueHashedRelation with Externalizable {
-
-  // Needed for serialization (it is public to make Java serialization work)
-  def this() = this(null)
-
-  override def getValue(key: InternalRow): InternalRow = hashTable.get(key)
-
-  override def writeExternal(out: ObjectOutput): Unit = {
-    writeBytes(out, SparkSqlSerializer.serialize(hashTable))
-  }
-
-  override def readExternal(in: ObjectInput): Unit = {
-    hashTable = SparkSqlSerializer.deserialize(readBytes(in))
-  }
-}
-
-
 private[execution] object HashedRelation {
 
   /**
@@ -162,51 +117,16 @@ private[execution] object HashedRelation {
    * Note: The caller should make sure that these InternalRow are different objects.
    */
   def apply(
+      canJoinKeyFitWithinLong: Boolean,
       input: Iterator[InternalRow],
       keyGenerator: Projection,
       sizeEstimate: Int = 64): HashedRelation = {
 
-    if (keyGenerator.isInstanceOf[UnsafeProjection]) {
-      return UnsafeHashedRelation(
-        input, keyGenerator.asInstanceOf[UnsafeProjection], sizeEstimate)
-    }
-
-    // TODO: Use Spark's HashMap implementation.
-    val hashTable = new JavaHashMap[InternalRow, CompactBuffer[InternalRow]](sizeEstimate)
-    var currentRow: InternalRow = null
-
-    // Whether the join key is unique. If the key is unique, we can convert the underlying
-    // hash map into one specialized for this.
-    var keyIsUnique = true
-
-    // Create a mapping of buildKeys -> rows
-    while (input.hasNext) {
-      currentRow = input.next()
-      val rowKey = keyGenerator(currentRow)
-      if (!rowKey.anyNull) {
-        val existingMatchList = hashTable.get(rowKey)
-        val matchList = if (existingMatchList == null) {
-          val newMatchList = new CompactBuffer[InternalRow]()
-          hashTable.put(rowKey.copy(), newMatchList)
-          newMatchList
-        } else {
-          keyIsUnique = false
-          existingMatchList
-        }
-        matchList += currentRow
-      }
-    }
-
-    if (keyIsUnique) {
-      val uniqHashTable = new JavaHashMap[InternalRow, InternalRow](hashTable.size)
-      val iter = hashTable.entrySet().iterator()
-      while (iter.hasNext) {
-        val entry = iter.next()
-        uniqHashTable.put(entry.getKey, entry.getValue()(0))
-      }
-      new UniqueKeyHashedRelation(uniqHashTable)
+    if (canJoinKeyFitWithinLong) {
+      LongHashedRelation(input, keyGenerator, sizeEstimate)
     } else {
-      new GeneralHashedRelation(hashTable)
+      UnsafeHashedRelation(
+        input, keyGenerator.asInstanceOf[UnsafeProjection], sizeEstimate)
     }
   }
 }
@@ -428,6 +348,7 @@ private[joins] object UnsafeHashedRelation {
       sizeEstimate: Int): HashedRelation = {
 
     // Use a Java hash table here because unsafe maps expect fixed size records
+    // TODO: Use BytesToBytesMap for memory efficiency
     val hashTable = new JavaHashMap[UnsafeRow, CompactBuffer[UnsafeRow]](sizeEstimate)
 
     // Create a mapping of buildKeys -> rows
@@ -683,11 +604,7 @@ private[execution] case class HashedRelationBroadcastMode(
 
   override def transform(rows: Array[InternalRow]): HashedRelation = {
     val generator = UnsafeProjection.create(keys, attributes)
-    if (canJoinKeyFitWithinLong) {
-      LongHashedRelation(rows.iterator, generator, rows.length)
-    } else {
-      HashedRelation(rows.iterator, generator, rows.length)
-    }
+    HashedRelation(canJoinKeyFitWithinLong, rows.iterator, generator, rows.length)
   }
 
   private lazy val canonicalizedKeys: Seq[Expression] = {
@@ -703,4 +620,3 @@ private[execution] case class HashedRelationBroadcastMode(
     case _ => false
   }
 }
-

http://git-wip-us.apache.org/repos/asf/spark/blob/9b4e15ba/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala
index 1e8879a..5c4f1ef 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala
@@ -17,17 +17,18 @@
 
 package org.apache.spark.sql.execution.joins
 
+import org.apache.spark.{SparkException, TaskContext}
+import org.apache.spark.memory.MemoryMode
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Expression, JoinedRow}
+import org.apache.spark.sql.catalyst.expressions.{Expression, JoinedRow, UnsafeRow}
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.physical._
 import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
 import org.apache.spark.sql.execution.metric.SQLMetrics
 
 /**
- * Performs an inner hash join of two child relations by first shuffling the data using the join
- * keys.
+ * Performs a hash join of two child relations by first shuffling the data using the join keys.
  */
 case class ShuffledHashJoin(
     leftKeys: Seq[Expression],
@@ -55,11 +56,45 @@ case class ShuffledHashJoin(
   override def requiredChildDistribution: Seq[Distribution] =
     ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
 
+  private def buildHashedRelation(iter: Iterator[UnsafeRow]): HashedRelation = {
+    // try to acquire some memory for the hash table, it could trigger other operator to free some
+    // memory. The memory acquired here will mostly be used until the end of task.
+    val context = TaskContext.get()
+    val memoryManager = context.taskMemoryManager()
+    var acquired = 0L
+    var used = 0L
+    context.addTaskCompletionListener((t: TaskContext) =>
+      memoryManager.releaseExecutionMemory(acquired, MemoryMode.ON_HEAP, null)
+    )
+
+    val copiedIter = iter.map { row =>
+      // It's hard to guess what's exactly memory will be used, we have a rough guess here.
+      // TODO: use BytesToBytesMap instead of HashMap for memory efficiency
+      // Each pair in HashMap will have two UnsafeRows, one CompactBuffer, maybe 10+ pointers
+      val needed = 150 + row.getSizeInBytes
+      if (needed > acquired - used) {
+        val got = memoryManager.acquireExecutionMemory(
+          Math.max(memoryManager.pageSizeBytes(), needed), MemoryMode.ON_HEAP, null)
+        if (got < needed) {
+          throw new SparkException("Can't acquire enough memory to build hash map in shuffled" +
+            "hash join, please use sort merge join by setting " +
+            "spark.sql.join.preferSortMergeJoin=true")
+        }
+        acquired += got
+      }
+      used += needed
+      // HashedRelation requires that the UnsafeRow should be separate objects.
+      row.copy()
+    }
+
+    HashedRelation(canJoinKeyFitWithinLong, copiedIter, buildSideKeyGenerator)
+  }
+
   protected override def doExecute(): RDD[InternalRow] = {
     val numOutputRows = longMetric("numOutputRows")
 
     streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter, buildIter) =>
-      val hashed = HashedRelation(buildIter.map(_.copy()), buildSideKeyGenerator)
+      val hashed = buildHashedRelation(buildIter.asInstanceOf[Iterator[UnsafeRow]])
       val joinedRow = new JoinedRow
       joinType match {
         case Inner =>

http://git-wip-us.apache.org/repos/asf/spark/blob/9b4e15ba/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
index 04dd809..dd20855 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
@@ -29,42 +29,6 @@ import org.apache.spark.util.collection.CompactBuffer
 
 class HashedRelationSuite extends SparkFunSuite with SharedSQLContext {
 
-  // Key is simply the record itself
-  private val keyProjection = new Projection {
-    override def apply(row: InternalRow): InternalRow = row
-  }
-
-  test("GeneralHashedRelation") {
-    val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2))
-    val hashed = HashedRelation(data.iterator, keyProjection)
-    assert(hashed.isInstanceOf[GeneralHashedRelation])
-
-    assert(hashed.get(data(0)) === CompactBuffer[InternalRow](data(0)))
-    assert(hashed.get(data(1)) === CompactBuffer[InternalRow](data(1)))
-    assert(hashed.get(InternalRow(10)) === null)
-
-    val data2 = CompactBuffer[InternalRow](data(2))
-    data2 += data(2)
-    assert(hashed.get(data(2)) === data2)
-  }
-
-  test("UniqueKeyHashedRelation") {
-    val data = Array(InternalRow(0), InternalRow(1), InternalRow(2))
-    val hashed = HashedRelation(data.iterator, keyProjection)
-    assert(hashed.isInstanceOf[UniqueKeyHashedRelation])
-
-    assert(hashed.get(data(0)) === CompactBuffer[InternalRow](data(0)))
-    assert(hashed.get(data(1)) === CompactBuffer[InternalRow](data(1)))
-    assert(hashed.get(data(2)) === CompactBuffer[InternalRow](data(2)))
-    assert(hashed.get(InternalRow(10)) === null)
-
-    val uniqHashed = hashed.asInstanceOf[UniqueKeyHashedRelation]
-    assert(uniqHashed.getValue(data(0)) === data(0))
-    assert(uniqHashed.getValue(data(1)) === data(1))
-    assert(uniqHashed.getValue(data(2)) === data(2))
-    assert(uniqHashed.getValue(InternalRow(10)) === null)
-  }
-
   test("UnsafeHashedRelation") {
     val schema = StructType(StructField("a", IntegerType, true) :: Nil)
     val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2))


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org