You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by da...@apache.org on 2016/02/10 01:44:39 UTC

spark git commit: [SPARK-12950] [SQL] Improve lookup of BytesToBytesMap in aggregate

Repository: spark
Updated Branches:
  refs/heads/master fae830d15 -> 0e5ebac3c


[SPARK-12950] [SQL] Improve lookup of BytesToBytesMap in aggregate

This PR improve the lookup of BytesToBytesMap by:

1. Generate code for calculate the hash code of grouping keys.

2. Do not use MemoryLocation, fetch the baseObject and offset for key and value directly (remove the indirection).

Author: Davies Liu <da...@databricks.com>

Closes #11010 from davies/gen_map.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/0e5ebac3
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/0e5ebac3
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/0e5ebac3

Branch: refs/heads/master
Commit: 0e5ebac3c1f1ff58f938be59c7c9e604977d269c
Parents: fae830d
Author: Davies Liu <da...@databricks.com>
Authored: Tue Feb 9 16:41:21 2016 -0800
Committer: Davies Liu <da...@gmail.com>
Committed: Tue Feb 9 16:41:21 2016 -0800

----------------------------------------------------------------------
 .../spark/unsafe/map/BytesToBytesMap.java       | 108 +++++++++++--------
 .../map/AbstractBytesToBytesMapSuite.java       |  64 ++++++-----
 project/MimaExcludes.scala                      |   1 +
 .../spark/sql/catalyst/expressions/misc.scala   |   1 -
 .../UnsafeFixedWidthAggregationMap.java         |  34 +++---
 .../sql/execution/UnsafeKVExternalSorter.java   |   4 +-
 .../spark/sql/execution/WholeStageCodegen.scala |   6 +-
 .../execution/aggregate/TungstenAggregate.scala |  10 +-
 .../sql/execution/joins/HashedRelation.scala    |  17 ++-
 .../execution/BenchmarkWholeStageCodegen.scala  |  64 ++++++++---
 10 files changed, 182 insertions(+), 127 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/0e5ebac3/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
index 3387f9a..b55a322 100644
--- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
+++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
@@ -38,7 +38,6 @@ import org.apache.spark.unsafe.array.ByteArrayMethods;
 import org.apache.spark.unsafe.array.LongArray;
 import org.apache.spark.unsafe.hash.Murmur3_x86_32;
 import org.apache.spark.unsafe.memory.MemoryBlock;
-import org.apache.spark.unsafe.memory.MemoryLocation;
 import org.apache.spark.util.collection.unsafe.sort.UnsafeSorterSpillReader;
 import org.apache.spark.util.collection.unsafe.sort.UnsafeSorterSpillWriter;
 
@@ -65,8 +64,6 @@ public final class BytesToBytesMap extends MemoryConsumer {
 
   private final Logger logger = LoggerFactory.getLogger(BytesToBytesMap.class);
 
-  private static final Murmur3_x86_32 HASHER = new Murmur3_x86_32(0);
-
   private static final HashMapGrowthStrategy growthStrategy = HashMapGrowthStrategy.DOUBLING;
 
   private final TaskMemoryManager taskMemoryManager;
@@ -417,7 +414,19 @@ public final class BytesToBytesMap extends MemoryConsumer {
    * This function always return the same {@link Location} instance to avoid object allocation.
    */
   public Location lookup(Object keyBase, long keyOffset, int keyLength) {
-    safeLookup(keyBase, keyOffset, keyLength, loc);
+    safeLookup(keyBase, keyOffset, keyLength, loc,
+      Murmur3_x86_32.hashUnsafeWords(keyBase, keyOffset, keyLength, 42));
+    return loc;
+  }
+
+  /**
+   * Looks up a key, and return a {@link Location} handle that can be used to test existence
+   * and read/write values.
+   *
+   * This function always return the same {@link Location} instance to avoid object allocation.
+   */
+  public Location lookup(Object keyBase, long keyOffset, int keyLength, int hash) {
+    safeLookup(keyBase, keyOffset, keyLength, loc, hash);
     return loc;
   }
 
@@ -426,14 +435,13 @@ public final class BytesToBytesMap extends MemoryConsumer {
    *
    * This is a thread-safe version of `lookup`, could be used by multiple threads.
    */
-  public void safeLookup(Object keyBase, long keyOffset, int keyLength, Location loc) {
+  public void safeLookup(Object keyBase, long keyOffset, int keyLength, Location loc, int hash) {
     assert(longArray != null);
 
     if (enablePerfMetrics) {
       numKeyLookups++;
     }
-    final int hashcode = HASHER.hashUnsafeWords(keyBase, keyOffset, keyLength);
-    int pos = hashcode & mask;
+    int pos = hash & mask;
     int step = 1;
     while (true) {
       if (enablePerfMetrics) {
@@ -441,22 +449,19 @@ public final class BytesToBytesMap extends MemoryConsumer {
       }
       if (longArray.get(pos * 2) == 0) {
         // This is a new key.
-        loc.with(pos, hashcode, false);
+        loc.with(pos, hash, false);
         return;
       } else {
         long stored = longArray.get(pos * 2 + 1);
-        if ((int) (stored) == hashcode) {
+        if ((int) (stored) == hash) {
           // Full hash code matches.  Let's compare the keys for equality.
-          loc.with(pos, hashcode, true);
+          loc.with(pos, hash, true);
           if (loc.getKeyLength() == keyLength) {
-            final MemoryLocation keyAddress = loc.getKeyAddress();
-            final Object storedkeyBase = keyAddress.getBaseObject();
-            final long storedkeyOffset = keyAddress.getBaseOffset();
             final boolean areEqual = ByteArrayMethods.arrayEquals(
               keyBase,
               keyOffset,
-              storedkeyBase,
-              storedkeyOffset,
+              loc.getKeyBase(),
+              loc.getKeyOffset(),
               keyLength
             );
             if (areEqual) {
@@ -484,13 +489,14 @@ public final class BytesToBytesMap extends MemoryConsumer {
     private boolean isDefined;
     /**
      * The hashcode of the most recent key passed to
-     * {@link BytesToBytesMap#lookup(Object, long, int)}. Caching this hashcode here allows us to
-     * avoid re-hashing the key when storing a value for that key.
+     * {@link BytesToBytesMap#lookup(Object, long, int, int)}. Caching this hashcode here allows us
+     * to avoid re-hashing the key when storing a value for that key.
      */
     private int keyHashcode;
-    private final MemoryLocation keyMemoryLocation = new MemoryLocation();
-    private final MemoryLocation valueMemoryLocation = new MemoryLocation();
+    private Object baseObject;  // the base object for key and value
+    private long keyOffset;
     private int keyLength;
+    private long valueOffset;
     private int valueLength;
 
     /**
@@ -504,18 +510,15 @@ public final class BytesToBytesMap extends MemoryConsumer {
         taskMemoryManager.getOffsetInPage(fullKeyAddress));
     }
 
-    private void updateAddressesAndSizes(final Object base, final long offset) {
-      long position = offset;
-      final int totalLength = Platform.getInt(base, position);
-      position += 4;
-      keyLength = Platform.getInt(base, position);
-      position += 4;
+    private void updateAddressesAndSizes(final Object base, long offset) {
+      baseObject = base;
+      final int totalLength = Platform.getInt(base, offset);
+      offset += 4;
+      keyLength = Platform.getInt(base, offset);
+      offset += 4;
+      keyOffset = offset;
+      valueOffset = offset + keyLength;
       valueLength = totalLength - keyLength - 4;
-
-      keyMemoryLocation.setObjAndOffset(base, position);
-
-      position += keyLength;
-      valueMemoryLocation.setObjAndOffset(base, position);
     }
 
     private Location with(int pos, int keyHashcode, boolean isDefined) {
@@ -543,10 +546,11 @@ public final class BytesToBytesMap extends MemoryConsumer {
     private Location with(Object base, long offset, int length) {
       this.isDefined = true;
       this.memoryPage = null;
+      baseObject = base;
+      keyOffset = offset + 4;
       keyLength = Platform.getInt(base, offset);
+      valueOffset = offset + 4 + keyLength;
       valueLength = length - 4 - keyLength;
-      keyMemoryLocation.setObjAndOffset(base, offset + 4);
-      valueMemoryLocation.setObjAndOffset(base, offset + 4 + keyLength);
       return this;
     }
 
@@ -566,34 +570,44 @@ public final class BytesToBytesMap extends MemoryConsumer {
     }
 
     /**
-     * Returns the address of the key defined at this position.
-     * This points to the first byte of the key data.
-     * Unspecified behavior if the key is not defined.
-     * For efficiency reasons, calls to this method always returns the same MemoryLocation object.
+     * Returns the base object for key.
      */
-    public MemoryLocation getKeyAddress() {
+    public Object getKeyBase() {
       assert (isDefined);
-      return keyMemoryLocation;
+      return baseObject;
     }
 
     /**
-     * Returns the length of the key defined at this position.
-     * Unspecified behavior if the key is not defined.
+     * Returns the offset for key.
      */
-    public int getKeyLength() {
+    public long getKeyOffset() {
       assert (isDefined);
-      return keyLength;
+      return keyOffset;
+    }
+
+    /**
+     * Returns the base object for value.
+     */
+    public Object getValueBase() {
+      assert (isDefined);
+      return baseObject;
     }
 
     /**
-     * Returns the address of the value defined at this position.
-     * This points to the first byte of the value data.
+     * Returns the offset for value.
+     */
+    public long getValueOffset() {
+      assert (isDefined);
+      return valueOffset;
+    }
+
+    /**
+     * Returns the length of the key defined at this position.
      * Unspecified behavior if the key is not defined.
-     * For efficiency reasons, calls to this method always returns the same MemoryLocation object.
      */
-    public MemoryLocation getValueAddress() {
+    public int getKeyLength() {
       assert (isDefined);
-      return valueMemoryLocation;
+      return keyLength;
     }
 
     /**

http://git-wip-us.apache.org/repos/asf/spark/blob/0e5ebac3/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
index 702ba54..d8af2b3 100644
--- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
+++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
@@ -39,14 +39,13 @@ import org.mockito.stubbing.Answer;
 
 import org.apache.spark.SparkConf;
 import org.apache.spark.executor.ShuffleWriteMetrics;
-import org.apache.spark.memory.TestMemoryManager;
 import org.apache.spark.memory.TaskMemoryManager;
+import org.apache.spark.memory.TestMemoryManager;
 import org.apache.spark.network.util.JavaUtils;
 import org.apache.spark.serializer.SerializerInstance;
 import org.apache.spark.storage.*;
 import org.apache.spark.unsafe.Platform;
 import org.apache.spark.unsafe.array.ByteArrayMethods;
-import org.apache.spark.unsafe.memory.MemoryLocation;
 import org.apache.spark.util.Utils;
 
 import static org.hamcrest.Matchers.greaterThan;
@@ -142,10 +141,9 @@ public abstract class AbstractBytesToBytesMapSuite {
 
   protected abstract boolean useOffHeapMemoryAllocator();
 
-  private static byte[] getByteArray(MemoryLocation loc, int size) {
+  private static byte[] getByteArray(Object base, long offset, int size) {
     final byte[] arr = new byte[size];
-    Platform.copyMemory(
-      loc.getBaseObject(), loc.getBaseOffset(), arr, Platform.BYTE_ARRAY_OFFSET, size);
+    Platform.copyMemory(base, offset, arr, Platform.BYTE_ARRAY_OFFSET, size);
     return arr;
   }
 
@@ -163,13 +161,14 @@ public abstract class AbstractBytesToBytesMapSuite {
    */
   private static boolean arrayEquals(
       byte[] expected,
-      MemoryLocation actualAddr,
+      Object base,
+      long offset,
       long actualLengthBytes) {
     return (actualLengthBytes == expected.length) && ByteArrayMethods.arrayEquals(
       expected,
       Platform.BYTE_ARRAY_OFFSET,
-      actualAddr.getBaseObject(),
-      actualAddr.getBaseOffset(),
+      base,
+      offset,
       expected.length
     );
   }
@@ -212,16 +211,20 @@ public abstract class AbstractBytesToBytesMapSuite {
       // reflect the result of this store without us having to call lookup() again on the same key.
       Assert.assertEquals(recordLengthBytes, loc.getKeyLength());
       Assert.assertEquals(recordLengthBytes, loc.getValueLength());
-      Assert.assertArrayEquals(keyData, getByteArray(loc.getKeyAddress(), recordLengthBytes));
-      Assert.assertArrayEquals(valueData, getByteArray(loc.getValueAddress(), recordLengthBytes));
+      Assert.assertArrayEquals(keyData,
+        getByteArray(loc.getKeyBase(), loc.getKeyOffset(), recordLengthBytes));
+      Assert.assertArrayEquals(valueData,
+        getByteArray(loc.getValueBase(), loc.getValueOffset(), recordLengthBytes));
 
       // After calling lookup() the location should still point to the correct data.
       Assert.assertTrue(
         map.lookup(keyData, Platform.BYTE_ARRAY_OFFSET, recordLengthBytes).isDefined());
       Assert.assertEquals(recordLengthBytes, loc.getKeyLength());
       Assert.assertEquals(recordLengthBytes, loc.getValueLength());
-      Assert.assertArrayEquals(keyData, getByteArray(loc.getKeyAddress(), recordLengthBytes));
-      Assert.assertArrayEquals(valueData, getByteArray(loc.getValueAddress(), recordLengthBytes));
+      Assert.assertArrayEquals(keyData,
+        getByteArray(loc.getKeyBase(), loc.getKeyOffset(), recordLengthBytes));
+      Assert.assertArrayEquals(valueData,
+        getByteArray(loc.getValueBase(), loc.getValueOffset(), recordLengthBytes));
 
       try {
         Assert.assertTrue(loc.putNewKey(
@@ -283,15 +286,12 @@ public abstract class AbstractBytesToBytesMapSuite {
       while (iter.hasNext()) {
         final BytesToBytesMap.Location loc = iter.next();
         Assert.assertTrue(loc.isDefined());
-        final MemoryLocation keyAddress = loc.getKeyAddress();
-        final MemoryLocation valueAddress = loc.getValueAddress();
-        final long value = Platform.getLong(
-          valueAddress.getBaseObject(), valueAddress.getBaseOffset());
+        final long value = Platform.getLong(loc.getValueBase(), loc.getValueOffset());
         final long keyLength = loc.getKeyLength();
         if (keyLength == 0) {
           Assert.assertTrue("value " + value + " was not divisible by 5", value % 5 == 0);
         } else {
-          final long key = Platform.getLong(keyAddress.getBaseObject(), keyAddress.getBaseOffset());
+          final long key = Platform.getLong(loc.getKeyBase(), loc.getKeyOffset());
           Assert.assertEquals(value, key);
         }
         valuesSeen.set((int) value);
@@ -365,15 +365,15 @@ public abstract class AbstractBytesToBytesMapSuite {
         Assert.assertEquals(KEY_LENGTH, loc.getKeyLength());
         Assert.assertEquals(VALUE_LENGTH, loc.getValueLength());
         Platform.copyMemory(
-          loc.getKeyAddress().getBaseObject(),
-          loc.getKeyAddress().getBaseOffset(),
+          loc.getKeyBase(),
+          loc.getKeyOffset(),
           key,
           Platform.LONG_ARRAY_OFFSET,
           KEY_LENGTH
         );
         Platform.copyMemory(
-          loc.getValueAddress().getBaseObject(),
-          loc.getValueAddress().getBaseOffset(),
+          loc.getValueBase(),
+          loc.getValueOffset(),
           value,
           Platform.LONG_ARRAY_OFFSET,
           VALUE_LENGTH
@@ -425,8 +425,9 @@ public abstract class AbstractBytesToBytesMapSuite {
           Assert.assertTrue(loc.isDefined());
           Assert.assertEquals(key.length, loc.getKeyLength());
           Assert.assertEquals(value.length, loc.getValueLength());
-          Assert.assertTrue(arrayEquals(key, loc.getKeyAddress(), key.length));
-          Assert.assertTrue(arrayEquals(value, loc.getValueAddress(), value.length));
+          Assert.assertTrue(arrayEquals(key, loc.getKeyBase(), loc.getKeyOffset(), key.length));
+          Assert.assertTrue(
+            arrayEquals(value, loc.getValueBase(), loc.getValueOffset(), value.length));
         }
       }
 
@@ -436,8 +437,10 @@ public abstract class AbstractBytesToBytesMapSuite {
         final BytesToBytesMap.Location loc =
           map.lookup(key, Platform.BYTE_ARRAY_OFFSET, key.length);
         Assert.assertTrue(loc.isDefined());
-        Assert.assertTrue(arrayEquals(key, loc.getKeyAddress(), loc.getKeyLength()));
-        Assert.assertTrue(arrayEquals(value, loc.getValueAddress(), loc.getValueLength()));
+        Assert.assertTrue(
+          arrayEquals(key, loc.getKeyBase(), loc.getKeyOffset(), loc.getKeyLength()));
+        Assert.assertTrue(
+          arrayEquals(value, loc.getValueBase(), loc.getValueOffset(), loc.getValueLength()));
       }
     } finally {
       map.free();
@@ -476,8 +479,9 @@ public abstract class AbstractBytesToBytesMapSuite {
           Assert.assertTrue(loc.isDefined());
           Assert.assertEquals(key.length, loc.getKeyLength());
           Assert.assertEquals(value.length, loc.getValueLength());
-          Assert.assertTrue(arrayEquals(key, loc.getKeyAddress(), key.length));
-          Assert.assertTrue(arrayEquals(value, loc.getValueAddress(), value.length));
+          Assert.assertTrue(arrayEquals(key, loc.getKeyBase(), loc.getKeyOffset(), key.length));
+          Assert.assertTrue(
+            arrayEquals(value, loc.getValueBase(), loc.getValueOffset(), value.length));
         }
       }
       for (Map.Entry<ByteBuffer, byte[]> entry : expected.entrySet()) {
@@ -486,8 +490,10 @@ public abstract class AbstractBytesToBytesMapSuite {
         final BytesToBytesMap.Location loc =
           map.lookup(key, Platform.BYTE_ARRAY_OFFSET, key.length);
         Assert.assertTrue(loc.isDefined());
-        Assert.assertTrue(arrayEquals(key, loc.getKeyAddress(), loc.getKeyLength()));
-        Assert.assertTrue(arrayEquals(value, loc.getValueAddress(), loc.getValueLength()));
+        Assert.assertTrue(
+          arrayEquals(key, loc.getKeyBase(), loc.getKeyOffset(), loc.getKeyLength()));
+        Assert.assertTrue(
+          arrayEquals(value, loc.getValueBase(), loc.getValueOffset(), loc.getValueLength()));
       }
     } finally {
       map.free();

http://git-wip-us.apache.org/repos/asf/spark/blob/0e5ebac3/project/MimaExcludes.scala
----------------------------------------------------------------------
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 9209094..1338947 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -40,6 +40,7 @@ object MimaExcludes {
         excludePackage("org.apache.spark.rpc"),
         excludePackage("org.spark-project.jetty"),
         excludePackage("org.apache.spark.unused"),
+        excludePackage("org.apache.spark.unsafe"),
         excludePackage("org.apache.spark.util.collection.unsafe"),
         excludePackage("org.apache.spark.sql.catalyst"),
         excludePackage("org.apache.spark.sql.execution"),

http://git-wip-us.apache.org/repos/asf/spark/blob/0e5ebac3/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
index f4ccadd..28e4f50 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
@@ -322,7 +322,6 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression
     }
   }
 
-
   override def genCode(ctx: CodegenContext, ev: ExprCode): String = {
     ev.isNull = "false"
     val childrenHash = children.map { child =>

http://git-wip-us.apache.org/repos/asf/spark/blob/0e5ebac3/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
----------------------------------------------------------------------
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
index 6bf9d7b..2e84178 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
@@ -121,19 +121,24 @@ public final class UnsafeFixedWidthAggregationMap {
     return getAggregationBufferFromUnsafeRow(unsafeGroupingKeyRow);
   }
 
-  public UnsafeRow getAggregationBufferFromUnsafeRow(UnsafeRow unsafeGroupingKeyRow) {
+  public UnsafeRow getAggregationBufferFromUnsafeRow(UnsafeRow key) {
+    return getAggregationBufferFromUnsafeRow(key, key.hashCode());
+  }
+
+  public UnsafeRow getAggregationBufferFromUnsafeRow(UnsafeRow key, int hash) {
     // Probe our map using the serialized key
     final BytesToBytesMap.Location loc = map.lookup(
-      unsafeGroupingKeyRow.getBaseObject(),
-      unsafeGroupingKeyRow.getBaseOffset(),
-      unsafeGroupingKeyRow.getSizeInBytes());
+      key.getBaseObject(),
+      key.getBaseOffset(),
+      key.getSizeInBytes(),
+      hash);
     if (!loc.isDefined()) {
       // This is the first time that we've seen this grouping key, so we'll insert a copy of the
       // empty aggregation buffer into the map:
       boolean putSucceeded = loc.putNewKey(
-        unsafeGroupingKeyRow.getBaseObject(),
-        unsafeGroupingKeyRow.getBaseOffset(),
-        unsafeGroupingKeyRow.getSizeInBytes(),
+        key.getBaseObject(),
+        key.getBaseOffset(),
+        key.getSizeInBytes(),
         emptyAggregationBuffer,
         Platform.BYTE_ARRAY_OFFSET,
         emptyAggregationBuffer.length
@@ -144,10 +149,9 @@ public final class UnsafeFixedWidthAggregationMap {
     }
 
     // Reset the pointer to point to the value that we just stored or looked up:
-    final MemoryLocation address = loc.getValueAddress();
     currentAggregationBuffer.pointTo(
-      address.getBaseObject(),
-      address.getBaseOffset(),
+      loc.getValueBase(),
+      loc.getValueOffset(),
       loc.getValueLength()
     );
     return currentAggregationBuffer;
@@ -172,16 +176,14 @@ public final class UnsafeFixedWidthAggregationMap {
       public boolean next() {
         if (mapLocationIterator.hasNext()) {
           final BytesToBytesMap.Location loc = mapLocationIterator.next();
-          final MemoryLocation keyAddress = loc.getKeyAddress();
-          final MemoryLocation valueAddress = loc.getValueAddress();
           key.pointTo(
-            keyAddress.getBaseObject(),
-            keyAddress.getBaseOffset(),
+            loc.getKeyBase(),
+            loc.getKeyOffset(),
             loc.getKeyLength()
           );
           value.pointTo(
-            valueAddress.getBaseObject(),
-            valueAddress.getBaseOffset(),
+            loc.getValueBase(),
+            loc.getValueOffset(),
             loc.getValueLength()
           );
           return true;

http://git-wip-us.apache.org/repos/asf/spark/blob/0e5ebac3/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
----------------------------------------------------------------------
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
index 0da26bf..51e10b0 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
@@ -97,8 +97,8 @@ public final class UnsafeKVExternalSorter {
       UnsafeRow row = new UnsafeRow(numKeyFields);
       while (iter.hasNext()) {
         final BytesToBytesMap.Location loc = iter.next();
-        final Object baseObject = loc.getKeyAddress().getBaseObject();
-        final long baseOffset = loc.getKeyAddress().getBaseOffset();
+        final Object baseObject = loc.getKeyBase();
+        final long baseOffset = loc.getKeyOffset();
 
         // Get encoded memory address
         // baseObject + baseOffset point to the beginning of the key data in the map, but that

http://git-wip-us.apache.org/repos/asf/spark/blob/0e5ebac3/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
index 4ca2d85..b200239 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
@@ -366,11 +366,7 @@ private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Ru
   def apply(plan: SparkPlan): SparkPlan = {
     if (sqlContext.conf.wholeStageEnabled) {
       plan.transform {
-        case plan: CodegenSupport if supportCodegen(plan) &&
-          // Whole stage codegen is only useful when there are at least two levels of operators that
-          // support it (save at least one projection/iterator).
-          (Utils.isTesting || plan.children.exists(supportCodegen)) =>
-
+        case plan: CodegenSupport if supportCodegen(plan) =>
           var inputs = ArrayBuffer[SparkPlan]()
           val combined = plan.transform {
             // The build side can't be compiled together

http://git-wip-us.apache.org/repos/asf/spark/blob/0e5ebac3/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
index 9d9f14f..340b8f7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
@@ -501,6 +501,11 @@ case class TungstenAggregate(
       }
     }
 
+    // generate hash code for key
+    val hashExpr = Murmur3Hash(groupingExpressions, 42)
+    ctx.currentVars = input
+    val hashEval = BindReferences.bindReference(hashExpr, child.output).gen(ctx)
+
     val inputAttr = bufferAttributes ++ child.output
     ctx.currentVars = new Array[ExprCode](bufferAttributes.length) ++ input
     ctx.INPUT_ROW = buffer
@@ -526,10 +531,11 @@ case class TungstenAggregate(
     s"""
      // generate grouping key
      ${keyCode.code.trim}
+     ${hashEval.code.trim}
      UnsafeRow $buffer = null;
      if ($checkFallback) {
        // try to get the buffer from hash map
-       $buffer = $hashMapTerm.getAggregationBufferFromUnsafeRow($key);
+       $buffer = $hashMapTerm.getAggregationBufferFromUnsafeRow($key, ${hashEval.value});
      }
      if ($buffer == null) {
        if ($sorterTerm == null) {
@@ -540,7 +546,7 @@ case class TungstenAggregate(
        $resetCoulter
        // the hash map had be spilled, it should have enough memory now,
        // try  to allocate buffer again.
-       $buffer = $hashMapTerm.getAggregationBufferFromUnsafeRow($key);
+       $buffer = $hashMapTerm.getAggregationBufferFromUnsafeRow($key, ${hashEval.value});
        if ($buffer == null) {
          // failed to allocate the first page
          throw new OutOfMemoryError("No enough memory for aggregation");

http://git-wip-us.apache.org/repos/asf/spark/blob/0e5ebac3/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
index c94d6c1..eb6930a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
@@ -277,13 +277,13 @@ private[joins] final class UnsafeHashedRelation(
       val map = binaryMap  // avoid the compiler error
       val loc = new map.Location  // this could be allocated in stack
       binaryMap.safeLookup(unsafeKey.getBaseObject, unsafeKey.getBaseOffset,
-        unsafeKey.getSizeInBytes, loc)
+        unsafeKey.getSizeInBytes, loc, unsafeKey.hashCode())
       if (loc.isDefined) {
         val buffer = CompactBuffer[UnsafeRow]()
 
-        val base = loc.getValueAddress.getBaseObject
-        var offset = loc.getValueAddress.getBaseOffset
-        val last = loc.getValueAddress.getBaseOffset + loc.getValueLength
+        val base = loc.getValueBase
+        var offset = loc.getValueOffset
+        val last = offset + loc.getValueLength
         while (offset < last) {
           val numFields = Platform.getInt(base, offset)
           val sizeInBytes = Platform.getInt(base, offset + 4)
@@ -311,12 +311,11 @@ private[joins] final class UnsafeHashedRelation(
       out.writeInt(binaryMap.numElements())
 
       var buffer = new Array[Byte](64)
-      def write(addr: MemoryLocation, length: Int): Unit = {
+      def write(base: Object, offset: Long, length: Int): Unit = {
         if (buffer.length < length) {
           buffer = new Array[Byte](length)
         }
-        Platform.copyMemory(addr.getBaseObject, addr.getBaseOffset,
-          buffer, Platform.BYTE_ARRAY_OFFSET, length)
+        Platform.copyMemory(base, offset, buffer, Platform.BYTE_ARRAY_OFFSET, length)
         out.write(buffer, 0, length)
       }
 
@@ -326,8 +325,8 @@ private[joins] final class UnsafeHashedRelation(
         // [key size] [values size] [key bytes] [values bytes]
         out.writeInt(loc.getKeyLength)
         out.writeInt(loc.getValueLength)
-        write(loc.getKeyAddress, loc.getKeyLength)
-        write(loc.getValueAddress, loc.getValueLength)
+        write(loc.getKeyBase, loc.getKeyOffset, loc.getKeyLength)
+        write(loc.getValueBase, loc.getValueOffset, loc.getValueLength)
       }
 
     } else {

http://git-wip-us.apache.org/repos/asf/spark/blob/0e5ebac3/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
index f015d29..dc6c647 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
@@ -114,11 +114,11 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
     }
 
     /*
-    Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
-    Aggregate w keys:                   Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
-    -------------------------------------------------------------------------------------------
-    Aggregate w keys codegen=false           2402 / 2551          8.0         125.0       1.0X
-    Aggregate w keys codegen=true            1620 / 1670         12.0          83.3       1.5X
+      Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
+      Aggregate w keys:                   Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
+      -------------------------------------------------------------------------------------------
+      Aggregate w keys codegen=false           2429 / 2644          8.6         115.8       1.0X
+      Aggregate w keys codegen=true            1535 / 1571         13.7          73.2       1.6X
     */
   }
 
@@ -165,21 +165,51 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
     benchmark.addCase("hash") { iter =>
       var i = 0
       val keyBytes = new Array[Byte](16)
-      val valueBytes = new Array[Byte](16)
       val key = new UnsafeRow(1)
       key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16)
-      val value = new UnsafeRow(2)
-      value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16)
       var s = 0
       while (i < N) {
         key.setInt(0, i % 1000)
         val h = Murmur3_x86_32.hashUnsafeWords(
-          key.getBaseObject, key.getBaseOffset, key.getSizeInBytes, 0)
+          key.getBaseObject, key.getBaseOffset, key.getSizeInBytes, 42)
+        s += h
+        i += 1
+      }
+    }
+
+    benchmark.addCase("fast hash") { iter =>
+      var i = 0
+      val keyBytes = new Array[Byte](16)
+      val key = new UnsafeRow(1)
+      key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16)
+      var s = 0
+      while (i < N) {
+        key.setInt(0, i % 1000)
+        val h = Murmur3_x86_32.hashLong(i % 1000, 42)
         s += h
         i += 1
       }
     }
 
+    benchmark.addCase("arrayEqual") { iter =>
+      var i = 0
+      val keyBytes = new Array[Byte](16)
+      val valueBytes = new Array[Byte](16)
+      val key = new UnsafeRow(1)
+      key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16)
+      val value = new UnsafeRow(1)
+      value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16)
+      value.setInt(0, 555)
+      var s = 0
+      while (i < N) {
+        key.setInt(0, i % 1000)
+        if (key.equals(value)) {
+          s += 1
+        }
+        i += 1
+      }
+    }
+
     Seq("off", "on").foreach { heap =>
       benchmark.addCase(s"BytesToBytesMap ($heap Heap)") { iter =>
         val taskMemoryManager = new TaskMemoryManager(
@@ -195,15 +225,15 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
         val valueBytes = new Array[Byte](16)
         val key = new UnsafeRow(1)
         key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16)
-        val value = new UnsafeRow(2)
+        val value = new UnsafeRow(1)
         value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16)
         var i = 0
         while (i < N) {
           key.setInt(0, i % 65536)
-          val loc = map.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes)
+          val loc = map.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes,
+            Murmur3_x86_32.hashLong(i % 65536, 42))
           if (loc.isDefined) {
-            value.pointTo(loc.getValueAddress.getBaseObject, loc.getValueAddress.getBaseOffset,
-              loc.getValueLength)
+            value.pointTo(loc.getValueBase, loc.getValueOffset, loc.getValueLength)
             value.setInt(0, value.getInt(0) + 1)
             i += 1
           } else {
@@ -218,9 +248,11 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
     Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
     BytesToBytesMap:                    Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
     -------------------------------------------------------------------------------------------
-    hash                                      628 /  661         83.0          12.0       1.0X
-    BytesToBytesMap (off Heap)               3292 / 3408         15.0          66.7       0.2X
-    BytesToBytesMap (on Heap)                3349 / 4267         15.0          66.7       0.2X
+    hash                                      651 /  678         80.0          12.5       1.0X
+    fast hash                                 336 /  343        155.9           6.4       1.9X
+    arrayEqual                                417 /  428        125.0           8.0       1.6X
+    BytesToBytesMap (off Heap)               2594 / 2664         20.2          49.5       0.2X
+    BytesToBytesMap (on Heap)                2693 / 2989         19.5          51.4       0.2X
       */
     benchmark.run()
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org