You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by da...@apache.org on 2016/03/28 22:07:47 UTC

spark git commit: [SPARK-14052] [SQL] build a BytesToBytesMap directly in HashedRelation

Repository: spark
Updated Branches:
  refs/heads/master 600c0b69c -> d7b58f146


[SPARK-14052] [SQL] build a BytesToBytesMap directly in HashedRelation

## What changes were proposed in this pull request?

Currently, for the key that can not fit within a long,  we build a hash map for UnsafeHashedRelation, it's converted to BytesToBytesMap after serialization and deserialization. We should build a BytesToBytesMap directly to have better memory efficiency.

In order to do that, BytesToBytesMap should support multiple (K,V) pair with the same K,  Location.putNewKey() is renamed to Location.append(), which could append multiple values for the same key (same Location). `Location.newValue()` is added to find the next value for the same key.

## How was this patch tested?

Existing tests. Added benchmark for broadcast hash join with duplicated keys.

Author: Davies Liu <da...@databricks.com>

Closes #11870 from davies/map2.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d7b58f14
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d7b58f14
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d7b58f14

Branch: refs/heads/master
Commit: d7b58f1461f71ee3c028360eef0ffedd17d6a076
Parents: 600c0b6
Author: Davies Liu <da...@databricks.com>
Authored: Mon Mar 28 13:07:32 2016 -0700
Committer: Davies Liu <da...@gmail.com>
Committed: Mon Mar 28 13:07:32 2016 -0700

----------------------------------------------------------------------
 .../spark/unsafe/map/BytesToBytesMap.java       | 113 +++++---
 .../map/AbstractBytesToBytesMapSuite.java       |  64 ++++-
 .../UnsafeFixedWidthAggregationMap.java         |   2 +-
 .../sql/execution/joins/HashedRelation.scala    | 281 +++++++++----------
 .../sql/execution/joins/ShuffledHashJoin.scala  |  18 +-
 .../execution/BenchmarkWholeStageCodegen.scala  |  26 +-
 .../execution/joins/HashedRelationSuite.scala   |  15 +-
 7 files changed, 299 insertions(+), 220 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/d7b58f14/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
index 9aacb08..32958be 100644
--- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
+++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
@@ -56,9 +56,10 @@ import org.apache.spark.util.collection.unsafe.sort.UnsafeSorterSpillWriter;
  *   Bytes 4 to 8: len(k)
  *   Bytes 8 to 8 + len(k): key data
  *   Bytes 8 + len(k) to 8 + len(k) + len(v): value data
+ *   Bytes 8 + len(k) + len(v) to 8 + len(k) + len(v) + 8: pointer to next pair
  *
  * This means that the first four bytes store the entire record (key + value) length. This format
- * is consistent with {@link org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter},
+ * is compatible with {@link org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter},
  * so we can pass records from this map directly into the sorter to sort records in place.
  */
 public final class BytesToBytesMap extends MemoryConsumer {
@@ -132,7 +133,12 @@ public final class BytesToBytesMap extends MemoryConsumer {
   /**
    * Number of keys defined in the map.
    */
-  private int numElements;
+  private int numKeys;
+
+  /**
+   * Number of values defined in the map. A key could have multiple values.
+   */
+  private int numValues;
 
   /**
    * The map will be expanded once the number of keys exceeds this threshold.
@@ -223,7 +229,12 @@ public final class BytesToBytesMap extends MemoryConsumer {
   /**
    * Returns the number of keys defined in the map.
    */
-  public int numElements() { return numElements; }
+  public int numKeys() { return numKeys; }
+
+  /**
+   * Returns the number of values defined in the map. A key could have multiple values.
+   */
+  public int numValues() { return numValues; }
 
   public final class MapIterator implements Iterator<Location> {
 
@@ -311,7 +322,8 @@ public final class BytesToBytesMap extends MemoryConsumer {
       if (currentPage != null) {
         int totalLength = Platform.getInt(pageBaseObject, offsetInPage);
         loc.with(currentPage, offsetInPage);
-        offsetInPage += 4 + totalLength;
+        // [total size] [key size] [key] [value] [pointer to next]
+        offsetInPage += 4 + totalLength + 8;
         recordsInPage --;
         return loc;
       } else {
@@ -361,7 +373,7 @@ public final class BytesToBytesMap extends MemoryConsumer {
           while (numRecords > 0) {
             int length = Platform.getInt(base, offset);
             writer.write(base, offset + 4, length, 0);
-            offset += 4 + length;
+            offset += 4 + length + 8;
             numRecords--;
           }
           writer.close();
@@ -395,7 +407,7 @@ public final class BytesToBytesMap extends MemoryConsumer {
    * `lookup()`, the behavior of the returned iterator is undefined.
    */
   public MapIterator iterator() {
-    return new MapIterator(numElements, loc, false);
+    return new MapIterator(numValues, loc, false);
   }
 
   /**
@@ -409,7 +421,7 @@ public final class BytesToBytesMap extends MemoryConsumer {
    * `lookup()`, the behavior of the returned iterator is undefined.
    */
   public MapIterator destructiveIterator() {
-    return new MapIterator(numElements, loc, true);
+    return new MapIterator(numValues, loc, true);
   }
 
   /**
@@ -560,6 +572,20 @@ public final class BytesToBytesMap extends MemoryConsumer {
     }
 
     /**
+     * Find the next pair that has the same key as current one.
+     */
+    public boolean nextValue() {
+      assert isDefined;
+      long nextAddr = Platform.getLong(baseObject, valueOffset + valueLength);
+      if (nextAddr == 0) {
+        return false;
+      } else {
+        updateAddressesAndSizes(nextAddr);
+        return true;
+      }
+    }
+
+    /**
      * Returns the memory page that contains the current record.
      * This is only valid if this is returned by {@link BytesToBytesMap#iterator()}.
      */
@@ -625,10 +651,9 @@ public final class BytesToBytesMap extends MemoryConsumer {
     }
 
     /**
-     * Store a new key and value. This method may only be called once for a given key; if you want
-     * to update the value associated with a key, then you can directly manipulate the bytes stored
-     * at the value address. The return value indicates whether the put succeeded or whether it
-     * failed because additional memory could not be acquired.
+     * Append a new value for the key. This method could be called multiple times for a given key.
+     * The return value indicates whether the put succeeded or whether it failed because additional
+     * memory could not be acquired.
      * <p>
      * It is only valid to call this method immediately after calling `lookup()` using the same key.
      * </p>
@@ -637,7 +662,7 @@ public final class BytesToBytesMap extends MemoryConsumer {
      * </p>
      * <p>
      * After calling this method, calls to `get[Key|Value]Address()` and `get[Key|Value]Length`
-     * will return information on the data stored by this `putNewKey` call.
+     * will return information on the data stored by this `append` call.
      * </p>
      * <p>
      * As an example usage, here's the proper way to store a new key:
@@ -645,7 +670,7 @@ public final class BytesToBytesMap extends MemoryConsumer {
      * <pre>
      *   Location loc = map.lookup(keyBase, keyOffset, keyLength);
      *   if (!loc.isDefined()) {
-     *     if (!loc.putNewKey(keyBase, keyOffset, keyLength, ...)) {
+     *     if (!loc.append(keyBase, keyOffset, keyLength, ...)) {
      *       // handle failure to grow map (by spilling, for example)
      *     }
      *   }
@@ -657,26 +682,23 @@ public final class BytesToBytesMap extends MemoryConsumer {
      * @return true if the put() was successful and false if the put() failed because memory could
      *         not be acquired.
      */
-    public boolean putNewKey(Object keyBase, long keyOffset, int keyLength,
-        Object valueBase, long valueOffset, int valueLength) {
-      assert (!isDefined) : "Can only set value once for a key";
-      assert (keyLength % 8 == 0);
-      assert (valueLength % 8 == 0);
-      assert(longArray != null);
-
+    public boolean append(Object kbase, long koff, int klen, Object vbase, long voff, int vlen) {
+      assert (klen % 8 == 0);
+      assert (vlen % 8 == 0);
+      assert (longArray != null);
 
-      if (numElements == MAX_CAPACITY
+      if (numKeys == MAX_CAPACITY
         // The map could be reused from last spill (because of no enough memory to grow),
         // then we don't try to grow again if hit the `growthThreshold`.
-        || !canGrowArray && numElements > growthThreshold) {
+        || !canGrowArray && numKeys > growthThreshold) {
         return false;
       }
 
       // Here, we'll copy the data into our data pages. Because we only store a relative offset from
       // the key address instead of storing the absolute address of the value, the key and value
       // must be stored in the same memory page.
-      // (8 byte key length) (key) (value)
-      final long recordLength = 8 + keyLength + valueLength;
+      // (8 byte key length) (key) (value) (8 byte pointer to next value)
+      final long recordLength = 8 + klen + vlen + 8;
       if (currentPage == null || currentPage.size() - pageCursor < recordLength) {
         if (!acquireNewPage(recordLength + 4L)) {
           return false;
@@ -687,30 +709,40 @@ public final class BytesToBytesMap extends MemoryConsumer {
       final Object base = currentPage.getBaseObject();
       long offset = currentPage.getBaseOffset() + pageCursor;
       final long recordOffset = offset;
-      Platform.putInt(base, offset, keyLength + valueLength + 4);
-      Platform.putInt(base, offset + 4, keyLength);
+      Platform.putInt(base, offset, klen + vlen + 4);
+      Platform.putInt(base, offset + 4, klen);
       offset += 8;
-      Platform.copyMemory(keyBase, keyOffset, base, offset, keyLength);
-      offset += keyLength;
-      Platform.copyMemory(valueBase, valueOffset, base, offset, valueLength);
+      Platform.copyMemory(kbase, koff, base, offset, klen);
+      offset += klen;
+      Platform.copyMemory(vbase, voff, base, offset, vlen);
+      offset += vlen;
+      Platform.putLong(base, offset, 0);
 
       // --- Update bookkeeping data structures ----------------------------------------------------
       offset = currentPage.getBaseOffset();
       Platform.putInt(base, offset, Platform.getInt(base, offset) + 1);
       pageCursor += recordLength;
-      numElements++;
       final long storedKeyAddress = taskMemoryManager.encodePageNumberAndOffset(
         currentPage, recordOffset);
-      longArray.set(pos * 2, storedKeyAddress);
-      longArray.set(pos * 2 + 1, keyHashcode);
-      updateAddressesAndSizes(storedKeyAddress);
-      isDefined = true;
+      numValues++;
+      if (isDefined) {
+        // put this pair at the end of chain
+        while (nextValue()) { /* do nothing */ }
+        Platform.putLong(baseObject, valueOffset + valueLength, storedKeyAddress);
+        nextValue();  // point to new added value
+      } else {
+        numKeys++;
+        longArray.set(pos * 2, storedKeyAddress);
+        longArray.set(pos * 2 + 1, keyHashcode);
+        updateAddressesAndSizes(storedKeyAddress);
+        isDefined = true;
 
-      if (numElements > growthThreshold && longArray.size() < MAX_CAPACITY) {
-        try {
-          growAndRehash();
-        } catch (OutOfMemoryError oom) {
-          canGrowArray = false;
+        if (numKeys > growthThreshold && longArray.size() < MAX_CAPACITY) {
+          try {
+            growAndRehash();
+          } catch (OutOfMemoryError oom) {
+            canGrowArray = false;
+          }
         }
       }
       return true;
@@ -866,7 +898,8 @@ public final class BytesToBytesMap extends MemoryConsumer {
    * Reset this map to initialized state.
    */
   public void reset() {
-    numElements = 0;
+    numKeys = 0;
+    numValues = 0;
     longArray.zeroOut();
 
     while (dataPages.size() > 0) {

http://git-wip-us.apache.org/repos/asf/spark/blob/d7b58f14/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
index 449fb45..84b82f5 100644
--- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
+++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
@@ -182,7 +182,7 @@ public abstract class AbstractBytesToBytesMapSuite {
   public void emptyMap() {
     BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, 64, PAGE_SIZE_BYTES);
     try {
-      Assert.assertEquals(0, map.numElements());
+      Assert.assertEquals(0, map.numKeys());
       final int keyLengthInWords = 10;
       final int keyLengthInBytes = keyLengthInWords * 8;
       final byte[] key = getRandomByteArray(keyLengthInWords);
@@ -204,7 +204,7 @@ public abstract class AbstractBytesToBytesMapSuite {
       final BytesToBytesMap.Location loc =
         map.lookup(keyData, Platform.BYTE_ARRAY_OFFSET, recordLengthBytes);
       Assert.assertFalse(loc.isDefined());
-      Assert.assertTrue(loc.putNewKey(
+      Assert.assertTrue(loc.append(
         keyData,
         Platform.BYTE_ARRAY_OFFSET,
         recordLengthBytes,
@@ -232,7 +232,7 @@ public abstract class AbstractBytesToBytesMapSuite {
         getByteArray(loc.getValueBase(), loc.getValueOffset(), recordLengthBytes));
 
       try {
-        Assert.assertTrue(loc.putNewKey(
+        Assert.assertTrue(loc.append(
           keyData,
           Platform.BYTE_ARRAY_OFFSET,
           recordLengthBytes,
@@ -260,7 +260,7 @@ public abstract class AbstractBytesToBytesMapSuite {
         Assert.assertFalse(loc.isDefined());
         // Ensure that we store some zero-length keys
         if (i % 5 == 0) {
-          Assert.assertTrue(loc.putNewKey(
+          Assert.assertTrue(loc.append(
             null,
             Platform.LONG_ARRAY_OFFSET,
             0,
@@ -269,7 +269,7 @@ public abstract class AbstractBytesToBytesMapSuite {
             8
           ));
         } else {
-          Assert.assertTrue(loc.putNewKey(
+          Assert.assertTrue(loc.append(
             value,
             Platform.LONG_ARRAY_OFFSET,
             8,
@@ -349,7 +349,7 @@ public abstract class AbstractBytesToBytesMapSuite {
           KEY_LENGTH
         );
         Assert.assertFalse(loc.isDefined());
-        Assert.assertTrue(loc.putNewKey(
+        Assert.assertTrue(loc.append(
           key,
           Platform.LONG_ARRAY_OFFSET,
           KEY_LENGTH,
@@ -417,7 +417,7 @@ public abstract class AbstractBytesToBytesMapSuite {
             key.length
           );
           Assert.assertFalse(loc.isDefined());
-          Assert.assertTrue(loc.putNewKey(
+          Assert.assertTrue(loc.append(
             key,
             Platform.BYTE_ARRAY_OFFSET,
             key.length,
@@ -471,7 +471,7 @@ public abstract class AbstractBytesToBytesMapSuite {
             key.length
           );
           Assert.assertFalse(loc.isDefined());
-          Assert.assertTrue(loc.putNewKey(
+          Assert.assertTrue(loc.append(
             key,
             Platform.BYTE_ARRAY_OFFSET,
             key.length,
@@ -514,7 +514,7 @@ public abstract class AbstractBytesToBytesMapSuite {
       final BytesToBytesMap.Location loc =
         map.lookup(emptyArray, Platform.LONG_ARRAY_OFFSET, 0);
       Assert.assertFalse(loc.isDefined());
-      Assert.assertFalse(loc.putNewKey(
+      Assert.assertFalse(loc.append(
         emptyArray, Platform.LONG_ARRAY_OFFSET, 0, emptyArray, Platform.LONG_ARRAY_OFFSET, 0));
     } finally {
       map.free();
@@ -535,7 +535,7 @@ public abstract class AbstractBytesToBytesMapSuite {
         final long[] arr = new long[]{i};
         final BytesToBytesMap.Location loc = map.lookup(arr, Platform.LONG_ARRAY_OFFSET, 8);
         success =
-          loc.putNewKey(arr, Platform.LONG_ARRAY_OFFSET, 8, arr, Platform.LONG_ARRAY_OFFSET, 8);
+          loc.append(arr, Platform.LONG_ARRAY_OFFSET, 8, arr, Platform.LONG_ARRAY_OFFSET, 8);
         if (!success) {
           break;
         }
@@ -556,7 +556,7 @@ public abstract class AbstractBytesToBytesMapSuite {
       for (i = 0; i < 1024; i++) {
         final long[] arr = new long[]{i};
         final BytesToBytesMap.Location loc = map.lookup(arr, Platform.LONG_ARRAY_OFFSET, 8);
-        loc.putNewKey(arr, Platform.LONG_ARRAY_OFFSET, 8, arr, Platform.LONG_ARRAY_OFFSET, 8);
+        loc.append(arr, Platform.LONG_ARRAY_OFFSET, 8, arr, Platform.LONG_ARRAY_OFFSET, 8);
       }
       BytesToBytesMap.MapIterator iter = map.iterator();
       for (i = 0; i < 100; i++) {
@@ -587,6 +587,44 @@ public abstract class AbstractBytesToBytesMapSuite {
   }
 
   @Test
+  public void multipleValuesForSameKey() {
+    BytesToBytesMap map =
+      new BytesToBytesMap(taskMemoryManager, blockManager, serializerManager, 1, 0.75, 1024, false);
+    try {
+      int i;
+      for (i = 0; i < 1024; i++) {
+        final long[] arr = new long[]{i};
+        map.lookup(arr, Platform.LONG_ARRAY_OFFSET, 8)
+          .append(arr, Platform.LONG_ARRAY_OFFSET, 8, arr, Platform.LONG_ARRAY_OFFSET, 8);
+      }
+      assert map.numKeys() == 1024;
+      assert map.numValues() == 1024;
+      for (i = 0; i < 1024; i++) {
+        final long[] arr = new long[]{i};
+        map.lookup(arr, Platform.LONG_ARRAY_OFFSET, 8)
+          .append(arr, Platform.LONG_ARRAY_OFFSET, 8, arr, Platform.LONG_ARRAY_OFFSET, 8);
+      }
+      assert map.numKeys() == 1024;
+      assert map.numValues() == 2048;
+      for (i = 0; i < 1024; i++) {
+        final long[] arr = new long[]{i};
+        final BytesToBytesMap.Location loc = map.lookup(arr, Platform.LONG_ARRAY_OFFSET, 8);
+        assert loc.isDefined();
+        assert loc.nextValue();
+        assert !loc.nextValue();
+      }
+      BytesToBytesMap.MapIterator iter = map.iterator();
+      for (i = 0; i < 2048; i++) {
+        assert iter.hasNext();
+        final BytesToBytesMap.Location loc = iter.next();
+        assert loc.isDefined();
+      }
+    } finally {
+      map.free();
+    }
+  }
+
+  @Test
   public void initialCapacityBoundsChecking() {
     try {
       new BytesToBytesMap(taskMemoryManager, 0, PAGE_SIZE_BYTES);
@@ -608,7 +646,7 @@ public abstract class AbstractBytesToBytesMapSuite {
 
   @Test
   public void testPeakMemoryUsed() {
-    final long recordLengthBytes = 24;
+    final long recordLengthBytes = 32;
     final long pageSizeBytes = 256 + 8; // 8 bytes for end-of-page marker
     final long numRecordsPerPage = (pageSizeBytes - 8) / recordLengthBytes;
     final BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, 1024, pageSizeBytes);
@@ -622,7 +660,7 @@ public abstract class AbstractBytesToBytesMapSuite {
     try {
       for (long i = 0; i < numRecordsPerPage * 10; i++) {
         final long[] value = new long[]{i};
-        map.lookup(value, Platform.LONG_ARRAY_OFFSET, 8).putNewKey(
+        map.lookup(value, Platform.LONG_ARRAY_OFFSET, 8).append(
           value,
           Platform.LONG_ARRAY_OFFSET,
           8,

http://git-wip-us.apache.org/repos/asf/spark/blob/d7b58f14/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
----------------------------------------------------------------------
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
index 8882903..1f1b538 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
@@ -134,7 +134,7 @@ public final class UnsafeFixedWidthAggregationMap {
     if (!loc.isDefined()) {
       // This is the first time that we've seen this grouping key, so we'll insert a copy of the
       // empty aggregation buffer into the map:
-      boolean putSucceeded = loc.putNewKey(
+      boolean putSucceeded = loc.append(
         key.getBaseObject(),
         key.getBaseOffset(),
         key.getSizeInBytes(),

http://git-wip-us.apache.org/repos/asf/spark/blob/d7b58f14/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
index 8cc3528..dc4793e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
@@ -18,18 +18,18 @@
 package org.apache.spark.sql.execution.joins
 
 import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput}
-import java.nio.ByteOrder
 import java.util.{HashMap => JavaHashMap}
 
-import org.apache.spark.{SparkConf, SparkEnv}
+import org.apache.spark.{SparkConf, SparkEnv, SparkException, TaskContext}
 import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode
 import org.apache.spark.sql.execution.SparkSqlSerializer
 import org.apache.spark.unsafe.Platform
+import org.apache.spark.unsafe.hash.Murmur3_x86_32
 import org.apache.spark.unsafe.map.BytesToBytesMap
-import org.apache.spark.util.{KnownSizeEstimation, SizeEstimator, Utils}
+import org.apache.spark.util.{KnownSizeEstimation, Utils}
 import org.apache.spark.util.collection.CompactBuffer
 
 /**
@@ -54,6 +54,11 @@ private[execution] sealed trait HashedRelation {
     */
   def getMemorySize: Long = 1L  // to make the test happy
 
+  /**
+   * Release any used resources.
+   */
+  def close(): Unit = {}
+
   // This is a helper method to implement Externalizable, and is used by
   // GeneralHashedRelation and UniqueKeyHashedRelation
   protected def writeBytes(out: ObjectOutput, serialized: Array[Byte]): Unit = {
@@ -132,163 +137,83 @@ private[execution] object HashedRelation {
 }
 
 /**
- * A HashedRelation for UnsafeRow, which is backed by HashMap or BytesToBytesMap that maps the key
- * into a sequence of values.
- *
- * When it's created, it uses HashMap. After it's serialized and deserialized, it switch to use
- * BytesToBytesMap for better memory performance (multiple values for the same are stored as a
- * continuous byte array.
+ * A HashedRelation for UnsafeRow, which is backed BytesToBytesMap.
  *
  * It's serialized in the following format:
  *  [number of keys]
- *  [size of key] [size of all values in bytes] [key bytes] [bytes for all values]
- *  ...
- *
- * All the values are serialized as following:
- *   [number of fields] [number of bytes] [underlying bytes of UnsafeRow]
- *   ...
+ *  [size of key] [size of value] [key bytes] [bytes for value]
  */
-private[joins] final class UnsafeHashedRelation(
-    private var hashTable: JavaHashMap[UnsafeRow, CompactBuffer[UnsafeRow]])
-  extends HashedRelation
-  with KnownSizeEstimation
-  with Externalizable {
-
-  private[joins] def this() = this(null)  // Needed for serialization
+private[joins] class UnsafeHashedRelation(
+    private var numFields: Int,
+    private var binaryMap: BytesToBytesMap)
+  extends HashedRelation with KnownSizeEstimation with Externalizable {
 
-  // Use BytesToBytesMap in executor for better performance (it's created when deserialization)
-  // This is used in broadcast joins and distributed mode only
-  @transient private[this] var binaryMap: BytesToBytesMap = _
+  private[joins] def this() = this(0, null)  // Needed for serialization
 
-  /**
-   * Return the size of the unsafe map on the executors.
-   *
-   * For broadcast joins, this hashed relation is bigger on the driver because it is
-   * represented as a Java hash map there. While serializing the map to the executors,
-   * however, we rehash the contents in a binary map to reduce the memory footprint on
-   * the executors.
-   *
-   * For non-broadcast joins or in local mode, return 0.
-   */
   override def getMemorySize: Long = {
-    if (binaryMap != null) {
-      binaryMap.getTotalMemoryConsumption
-    } else {
-      0
-    }
+    binaryMap.getTotalMemoryConsumption
   }
 
   override def estimatedSize: Long = {
-    if (binaryMap != null) {
-      binaryMap.getTotalMemoryConsumption
-    } else {
-      SizeEstimator.estimate(hashTable)
-    }
+    binaryMap.getTotalMemoryConsumption
   }
 
   override def get(key: InternalRow): Seq[InternalRow] = {
     val unsafeKey = key.asInstanceOf[UnsafeRow]
-
-    if (binaryMap != null) {
-      // Used in Broadcast join
-      val map = binaryMap  // avoid the compiler error
-      val loc = new map.Location  // this could be allocated in stack
-      binaryMap.safeLookup(unsafeKey.getBaseObject, unsafeKey.getBaseOffset,
-        unsafeKey.getSizeInBytes, loc, unsafeKey.hashCode())
-      if (loc.isDefined) {
-        val buffer = CompactBuffer[UnsafeRow]()
-
-        val base = loc.getValueBase
-        var offset = loc.getValueOffset
-        val last = offset + loc.getValueLength
-        while (offset < last) {
-          val numFields = Platform.getInt(base, offset)
-          val sizeInBytes = Platform.getInt(base, offset + 4)
-          offset += 8
-
-          val row = new UnsafeRow(numFields)
-          row.pointTo(base, offset, sizeInBytes)
-          buffer += row
-          offset += sizeInBytes
-        }
-        buffer
-      } else {
-        null
+    val map = binaryMap  // avoid the compiler error
+    val loc = new map.Location  // this could be allocated in stack
+    binaryMap.safeLookup(unsafeKey.getBaseObject, unsafeKey.getBaseOffset,
+      unsafeKey.getSizeInBytes, loc, unsafeKey.hashCode())
+    if (loc.isDefined) {
+      val buffer = CompactBuffer[UnsafeRow]()
+      val row = new UnsafeRow(numFields)
+      row.pointTo(loc.getValueBase, loc.getValueOffset, loc.getValueLength)
+      buffer += row
+      while (loc.nextValue()) {
+        val row = new UnsafeRow(numFields)
+        row.pointTo(loc.getValueBase, loc.getValueOffset, loc.getValueLength)
+        buffer += row
       }
-
+      buffer
     } else {
-      // Use the Java HashMap in local mode or for non-broadcast joins (e.g. ShuffleHashJoin)
-      hashTable.get(unsafeKey)
+      null
     }
   }
 
-  override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
-    if (binaryMap != null) {
-      // This could happen when a cached broadcast object need to be dumped into disk to free memory
-      out.writeInt(binaryMap.numElements())
-
-      var buffer = new Array[Byte](64)
-      def write(base: Object, offset: Long, length: Int): Unit = {
-        if (buffer.length < length) {
-          buffer = new Array[Byte](length)
-        }
-        Platform.copyMemory(base, offset, buffer, Platform.BYTE_ARRAY_OFFSET, length)
-        out.write(buffer, 0, length)
-      }
+  override def close(): Unit = {
+    binaryMap.free()
+  }
 
-      val iter = binaryMap.iterator()
-      while (iter.hasNext) {
-        val loc = iter.next()
-        // [key size] [values size] [key bytes] [values bytes]
-        out.writeInt(loc.getKeyLength)
-        out.writeInt(loc.getValueLength)
-        write(loc.getKeyBase, loc.getKeyOffset, loc.getKeyLength)
-        write(loc.getValueBase, loc.getValueOffset, loc.getValueLength)
+  override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
+    out.writeInt(numFields)
+    // TODO: move these into BytesToBytesMap
+    out.writeInt(binaryMap.numKeys())
+    out.writeInt(binaryMap.numValues())
+
+    var buffer = new Array[Byte](64)
+    def write(base: Object, offset: Long, length: Int): Unit = {
+      if (buffer.length < length) {
+        buffer = new Array[Byte](length)
       }
+      Platform.copyMemory(base, offset, buffer, Platform.BYTE_ARRAY_OFFSET, length)
+      out.write(buffer, 0, length)
+    }
 
-    } else {
-      assert(hashTable != null)
-      out.writeInt(hashTable.size())
-
-      val iter = hashTable.entrySet().iterator()
-      while (iter.hasNext) {
-        val entry = iter.next()
-        val key = entry.getKey
-        val values = entry.getValue
-
-        // write all the values as single byte array
-        var totalSize = 0L
-        var i = 0
-        while (i < values.length) {
-          totalSize += values(i).getSizeInBytes + 4 + 4
-          i += 1
-        }
-        assert(totalSize < Integer.MAX_VALUE, "values are too big")
-
-        // [key size] [values size] [key bytes] [values bytes]
-        out.writeInt(key.getSizeInBytes)
-        out.writeInt(totalSize.toInt)
-        out.write(key.getBytes)
-        i = 0
-        while (i < values.length) {
-          // [num of fields] [num of bytes] [row bytes]
-          // write the integer in native order, so they can be read by UNSAFE.getInt()
-          if (ByteOrder.nativeOrder() == ByteOrder.BIG_ENDIAN) {
-            out.writeInt(values(i).numFields())
-            out.writeInt(values(i).getSizeInBytes)
-          } else {
-            out.writeInt(Integer.reverseBytes(values(i).numFields()))
-            out.writeInt(Integer.reverseBytes(values(i).getSizeInBytes))
-          }
-          out.write(values(i).getBytes)
-          i += 1
-        }
-      }
+    val iter = binaryMap.iterator()
+    while (iter.hasNext) {
+      val loc = iter.next()
+      // [key size] [values size] [key bytes] [value bytes]
+      out.writeInt(loc.getKeyLength)
+      out.writeInt(loc.getValueLength)
+      write(loc.getKeyBase, loc.getKeyOffset, loc.getKeyLength)
+      write(loc.getValueBase, loc.getValueOffset, loc.getValueLength)
     }
   }
 
   override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
+    numFields = in.readInt()
     val nKeys = in.readInt()
+    val nValues = in.readInt()
     // This is used in Broadcast, shared by multiple tasks, so we use on-heap memory
     // TODO(josh): This needs to be revisited before we merge this patch; making this change now
     // so that tests compile:
@@ -314,7 +239,7 @@ private[joins] final class UnsafeHashedRelation(
     var i = 0
     var keyBuffer = new Array[Byte](1024)
     var valuesBuffer = new Array[Byte](1024)
-    while (i < nKeys) {
+    while (i < nValues) {
       val keySize = in.readInt()
       val valuesSize = in.readInt()
       if (keySize > keyBuffer.length) {
@@ -326,13 +251,11 @@ private[joins] final class UnsafeHashedRelation(
       }
       in.readFully(valuesBuffer, 0, valuesSize)
 
-      // put it into binary map
       val loc = binaryMap.lookup(keyBuffer, Platform.BYTE_ARRAY_OFFSET, keySize)
-      assert(!loc.isDefined, "Duplicated key found!")
-      val putSuceeded = loc.putNewKey(
-        keyBuffer, Platform.BYTE_ARRAY_OFFSET, keySize,
+      val putSuceeded = loc.append(keyBuffer, Platform.BYTE_ARRAY_OFFSET, keySize,
         valuesBuffer, Platform.BYTE_ARRAY_OFFSET, valuesSize)
       if (!putSuceeded) {
+        binaryMap.free()
         throw new IOException("Could not allocate memory to grow BytesToBytesMap")
       }
       i += 1
@@ -340,6 +263,29 @@ private[joins] final class UnsafeHashedRelation(
   }
 }
 
+/**
+ * A HashedRelation for UnsafeRow with unique keys.
+ */
+private[joins] final class UniqueUnsafeHashedRelation(
+    private var numFields: Int,
+    private var binaryMap: BytesToBytesMap)
+  extends UnsafeHashedRelation(numFields, binaryMap) with UniqueHashedRelation {
+  def getValue(key: InternalRow): InternalRow = {
+    val unsafeKey = key.asInstanceOf[UnsafeRow]
+    val map = binaryMap  // avoid the compiler error
+    val loc = new map.Location  // this could be allocated in stack
+    binaryMap.safeLookup(unsafeKey.getBaseObject, unsafeKey.getBaseOffset,
+      unsafeKey.getSizeInBytes, loc, unsafeKey.hashCode())
+    if (loc.isDefined) {
+      val row = new UnsafeRow(numFields)
+      row.pointTo(loc.getValueBase, loc.getValueOffset, loc.getValueLength)
+      row
+    } else {
+      null
+    }
+  }
+}
+
 private[joins] object UnsafeHashedRelation {
 
   def apply(
@@ -347,29 +293,54 @@ private[joins] object UnsafeHashedRelation {
       keyGenerator: UnsafeProjection,
       sizeEstimate: Int): HashedRelation = {
 
-    // Use a Java hash table here because unsafe maps expect fixed size records
-    // TODO: Use BytesToBytesMap for memory efficiency
-    val hashTable = new JavaHashMap[UnsafeRow, CompactBuffer[UnsafeRow]](sizeEstimate)
+    val taskMemoryManager = if (TaskContext.get() != null) {
+      TaskContext.get().taskMemoryManager()
+    } else {
+      new TaskMemoryManager(
+        new StaticMemoryManager(
+          new SparkConf().set("spark.memory.offHeap.enabled", "false"),
+          Long.MaxValue,
+          Long.MaxValue,
+          1),
+        0)
+    }
+    val pageSizeBytes = Option(SparkEnv.get).map(_.memoryManager.pageSizeBytes)
+      .getOrElse(new SparkConf().getSizeAsBytes("spark.buffer.pageSize", "16m"))
+
+    val binaryMap = new BytesToBytesMap(
+      taskMemoryManager,
+      // Only 70% of the slots can be used before growing, more capacity help to reduce collision
+      (sizeEstimate * 1.5 + 1).toInt,
+      pageSizeBytes)
 
     // Create a mapping of buildKeys -> rows
+    var numFields = 0
+    // Whether all the keys are unique or not
+    var allUnique: Boolean = true
     while (input.hasNext) {
-      val unsafeRow = input.next().asInstanceOf[UnsafeRow]
-      val rowKey = keyGenerator(unsafeRow)
-      if (!rowKey.anyNull) {
-        val existingMatchList = hashTable.get(rowKey)
-        val matchList = if (existingMatchList == null) {
-          val newMatchList = new CompactBuffer[UnsafeRow]()
-          hashTable.put(rowKey.copy(), newMatchList)
-          newMatchList
-        } else {
-          existingMatchList
+      val row = input.next().asInstanceOf[UnsafeRow]
+      numFields = row.numFields()
+      val key = keyGenerator(row)
+      if (!key.anyNull) {
+        val loc = binaryMap.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes)
+        if (loc.isDefined) {
+          allUnique = false
+        }
+        val success = loc.append(
+          key.getBaseObject, key.getBaseOffset, key.getSizeInBytes,
+          row.getBaseObject, row.getBaseOffset, row.getSizeInBytes)
+        if (!success) {
+          binaryMap.free()
+          throw new SparkException("There is no enough memory to build hash map")
         }
-        matchList += unsafeRow
       }
     }
 
-    // TODO: create UniqueUnsafeRelation
-    new UnsafeHashedRelation(hashTable)
+    if (allUnique) {
+      new UniqueUnsafeHashedRelation(numFields, binaryMap)
+    } else {
+      new UnsafeHashedRelation(numFields, binaryMap)
+    }
   }
 }
 
@@ -523,7 +494,7 @@ private[joins] object LongHashedRelation {
     keyGenerator: Projection,
     sizeEstimate: Int): HashedRelation = {
 
-    // Use a Java hash table here because unsafe maps expect fixed size records
+    // TODO: use LongToBytesMap for better memory efficiency
     val hashTable = new JavaHashMap[Long, CompactBuffer[UnsafeRow]](sizeEstimate)
 
     // Create a mapping of key -> rows

http://git-wip-us.apache.org/repos/asf/spark/blob/d7b58f14/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala
index 5c4f1ef..e3a2eae 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala
@@ -57,9 +57,19 @@ case class ShuffledHashJoin(
     ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
 
   private def buildHashedRelation(iter: Iterator[UnsafeRow]): HashedRelation = {
+    val context = TaskContext.get()
+    if (!canJoinKeyFitWithinLong) {
+      // build BytesToBytesMap
+      val relation = HashedRelation(canJoinKeyFitWithinLong, iter, buildSideKeyGenerator)
+      // This relation is usually used until the end of task.
+      context.addTaskCompletionListener((t: TaskContext) =>
+        relation.close()
+      )
+      return relation
+    }
+
     // try to acquire some memory for the hash table, it could trigger other operator to free some
     // memory. The memory acquired here will mostly be used until the end of task.
-    val context = TaskContext.get()
     val memoryManager = context.taskMemoryManager()
     var acquired = 0L
     var used = 0L
@@ -69,18 +79,18 @@ case class ShuffledHashJoin(
 
     val copiedIter = iter.map { row =>
       // It's hard to guess what's exactly memory will be used, we have a rough guess here.
-      // TODO: use BytesToBytesMap instead of HashMap for memory efficiency
-      // Each pair in HashMap will have two UnsafeRows, one CompactBuffer, maybe 10+ pointers
+      // TODO: use LongToBytesMap instead of HashMap for memory efficiency
+      // Each pair in HashMap will have UnsafeRow, CompactBuffer, maybe 10+ pointers
       val needed = 150 + row.getSizeInBytes
       if (needed > acquired - used) {
         val got = memoryManager.acquireExecutionMemory(
           Math.max(memoryManager.pageSizeBytes(), needed), MemoryMode.ON_HEAP, null)
+        acquired += got
         if (got < needed) {
           throw new SparkException("Can't acquire enough memory to build hash map in shuffled" +
             "hash join, please use sort merge join by setting " +
             "spark.sql.join.preferSortMergeJoin=true")
         }
-        acquired += got
       }
       used += needed
       // HashedRelation requires that the UnsafeRow should be separate objects.

http://git-wip-us.apache.org/repos/asf/spark/blob/d7b58f14/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
index 0b1cb90..a16092e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
@@ -184,11 +184,29 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
 
     /**
     Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
-    Join w 2 longs:                      Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
+    Join w 2 longs:                     Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
     -------------------------------------------------------------------------------------------
-    Join w 2 longs codegen=false              7877 / 8358         13.3          75.1       1.0X
-    Join w 2 longs codegen=true               3877 / 3937         27.0          37.0       2.0X
+    Join w 2 longs codegen=false           12725 / 13158          8.2         121.4       1.0X
+    Join w 2 longs codegen=true              6044 / 6771         17.3          57.6       2.1X
       */
+
+    val dim4 = broadcast(sqlContext.range(M)
+      .selectExpr("cast(id/10 as long) as k1", "cast(id/10 as long) as k2"))
+
+    runBenchmark("Join w 2 longs duplicated", N) {
+      sqlContext.range(N).join(dim4,
+        (col("id") bitwiseAND M) === col("k1") && (col("id") bitwiseAND M) === col("k2"))
+        .count()
+    }
+
+    /**
+    Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
+    Join w 2 longs:                     Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
+    -------------------------------------------------------------------------------------------
+    Join w 2 longs duplicated codegen=false 13066 / 13710          8.0         124.6       1.0X
+    Join w 2 longs duplicated codegen=true    7122 / 7277         14.7          67.9       1.8X
+     */
+
     runBenchmark("outer join w long", N) {
       sqlContext.range(N).join(dim, (col("id") bitwiseAND M) === col("k"), "left").count()
     }
@@ -438,7 +456,7 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
             value.setInt(0, value.getInt(0) + 1)
             i += 1
           } else {
-            loc.putNewKey(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes,
+            loc.append(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes,
               value.getBaseObject, value.getBaseOffset, value.getSizeInBytes)
           }
         }

http://git-wip-us.apache.org/repos/asf/spark/blob/d7b58f14/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
index e19b4ff..ed4cc1c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
@@ -19,11 +19,13 @@ package org.apache.spark.sql.execution.joins
 
 import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, ObjectOutputStream}
 
-import org.apache.spark.SparkFunSuite
+import org.apache.spark.{SparkConf, SparkFunSuite}
+import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
+import org.apache.spark.unsafe.map.BytesToBytesMap
 import org.apache.spark.util.collection.CompactBuffer
 
 class HashedRelationSuite extends SparkFunSuite with SharedSQLContext {
@@ -69,10 +71,17 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext {
   }
 
   test("test serialization empty hash map") {
+    val taskMemoryManager = new TaskMemoryManager(
+      new StaticMemoryManager(
+        new SparkConf().set("spark.memory.offHeap.enabled", "false"),
+        Long.MaxValue,
+        Long.MaxValue,
+        1),
+      0)
+    val binaryMap = new BytesToBytesMap(taskMemoryManager, 1, 1)
     val os = new ByteArrayOutputStream()
     val out = new ObjectOutputStream(os)
-    val hashed = new UnsafeHashedRelation(
-      new java.util.HashMap[UnsafeRow, CompactBuffer[UnsafeRow]])
+    val hashed = new UnsafeHashedRelation(1, binaryMap)
     hashed.writeExternal(out)
     out.flush()
     val in = new ObjectInputStream(new ByteArrayInputStream(os.toByteArray))


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org