You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by da...@apache.org on 2015/06/30 00:59:28 UTC

spark git commit: [SPARK-8579] [SQL] support arbitrary object in UnsafeRow

Repository: spark
Updated Branches:
  refs/heads/master 931da5c8a -> ed359de59


[SPARK-8579] [SQL] support arbitrary object in UnsafeRow

This PR brings arbitrary object support in UnsafeRow (both in grouping key and aggregation buffer).

Two object pools will be created to hold those non-primitive objects, and put the index of them into UnsafeRow. In order to compare the grouping key as bytes, the objects in key will be stored in a unique object pool, to make sure same objects will have same index (used as hashCode).

For StringType and BinaryType, we still put them as var-length in UnsafeRow when initializing for better performance. But for update, they will be an object inside object pools (there will be some garbages left in the buffer).

BTW: Will create a JIRA once issue.apache.org is available.

cc JoshRosen rxin

Author: Davies Liu <da...@databricks.com>

Closes #6959 from davies/unsafe_obj and squashes the following commits:

5ce39da [Davies Liu] fix comment
5e797bf [Davies Liu] Merge branch 'master' of github.com:apache/spark into unsafe_obj
5803d64 [Davies Liu] fix conflict
461d304 [Davies Liu] Merge branch 'master' of github.com:apache/spark into unsafe_obj
2f41c90 [Davies Liu] Merge branch 'master' of github.com:apache/spark into unsafe_obj
b04d69c [Davies Liu] address comments
4859b80 [Davies Liu] fix comments
f38011c [Davies Liu] add a test for grouping by decimal
d2cf7ab [Davies Liu] add more tests for null checking
71983c5 [Davies Liu] add test for timestamp
e8a1649 [Davies Liu] reuse buffer for string
39f09ca [Davies Liu] Merge branch 'master' of github.com:apache/spark into unsafe_obj
035501e [Davies Liu] fix style
236d6de [Davies Liu] support arbitrary object in UnsafeRow


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/ed359de5
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/ed359de5
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/ed359de5

Branch: refs/heads/master
Commit: ed359de595d5dd67b666660eddf092eaf89041c8
Parents: 931da5c
Author: Davies Liu <da...@databricks.com>
Authored: Mon Jun 29 15:59:20 2015 -0700
Committer: Davies Liu <da...@databricks.com>
Committed: Mon Jun 29 15:59:20 2015 -0700

----------------------------------------------------------------------
 .../UnsafeFixedWidthAggregationMap.java         | 144 ++++++------
 .../sql/catalyst/expressions/UnsafeRow.java     | 218 ++++++++++---------
 .../spark/sql/catalyst/util/ObjectPool.java     |  78 +++++++
 .../sql/catalyst/util/UniqueObjectPool.java     |  59 +++++
 .../apache/spark/sql/catalyst/InternalRow.scala |   5 +-
 .../expressions/UnsafeRowConverter.scala        |  94 ++++----
 .../UnsafeFixedWidthAggregationMapSuite.scala   |  65 ++++--
 .../expressions/UnsafeRowConverterSuite.scala   | 190 ++++++++++++----
 .../sql/catalyst/util/ObjectPoolSuite.scala     |  57 +++++
 .../sql/execution/GeneratedAggregate.scala      |  16 +-
 10 files changed, 615 insertions(+), 311 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/ed359de5/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java
index 83f2a31..1e79f4b 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java
@@ -19,9 +19,11 @@ package org.apache.spark.sql.catalyst.expressions;
 
 import java.util.Iterator;
 
+import scala.Function1;
+
 import org.apache.spark.sql.catalyst.InternalRow;
-import org.apache.spark.sql.types.StructField;
-import org.apache.spark.sql.types.StructType;
+import org.apache.spark.sql.catalyst.util.ObjectPool;
+import org.apache.spark.sql.catalyst.util.UniqueObjectPool;
 import org.apache.spark.unsafe.PlatformDependent;
 import org.apache.spark.unsafe.map.BytesToBytesMap;
 import org.apache.spark.unsafe.memory.MemoryLocation;
@@ -38,16 +40,28 @@ public final class UnsafeFixedWidthAggregationMap {
    * An empty aggregation buffer, encoded in UnsafeRow format. When inserting a new key into the
    * map, we copy this buffer and use it as the value.
    */
-  private final byte[] emptyAggregationBuffer;
+  private final byte[] emptyBuffer;
 
-  private final StructType aggregationBufferSchema;
+  /**
+   * An empty row used by `initProjection`
+   */
+  private static final InternalRow emptyRow = new GenericInternalRow();
 
-  private final StructType groupingKeySchema;
+  /**
+   * Whether can the empty aggregation buffer be reuse without calling `initProjection` or not.
+   */
+  private final boolean reuseEmptyBuffer;
 
   /**
-   * Encodes grouping keys as UnsafeRows.
+   * The projection used to initialize the emptyBuffer
    */
-  private final UnsafeRowConverter groupingKeyToUnsafeRowConverter;
+  private final Function1<InternalRow, InternalRow> initProjection;
+
+  /**
+   * Encodes grouping keys or buffers as UnsafeRows.
+   */
+  private final UnsafeRowConverter keyConverter;
+  private final UnsafeRowConverter bufferConverter;
 
   /**
    * A hashmap which maps from opaque bytearray keys to bytearray values.
@@ -55,9 +69,19 @@ public final class UnsafeFixedWidthAggregationMap {
   private final BytesToBytesMap map;
 
   /**
+   * An object pool for objects that are used in grouping keys.
+   */
+  private final UniqueObjectPool keyPool;
+
+  /**
+   * An object pool for objects that are used in aggregation buffers.
+   */
+  private final ObjectPool bufferPool;
+
+  /**
    * Re-used pointer to the current aggregation buffer
    */
-  private final UnsafeRow currentAggregationBuffer = new UnsafeRow();
+  private final UnsafeRow currentBuffer = new UnsafeRow();
 
   /**
    * Scratch space that is used when encoding grouping keys into UnsafeRow format.
@@ -70,67 +94,38 @@ public final class UnsafeFixedWidthAggregationMap {
   private final boolean enablePerfMetrics;
 
   /**
-   * @return true if UnsafeFixedWidthAggregationMap supports grouping keys with the given schema,
-   *         false otherwise.
-   */
-  public static boolean supportsGroupKeySchema(StructType schema) {
-    for (StructField field: schema.fields()) {
-      if (!UnsafeRow.readableFieldTypes.contains(field.dataType())) {
-        return false;
-      }
-    }
-    return true;
-  }
-
-  /**
-   * @return true if UnsafeFixedWidthAggregationMap supports aggregation buffers with the given
-   *         schema, false otherwise.
-   */
-  public static boolean supportsAggregationBufferSchema(StructType schema) {
-    for (StructField field: schema.fields()) {
-      if (!UnsafeRow.settableFieldTypes.contains(field.dataType())) {
-        return false;
-      }
-    }
-    return true;
-  }
-
-  /**
    * Create a new UnsafeFixedWidthAggregationMap.
    *
-   * @param emptyAggregationBuffer the default value for new keys (a "zero" of the agg. function)
-   * @param aggregationBufferSchema the schema of the aggregation buffer, used for row conversion.
-   * @param groupingKeySchema the schema of the grouping key, used for row conversion.
+   * @param initProjection the default value for new keys (a "zero" of the agg. function)
+   * @param keyConverter the converter of the grouping key, used for row conversion.
+   * @param bufferConverter the converter of the aggregation buffer, used for row conversion.
    * @param memoryManager the memory manager used to allocate our Unsafe memory structures.
    * @param initialCapacity the initial capacity of the map (a sizing hint to avoid re-hashing).
    * @param enablePerfMetrics if true, performance metrics will be recorded (has minor perf impact)
    */
   public UnsafeFixedWidthAggregationMap(
-      InternalRow emptyAggregationBuffer,
-      StructType aggregationBufferSchema,
-      StructType groupingKeySchema,
+      Function1<InternalRow, InternalRow> initProjection,
+      UnsafeRowConverter keyConverter,
+      UnsafeRowConverter bufferConverter,
       TaskMemoryManager memoryManager,
       int initialCapacity,
       boolean enablePerfMetrics) {
-    this.emptyAggregationBuffer =
-      convertToUnsafeRow(emptyAggregationBuffer, aggregationBufferSchema);
-    this.aggregationBufferSchema = aggregationBufferSchema;
-    this.groupingKeyToUnsafeRowConverter = new UnsafeRowConverter(groupingKeySchema);
-    this.groupingKeySchema = groupingKeySchema;
-    this.map = new BytesToBytesMap(memoryManager, initialCapacity, enablePerfMetrics);
+    this.initProjection = initProjection;
+    this.keyConverter = keyConverter;
+    this.bufferConverter = bufferConverter;
     this.enablePerfMetrics = enablePerfMetrics;
-  }
 
-  /**
-   * Convert a Java object row into an UnsafeRow, allocating it into a new byte array.
-   */
-  private static byte[] convertToUnsafeRow(InternalRow javaRow, StructType schema) {
-    final UnsafeRowConverter converter = new UnsafeRowConverter(schema);
-    final byte[] unsafeRow = new byte[converter.getSizeRequirement(javaRow)];
-    final int writtenLength =
-      converter.writeRow(javaRow, unsafeRow, PlatformDependent.BYTE_ARRAY_OFFSET);
-    assert (writtenLength == unsafeRow.length): "Size requirement calculation was wrong!";
-    return unsafeRow;
+    this.map = new BytesToBytesMap(memoryManager, initialCapacity, enablePerfMetrics);
+    this.keyPool = new UniqueObjectPool(100);
+    this.bufferPool = new ObjectPool(initialCapacity);
+
+    InternalRow initRow = initProjection.apply(emptyRow);
+    this.emptyBuffer = new byte[bufferConverter.getSizeRequirement(initRow)];
+    int writtenLength = bufferConverter.writeRow(
+      initRow, emptyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, bufferPool);
+    assert (writtenLength == emptyBuffer.length): "Size requirement calculation was wrong!";
+    // re-use the empty buffer only when there is no object saved in pool.
+    reuseEmptyBuffer = bufferPool.size() == 0;
   }
 
   /**
@@ -138,15 +133,16 @@ public final class UnsafeFixedWidthAggregationMap {
    * return the same object.
    */
   public UnsafeRow getAggregationBuffer(InternalRow groupingKey) {
-    final int groupingKeySize = groupingKeyToUnsafeRowConverter.getSizeRequirement(groupingKey);
+    final int groupingKeySize = keyConverter.getSizeRequirement(groupingKey);
     // Make sure that the buffer is large enough to hold the key. If it's not, grow it:
     if (groupingKeySize > groupingKeyConversionScratchSpace.length) {
       groupingKeyConversionScratchSpace = new byte[groupingKeySize];
     }
-    final int actualGroupingKeySize = groupingKeyToUnsafeRowConverter.writeRow(
+    final int actualGroupingKeySize = keyConverter.writeRow(
       groupingKey,
       groupingKeyConversionScratchSpace,
-      PlatformDependent.BYTE_ARRAY_OFFSET);
+      PlatformDependent.BYTE_ARRAY_OFFSET,
+      keyPool);
     assert (groupingKeySize == actualGroupingKeySize) : "Size requirement calculation was wrong!";
 
     // Probe our map using the serialized key
@@ -157,25 +153,31 @@ public final class UnsafeFixedWidthAggregationMap {
     if (!loc.isDefined()) {
       // This is the first time that we've seen this grouping key, so we'll insert a copy of the
       // empty aggregation buffer into the map:
+      if (!reuseEmptyBuffer) {
+        // There is some objects referenced by emptyBuffer, so generate a new one
+        InternalRow initRow = initProjection.apply(emptyRow);
+        bufferConverter.writeRow(initRow, emptyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET,
+          bufferPool);
+      }
       loc.putNewKey(
         groupingKeyConversionScratchSpace,
         PlatformDependent.BYTE_ARRAY_OFFSET,
         groupingKeySize,
-        emptyAggregationBuffer,
+        emptyBuffer,
         PlatformDependent.BYTE_ARRAY_OFFSET,
-        emptyAggregationBuffer.length
+        emptyBuffer.length
       );
     }
 
     // Reset the pointer to point to the value that we just stored or looked up:
     final MemoryLocation address = loc.getValueAddress();
-    currentAggregationBuffer.pointTo(
+    currentBuffer.pointTo(
       address.getBaseObject(),
       address.getBaseOffset(),
-      aggregationBufferSchema.length(),
-      aggregationBufferSchema
+      bufferConverter.numFields(),
+      bufferPool
     );
-    return currentAggregationBuffer;
+    return currentBuffer;
   }
 
   /**
@@ -211,14 +213,14 @@ public final class UnsafeFixedWidthAggregationMap {
         entry.key.pointTo(
           keyAddress.getBaseObject(),
           keyAddress.getBaseOffset(),
-          groupingKeySchema.length(),
-          groupingKeySchema
+          keyConverter.numFields(),
+          keyPool
         );
         entry.value.pointTo(
           valueAddress.getBaseObject(),
           valueAddress.getBaseOffset(),
-          aggregationBufferSchema.length(),
-          aggregationBufferSchema
+          bufferConverter.numFields(),
+          bufferPool
         );
         return entry;
       }
@@ -246,6 +248,8 @@ public final class UnsafeFixedWidthAggregationMap {
     System.out.println("Number of hash collisions: " + map.getNumHashCollisions());
     System.out.println("Time spent resizing (ns): " + map.getTimeSpentResizingNs());
     System.out.println("Total memory consumption (bytes): " + map.getTotalMemoryConsumption());
+    System.out.println("Number of unique objects in keys: " + keyPool.size());
+    System.out.println("Number of objects in buffers: " + bufferPool.size());
   }
 
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/ed359de5/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
index 11d51d9..f077064 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
@@ -17,20 +17,12 @@
 
 package org.apache.spark.sql.catalyst.expressions;
 
-import javax.annotation.Nullable;
-import java.util.Arrays;
-import java.util.Collections;
-import java.util.HashSet;
-import java.util.Set;
-
 import org.apache.spark.sql.catalyst.InternalRow;
-import org.apache.spark.sql.types.DataType;
-import org.apache.spark.sql.types.StructType;
+import org.apache.spark.sql.catalyst.util.ObjectPool;
 import org.apache.spark.unsafe.PlatformDependent;
 import org.apache.spark.unsafe.bitset.BitSetMethods;
 import org.apache.spark.unsafe.types.UTF8String;
 
-import static org.apache.spark.sql.types.DataTypes.*;
 
 /**
  * An Unsafe implementation of Row which is backed by raw memory instead of Java objects.
@@ -44,7 +36,20 @@ import static org.apache.spark.sql.types.DataTypes.*;
  * primitive types, such as long, double, or int, we store the value directly in the word. For
  * fields with non-primitive or variable-length values, we store a relative offset (w.r.t. the
  * base address of the row) that points to the beginning of the variable-length field, and length
- * (they are combined into a long).
+ * (they are combined into a long). For other objects, they are stored in a pool, the indexes of
+ * them are hold in the the word.
+ *
+ * In order to support fast hashing and equality checks for UnsafeRows that contain objects
+ * when used as grouping key in BytesToBytesMap, we put the objects in an UniqueObjectPool to make
+ * sure all the key have the same index for same object, then we can hash/compare the objects by
+ * hash/compare the index.
+ *
+ * For non-primitive types, the word of a field could be:
+ *   UNION {
+ *     [1] [offset: 31bits] [length: 31bits]  // StringType
+ *     [0] [offset: 31bits] [length: 31bits]  // BinaryType
+ *     - [index: 63bits]                      // StringType, Binary, index to object in pool
+ *   }
  *
  * Instances of `UnsafeRow` act as pointers to row data stored in this format.
  */
@@ -53,8 +58,12 @@ public final class UnsafeRow extends MutableRow {
   private Object baseObject;
   private long baseOffset;
 
+  /** A pool to hold non-primitive objects */
+  private ObjectPool pool;
+
   Object getBaseObject() { return baseObject; }
   long getBaseOffset() { return baseOffset; }
+  ObjectPool getPool() { return pool; }
 
   /** The number of fields in this row, used for calculating the bitset width (and in assertions) */
   private int numFields;
@@ -63,15 +72,6 @@ public final class UnsafeRow extends MutableRow {
 
   /** The width of the null tracking bit set, in bytes */
   private int bitSetWidthInBytes;
-  /**
-   * This optional schema is required if you want to call generic get() and set() methods on
-   * this UnsafeRow, but is optional if callers will only use type-specific getTYPE() and setTYPE()
-   * methods. This should be removed after the planned InternalRow / Row split; right now, it's only
-   * needed by the generic get() method, which is only called internally by code that accesses
-   * UTF8String-typed columns.
-   */
-  @Nullable
-  private StructType schema;
 
   private long getFieldOffset(int ordinal) {
    return baseOffset + bitSetWidthInBytes + ordinal * 8L;
@@ -81,42 +81,7 @@ public final class UnsafeRow extends MutableRow {
     return ((numFields / 64) + (numFields % 64 == 0 ? 0 : 1)) * 8;
   }
 
-  /**
-   * Field types that can be updated in place in UnsafeRows (e.g. we support set() for these types)
-   */
-  public static final Set<DataType> settableFieldTypes;
-
-  /**
-   * Fields types can be read(but not set (e.g. set() will throw UnsupportedOperationException).
-   */
-  public static final Set<DataType> readableFieldTypes;
-
-  // TODO: support DecimalType
-  static {
-    settableFieldTypes = Collections.unmodifiableSet(
-      new HashSet<DataType>(
-        Arrays.asList(new DataType[] {
-          NullType,
-          BooleanType,
-          ByteType,
-          ShortType,
-          IntegerType,
-          LongType,
-          FloatType,
-          DoubleType,
-          DateType,
-          TimestampType
-    })));
-
-    // We support get() on a superset of the types for which we support set():
-    final Set<DataType> _readableFieldTypes = new HashSet<DataType>(
-      Arrays.asList(new DataType[]{
-        StringType,
-        BinaryType
-      }));
-    _readableFieldTypes.addAll(settableFieldTypes);
-    readableFieldTypes = Collections.unmodifiableSet(_readableFieldTypes);
-  }
+  public static final long OFFSET_BITS = 31L;
 
   /**
    * Construct a new UnsafeRow. The resulting row won't be usable until `pointTo()` has been called,
@@ -130,22 +95,15 @@ public final class UnsafeRow extends MutableRow {
    * @param baseObject the base object
    * @param baseOffset the offset within the base object
    * @param numFields the number of fields in this row
-   * @param schema an optional schema; this is necessary if you want to call generic get() or set()
-   *               methods on this row, but is optional if the caller will only use type-specific
-   *               getTYPE() and setTYPE() methods.
+   * @param pool the object pool to hold arbitrary objects
    */
-  public void pointTo(
-      Object baseObject,
-      long baseOffset,
-      int numFields,
-      @Nullable StructType schema) {
+  public void pointTo(Object baseObject, long baseOffset, int numFields, ObjectPool pool) {
     assert numFields >= 0 : "numFields should >= 0";
-    assert schema == null || schema.fields().length == numFields;
     this.bitSetWidthInBytes = calculateBitSetWidthInBytes(numFields);
     this.baseObject = baseObject;
     this.baseOffset = baseOffset;
     this.numFields = numFields;
-    this.schema = schema;
+    this.pool = pool;
   }
 
   private void assertIndexIsValid(int index) {
@@ -168,9 +126,68 @@ public final class UnsafeRow extends MutableRow {
     BitSetMethods.unset(baseObject, baseOffset, i);
   }
 
+  /**
+   * Updates the column `i` as Object `value`, which cannot be primitive types.
+   */
   @Override
-  public void update(int ordinal, Object value) {
-    throw new UnsupportedOperationException();
+  public void update(int i, Object value) {
+    if (value == null) {
+      if (!isNullAt(i)) {
+        // remove the old value from pool
+        long idx = getLong(i);
+        if (idx <= 0) {
+          // this is the index of old value in pool, remove it
+          pool.replace((int)-idx, null);
+        } else {
+          // there will be some garbage left (UTF8String or byte[])
+        }
+        setNullAt(i);
+      }
+      return;
+    }
+
+    if (isNullAt(i)) {
+      // there is not an old value, put the new value into pool
+      int idx = pool.put(value);
+      setLong(i, (long)-idx);
+    } else {
+      // there is an old value, check the type, then replace it or update it
+      long v = getLong(i);
+      if (v <= 0) {
+        // it's the index in the pool, replace old value with new one
+        int idx = (int)-v;
+        pool.replace(idx, value);
+      } else {
+        // old value is UTF8String or byte[], try to reuse the space
+        boolean isString;
+        byte[] newBytes;
+        if (value instanceof UTF8String) {
+          newBytes = ((UTF8String) value).getBytes();
+          isString = true;
+        } else {
+          newBytes = (byte[]) value;
+          isString = false;
+        }
+        int offset = (int) ((v >> OFFSET_BITS) & Integer.MAX_VALUE);
+        int oldLength = (int) (v & Integer.MAX_VALUE);
+        if (newBytes.length <= oldLength) {
+          // the new value can fit in the old buffer, re-use it
+          PlatformDependent.copyMemory(
+            newBytes,
+            PlatformDependent.BYTE_ARRAY_OFFSET,
+            baseObject,
+            baseOffset + offset,
+            newBytes.length);
+          long flag = isString ? 1L << (OFFSET_BITS * 2) : 0L;
+          setLong(i, flag | (((long) offset) << OFFSET_BITS) | (long) newBytes.length);
+        } else {
+          // Cannot fit in the buffer
+          int idx = pool.put(value);
+          setLong(i, (long) -idx);
+        }
+      }
+    }
+    setNotNullAt(i);
   }
 
   @Override
@@ -227,28 +244,38 @@ public final class UnsafeRow extends MutableRow {
     return numFields;
   }
 
-  @Override
-  public StructType schema() {
-    return schema;
-  }
-
+  /**
+   * Returns the object for column `i`, which should not be primitive type.
+   */
   @Override
   public Object get(int i) {
     assertIndexIsValid(i);
-    assert (schema != null) : "Schema must be defined when calling generic get() method";
-    final DataType dataType = schema.fields()[i].dataType();
-    // UnsafeRow is only designed to be invoked by internal code, which only invokes this generic
-    // get() method when trying to access UTF8String-typed columns. If we refactor the codebase to
-    // separate the internal and external row interfaces, then internal code can fetch strings via
-    // a new getUTF8String() method and we'll be able to remove this method.
     if (isNullAt(i)) {
       return null;
-    } else if (dataType == StringType) {
-      return getUTF8String(i);
-    } else if (dataType == BinaryType) {
-      return getBinary(i);
+    }
+    long v = PlatformDependent.UNSAFE.getLong(baseObject, getFieldOffset(i));
+    if (v <= 0) {
+      // It's an index to object in the pool.
+      int idx = (int)-v;
+      return pool.get(idx);
     } else {
-      throw new UnsupportedOperationException();
+      // The column could be StingType or BinaryType
+      boolean isString = (v >> (OFFSET_BITS * 2)) > 0;
+      int offset = (int) ((v >> OFFSET_BITS) & Integer.MAX_VALUE);
+      int size = (int) (v & Integer.MAX_VALUE);
+      final byte[] bytes = new byte[size];
+      PlatformDependent.copyMemory(
+        baseObject,
+        baseOffset + offset,
+        bytes,
+        PlatformDependent.BYTE_ARRAY_OFFSET,
+        size
+      );
+      if (isString) {
+        return UTF8String.fromBytes(bytes);
+      } else {
+        return bytes;
+      }
     }
   }
 
@@ -308,31 +335,6 @@ public final class UnsafeRow extends MutableRow {
     }
   }
 
-  public UTF8String getUTF8String(int i) {
-    return UTF8String.fromBytes(getBinary(i));
-  }
-
-  public byte[] getBinary(int i) {
-    assertIndexIsValid(i);
-    final long offsetAndSize = getLong(i);
-    final int offset = (int)(offsetAndSize >> 32);
-    final int size = (int)(offsetAndSize & ((1L << 32) - 1));
-    final byte[] bytes = new byte[size];
-    PlatformDependent.copyMemory(
-      baseObject,
-      baseOffset + offset,
-      bytes,
-      PlatformDependent.BYTE_ARRAY_OFFSET,
-      size
-    );
-    return bytes;
-  }
-
-  @Override
-  public String getString(int i) {
-    return getUTF8String(i).toString();
-  }
-
   @Override
   public InternalRow copy() {
     throw new UnsupportedOperationException();

http://git-wip-us.apache.org/repos/asf/spark/blob/ed359de5/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/ObjectPool.java
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/ObjectPool.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/ObjectPool.java
new file mode 100644
index 0000000..97f89a7
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/ObjectPool.java
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.util;
+
+/**
+ * A object pool stores a collection of objects in array, then they can be referenced by the
+ * pool plus an index.
+ */
+public class ObjectPool {
+
+  /**
+   * An array to hold objects, which will grow as needed.
+   */
+  private Object[] objects;
+
+  /**
+   * How many objects in the pool.
+   */
+  private int numObj;
+
+  public ObjectPool(int capacity) {
+    objects = new Object[capacity];
+    numObj = 0;
+  }
+
+  /**
+   * Returns how many objects in the pool.
+   */
+  public int size() {
+    return numObj;
+  }
+
+  /**
+   * Returns the object at position `idx` in the array.
+   */
+  public Object get(int idx) {
+    assert (idx < numObj);
+    return objects[idx];
+  }
+
+  /**
+   * Puts an object `obj` at the end of array, returns the index of it.
+   * <p/>
+   * The array will grow as needed.
+   */
+  public int put(Object obj) {
+    if (numObj >= objects.length) {
+      Object[] tmp = new Object[objects.length * 2];
+      System.arraycopy(objects, 0, tmp, 0, objects.length);
+      objects = tmp;
+    }
+    objects[numObj++] = obj;
+    return numObj - 1;
+  }
+
+  /**
+   * Replaces the object at `idx` with new one `obj`.
+   */
+  public void replace(int idx, Object obj) {
+    assert (idx < numObj);
+    objects[idx] = obj;
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/ed359de5/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/UniqueObjectPool.java
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/UniqueObjectPool.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/UniqueObjectPool.java
new file mode 100644
index 0000000..d512392
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/UniqueObjectPool.java
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.util;
+
+import java.util.HashMap;
+
+/**
+ * An unique object pool stores a collection of unique objects in it.
+ */
+public class UniqueObjectPool extends ObjectPool {
+
+  /**
+   * A hash map from objects to their indexes in the array.
+   */
+  private HashMap<Object, Integer> objIndex;
+
+  public UniqueObjectPool(int capacity) {
+    super(capacity);
+    objIndex = new HashMap<Object, Integer>();
+  }
+
+  /**
+   * Put an object `obj` into the pool. If there is an existing object equals to `obj`, it will
+   * return the index of the existing one.
+   */
+  @Override
+  public int put(Object obj) {
+    if (objIndex.containsKey(obj)) {
+      return objIndex.get(obj);
+    } else {
+      int idx = super.put(obj);
+      objIndex.put(obj, idx);
+      return idx;
+    }
+  }
+
+  /**
+   * The objects can not be replaced.
+   */
+  @Override
+  public void replace(int idx, Object obj) {
+    throw new UnsupportedOperationException();
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/ed359de5/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala
index 61a29c8..57de0f2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala
@@ -28,7 +28,10 @@ import org.apache.spark.unsafe.types.UTF8String
 abstract class InternalRow extends Row {
 
   // This is only use for test
-  override def getString(i: Int): String = getAs[UTF8String](i).toString
+  override def getString(i: Int): String = {
+    val str = getAs[UTF8String](i)
+    if (str != null) str.toString else null
+  }
 
   // These expensive API should not be used internally.
   final override def getDecimal(i: Int): java.math.BigDecimal =

http://git-wip-us.apache.org/repos/asf/spark/blob/ed359de5/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala
index b61d490..b11fc24 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala
@@ -17,6 +17,7 @@
 
 package org.apache.spark.sql.catalyst.expressions
 
+import org.apache.spark.sql.catalyst.util.ObjectPool
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.PlatformDependent
 import org.apache.spark.unsafe.array.ByteArrayMethods
@@ -33,6 +34,8 @@ class UnsafeRowConverter(fieldTypes: Array[DataType]) {
     this(schema.fields.map(_.dataType))
   }
 
+  def numFields: Int = fieldTypes.length
+
   /** Re-used pointer to the unsafe row being written */
   private[this] val unsafeRow = new UnsafeRow()
 
@@ -68,8 +71,8 @@ class UnsafeRowConverter(fieldTypes: Array[DataType]) {
    * @param baseOffset the base offset of the destination address
    * @return the number of bytes written. This should be equal to `getSizeRequirement(row)`.
    */
-  def writeRow(row: InternalRow, baseObject: Object, baseOffset: Long): Int = {
-    unsafeRow.pointTo(baseObject, baseOffset, writers.length, null)
+  def writeRow(row: InternalRow, baseObject: Object, baseOffset: Long, pool: ObjectPool): Int = {
+    unsafeRow.pointTo(baseObject, baseOffset, writers.length, pool)
 
     if (writers.length > 0) {
       // zero-out the bitset
@@ -84,16 +87,16 @@ class UnsafeRowConverter(fieldTypes: Array[DataType]) {
     }
 
     var fieldNumber = 0
-    var appendCursor: Int = fixedLengthSize
+    var cursor: Int = fixedLengthSize
     while (fieldNumber < writers.length) {
       if (row.isNullAt(fieldNumber)) {
         unsafeRow.setNullAt(fieldNumber)
       } else {
-        appendCursor += writers(fieldNumber).write(row, unsafeRow, fieldNumber, appendCursor)
+        cursor += writers(fieldNumber).write(row, unsafeRow, fieldNumber, cursor)
       }
       fieldNumber += 1
     }
-    appendCursor
+    cursor
   }
 
 }
@@ -108,11 +111,11 @@ private abstract class UnsafeColumnWriter {
    * @param source the row being converted
    * @param target a pointer to the converted unsafe row
    * @param column the column to write
-   * @param appendCursor the offset from the start of the unsafe row to the end of the row;
+   * @param cursor the offset from the start of the unsafe row to the end of the row;
    *                     used for calculating where variable-length data should be written
    * @return the number of variable-length bytes written
    */
-  def write(source: InternalRow, target: UnsafeRow, column: Int, appendCursor: Int): Int
+  def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int
 
   /**
    * Return the number of bytes that are needed to write this variable-length value.
@@ -134,8 +137,7 @@ private object UnsafeColumnWriter {
       case DoubleType => DoubleUnsafeColumnWriter
       case StringType => StringUnsafeColumnWriter
       case BinaryType => BinaryUnsafeColumnWriter
-      case t =>
-        throw new UnsupportedOperationException(s"Do not know how to write columns of type $t")
+      case t => ObjectUnsafeColumnWriter
     }
   }
 }
@@ -152,6 +154,7 @@ private object FloatUnsafeColumnWriter extends FloatUnsafeColumnWriter
 private object DoubleUnsafeColumnWriter extends DoubleUnsafeColumnWriter
 private object StringUnsafeColumnWriter extends StringUnsafeColumnWriter
 private object BinaryUnsafeColumnWriter extends BinaryUnsafeColumnWriter
+private object ObjectUnsafeColumnWriter extends ObjectUnsafeColumnWriter
 
 private abstract class PrimitiveUnsafeColumnWriter extends UnsafeColumnWriter {
   // Primitives don't write to the variable-length region:
@@ -159,88 +162,56 @@ private abstract class PrimitiveUnsafeColumnWriter extends UnsafeColumnWriter {
 }
 
 private class NullUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter {
-  override def write(
-      source: InternalRow,
-      target: UnsafeRow,
-      column: Int,
-      appendCursor: Int): Int = {
+  override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = {
     target.setNullAt(column)
     0
   }
 }
 
 private class BooleanUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter {
-  override def write(
-      source: InternalRow,
-      target: UnsafeRow,
-      column: Int,
-      appendCursor: Int): Int = {
+  override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = {
     target.setBoolean(column, source.getBoolean(column))
     0
   }
 }
 
 private class ByteUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter {
-  override def write(
-      source: InternalRow,
-      target: UnsafeRow,
-      column: Int,
-      appendCursor: Int): Int = {
+  override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = {
     target.setByte(column, source.getByte(column))
     0
   }
 }
 
 private class ShortUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter {
-  override def write(
-      source: InternalRow,
-      target: UnsafeRow,
-      column: Int,
-      appendCursor: Int): Int = {
+  override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = {
     target.setShort(column, source.getShort(column))
     0
   }
 }
 
 private class IntUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter {
-  override def write(
-      source: InternalRow,
-      target: UnsafeRow,
-      column: Int,
-      appendCursor: Int): Int = {
+  override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = {
     target.setInt(column, source.getInt(column))
     0
   }
 }
 
 private class LongUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter {
-  override def write(
-      source: InternalRow,
-      target: UnsafeRow,
-      column: Int,
-      appendCursor: Int): Int = {
+  override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = {
     target.setLong(column, source.getLong(column))
     0
   }
 }
 
 private class FloatUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter {
-  override def write(
-      source: InternalRow,
-      target: UnsafeRow,
-      column: Int,
-      appendCursor: Int): Int = {
+  override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = {
     target.setFloat(column, source.getFloat(column))
     0
   }
 }
 
 private class DoubleUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter {
-  override def write(
-      source: InternalRow,
-      target: UnsafeRow,
-      column: Int,
-      appendCursor: Int): Int = {
+  override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = {
     target.setDouble(column, source.getDouble(column))
     0
   }
@@ -255,12 +226,10 @@ private abstract class BytesUnsafeColumnWriter extends UnsafeColumnWriter {
     ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes)
   }
 
-  override def write(
-      source: InternalRow,
-      target: UnsafeRow,
-      column: Int,
-      appendCursor: Int): Int = {
-    val offset = target.getBaseOffset + appendCursor
+  protected[this] def isString: Boolean
+
+  override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = {
+    val offset = target.getBaseOffset + cursor
     val bytes = getBytes(source, column)
     val numBytes = bytes.length
     if ((numBytes & 0x07) > 0) {
@@ -274,19 +243,32 @@ private abstract class BytesUnsafeColumnWriter extends UnsafeColumnWriter {
       offset,
       numBytes
     )
-    target.setLong(column, (appendCursor.toLong << 32L) | numBytes.toLong)
+    val flag = if (isString) 1L << (UnsafeRow.OFFSET_BITS * 2) else 0
+    target.setLong(column, flag | (cursor.toLong << UnsafeRow.OFFSET_BITS) | numBytes.toLong)
     ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes)
   }
 }
 
 private class StringUnsafeColumnWriter private() extends BytesUnsafeColumnWriter {
+  protected[this] def isString: Boolean = true
   def getBytes(source: InternalRow, column: Int): Array[Byte] = {
     source.getAs[UTF8String](column).getBytes
   }
 }
 
 private class BinaryUnsafeColumnWriter private() extends BytesUnsafeColumnWriter {
+  protected[this] def isString: Boolean = false
   def getBytes(source: InternalRow, column: Int): Array[Byte] = {
     source.getAs[Array[Byte]](column)
   }
 }
+
+private class ObjectUnsafeColumnWriter private() extends UnsafeColumnWriter {
+  def getSize(sourceRow: InternalRow, column: Int): Int = 0
+  override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = {
+    val obj = source.get(column)
+    val idx = target.getPool.put(obj)
+    target.setLong(column, - idx)
+    0
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/ed359de5/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala
index 3095ccb..6fafc2f 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala
@@ -23,8 +23,9 @@ import scala.util.Random
 import org.scalatest.{BeforeAndAfterEach, Matchers}
 
 import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.expressions.codegen.GenerateProjection
 import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, TaskMemoryManager, MemoryAllocator}
+import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager}
 import org.apache.spark.unsafe.types.UTF8String
 
 
@@ -33,10 +34,10 @@ class UnsafeFixedWidthAggregationMapSuite
   with Matchers
   with BeforeAndAfterEach {
 
-  import UnsafeFixedWidthAggregationMap._
-
   private val groupKeySchema = StructType(StructField("product", StringType) :: Nil)
   private val aggBufferSchema = StructType(StructField("salePrice", IntegerType) :: Nil)
+  private def emptyProjection: Projection =
+    GenerateProjection.generate(Seq(Literal(0)), Seq(AttributeReference("price", IntegerType)()))
   private def emptyAggregationBuffer: InternalRow = InternalRow(0)
 
   private var memoryManager: TaskMemoryManager = null
@@ -52,21 +53,11 @@ class UnsafeFixedWidthAggregationMapSuite
     }
   }
 
-  test("supported schemas") {
-    assert(!supportsAggregationBufferSchema(StructType(StructField("x", StringType) :: Nil)))
-    assert(supportsGroupKeySchema(StructType(StructField("x", StringType) :: Nil)))
-
-    assert(
-      !supportsAggregationBufferSchema(StructType(StructField("x", ArrayType(IntegerType)) :: Nil)))
-    assert(
-      !supportsGroupKeySchema(StructType(StructField("x", ArrayType(IntegerType)) :: Nil)))
-  }
-
   test("empty map") {
     val map = new UnsafeFixedWidthAggregationMap(
-      emptyAggregationBuffer,
-      aggBufferSchema,
-      groupKeySchema,
+      emptyProjection,
+      new UnsafeRowConverter(groupKeySchema),
+      new UnsafeRowConverter(aggBufferSchema),
       memoryManager,
       1024, // initial capacity
       false // disable perf metrics
@@ -77,9 +68,9 @@ class UnsafeFixedWidthAggregationMapSuite
 
   test("updating values for a single key") {
     val map = new UnsafeFixedWidthAggregationMap(
-      emptyAggregationBuffer,
-      aggBufferSchema,
-      groupKeySchema,
+      emptyProjection,
+      new UnsafeRowConverter(groupKeySchema),
+      new UnsafeRowConverter(aggBufferSchema),
       memoryManager,
       1024, // initial capacity
       false // disable perf metrics
@@ -103,9 +94,9 @@ class UnsafeFixedWidthAggregationMapSuite
 
   test("inserting large random keys") {
     val map = new UnsafeFixedWidthAggregationMap(
-      emptyAggregationBuffer,
-      aggBufferSchema,
-      groupKeySchema,
+      emptyProjection,
+      new UnsafeRowConverter(groupKeySchema),
+      new UnsafeRowConverter(aggBufferSchema),
       memoryManager,
       128, // initial capacity
       false // disable perf metrics
@@ -120,6 +111,36 @@ class UnsafeFixedWidthAggregationMapSuite
     }.toSet
     seenKeys.size should be (groupKeys.size)
     seenKeys should be (groupKeys)
+
+    map.free()
+  }
+
+  test("with decimal in the key and values") {
+    val groupKeySchema = StructType(StructField("price", DecimalType(10, 0)) :: Nil)
+    val aggBufferSchema = StructType(StructField("amount", DecimalType.Unlimited) :: Nil)
+    val emptyProjection = GenerateProjection.generate(Seq(Literal(Decimal(0))),
+      Seq(AttributeReference("price", DecimalType.Unlimited)()))
+    val map = new UnsafeFixedWidthAggregationMap(
+      emptyProjection,
+      new UnsafeRowConverter(groupKeySchema),
+      new UnsafeRowConverter(aggBufferSchema),
+      memoryManager,
+      1, // initial capacity
+      false // disable perf metrics
+    )
+
+    (0 until 100).foreach { i =>
+      val groupKey = InternalRow(Decimal(i % 10))
+      val row = map.getAggregationBuffer(groupKey)
+      row.update(0, Decimal(i))
+    }
+    val seenKeys: Set[Int] = map.iterator().asScala.map { entry =>
+      entry.key.getAs[Decimal](0).toInt
+    }.toSet
+    seenKeys.size should be (10)
+    seenKeys should be ((0 until 10).toSet)
+
+    map.free()
   }
 
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/ed359de5/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
index c0675f4..94c2f32 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
@@ -23,10 +23,11 @@ import java.util.Arrays
 import org.scalatest.Matchers
 
 import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.catalyst.util.DateTimeUtils
+import org.apache.spark.sql.catalyst.util.{ObjectPool, DateTimeUtils}
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.PlatformDependent
 import org.apache.spark.unsafe.array.ByteArrayMethods
+import org.apache.spark.unsafe.types.UTF8String
 
 class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
 
@@ -40,16 +41,21 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
     row.setInt(2, 2)
 
     val sizeRequired: Int = converter.getSizeRequirement(row)
-    sizeRequired should be (8 + (3 * 8))
+    assert(sizeRequired === 8 + (3 * 8))
     val buffer: Array[Long] = new Array[Long](sizeRequired / 8)
-    val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET)
-    numBytesWritten should be (sizeRequired)
+    val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, null)
+    assert(numBytesWritten === sizeRequired)
 
     val unsafeRow = new UnsafeRow()
     unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null)
-    unsafeRow.getLong(0) should be (0)
-    unsafeRow.getLong(1) should be (1)
-    unsafeRow.getInt(2) should be (2)
+    assert(unsafeRow.getLong(0) === 0)
+    assert(unsafeRow.getLong(1) === 1)
+    assert(unsafeRow.getInt(2) === 2)
+
+    unsafeRow.setLong(1, 3)
+    assert(unsafeRow.getLong(1) === 3)
+    unsafeRow.setInt(2, 4)
+    assert(unsafeRow.getInt(2) === 4)
   }
 
   test("basic conversion with primitive, string and binary types") {
@@ -58,22 +64,67 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
 
     val row = new SpecificMutableRow(fieldTypes)
     row.setLong(0, 0)
-    row.setString(1, "Hello")
+    row.update(1, UTF8String.fromString("Hello"))
     row.update(2, "World".getBytes)
 
     val sizeRequired: Int = converter.getSizeRequirement(row)
-    sizeRequired should be (8 + (8 * 3) +
+    assert(sizeRequired === 8 + (8 * 3) +
       ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length) +
       ByteArrayMethods.roundNumberOfBytesToNearestWord("World".getBytes.length))
     val buffer: Array[Long] = new Array[Long](sizeRequired / 8)
-    val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET)
-    numBytesWritten should be (sizeRequired)
+    val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, null)
+    assert(numBytesWritten === sizeRequired)
 
     val unsafeRow = new UnsafeRow()
-    unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null)
-    unsafeRow.getLong(0) should be (0)
-    unsafeRow.getString(1) should be ("Hello")
-    unsafeRow.getBinary(2) should be ("World".getBytes)
+    val pool = new ObjectPool(10)
+    unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, pool)
+    assert(unsafeRow.getLong(0) === 0)
+    assert(unsafeRow.getString(1) === "Hello")
+    assert(unsafeRow.get(2) === "World".getBytes)
+
+    unsafeRow.update(1, UTF8String.fromString("World"))
+    assert(unsafeRow.getString(1) === "World")
+    assert(pool.size === 0)
+    unsafeRow.update(1, UTF8String.fromString("Hello World"))
+    assert(unsafeRow.getString(1) === "Hello World")
+    assert(pool.size === 1)
+
+    unsafeRow.update(2, "World".getBytes)
+    assert(unsafeRow.get(2) === "World".getBytes)
+    assert(pool.size === 1)
+    unsafeRow.update(2, "Hello World".getBytes)
+    assert(unsafeRow.get(2) === "Hello World".getBytes)
+    assert(pool.size === 2)
+  }
+
+  test("basic conversion with primitive, decimal and array") {
+    val fieldTypes: Array[DataType] = Array(LongType, DecimalType(10, 0), ArrayType(StringType))
+    val converter = new UnsafeRowConverter(fieldTypes)
+
+    val row = new SpecificMutableRow(fieldTypes)
+    row.setLong(0, 0)
+    row.update(1, Decimal(1))
+    row.update(2, Array(2))
+
+    val pool = new ObjectPool(10)
+    val sizeRequired: Int = converter.getSizeRequirement(row)
+    assert(sizeRequired === 8 + (8 * 3))
+    val buffer: Array[Long] = new Array[Long](sizeRequired / 8)
+    val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, pool)
+    assert(numBytesWritten === sizeRequired)
+    assert(pool.size === 2)
+
+    val unsafeRow = new UnsafeRow()
+    unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, pool)
+    assert(unsafeRow.getLong(0) === 0)
+    assert(unsafeRow.get(1) === Decimal(1))
+    assert(unsafeRow.get(2) === Array(2))
+
+    unsafeRow.update(1, Decimal(2))
+    assert(unsafeRow.get(1) === Decimal(2))
+    unsafeRow.update(2, Array(3, 4))
+    assert(unsafeRow.get(2) === Array(3, 4))
+    assert(pool.size === 2)
   }
 
   test("basic conversion with primitive, string, date and timestamp types") {
@@ -87,21 +138,27 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
     row.update(3, DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2015-05-08 08:10:25")))
 
     val sizeRequired: Int = converter.getSizeRequirement(row)
-    sizeRequired should be (8 + (8 * 4) +
+    assert(sizeRequired === 8 + (8 * 4) +
       ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length))
     val buffer: Array[Long] = new Array[Long](sizeRequired / 8)
-    val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET)
-    numBytesWritten should be (sizeRequired)
+    val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, null)
+    assert(numBytesWritten === sizeRequired)
 
     val unsafeRow = new UnsafeRow()
     unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null)
-    unsafeRow.getLong(0) should be (0)
-    unsafeRow.getString(1) should be ("Hello")
+    assert(unsafeRow.getLong(0) === 0)
+    assert(unsafeRow.getString(1) === "Hello")
     // Date is represented as Int in unsafeRow
-    DateTimeUtils.toJavaDate(unsafeRow.getInt(2)) should be (Date.valueOf("1970-01-01"))
+    assert(DateTimeUtils.toJavaDate(unsafeRow.getInt(2)) === Date.valueOf("1970-01-01"))
     // Timestamp is represented as Long in unsafeRow
     DateTimeUtils.toJavaTimestamp(unsafeRow.getLong(3)) should be
       (Timestamp.valueOf("2015-05-08 08:10:25"))
+
+    unsafeRow.setInt(2, DateTimeUtils.fromJavaDate(Date.valueOf("2015-06-22")))
+    assert(DateTimeUtils.toJavaDate(unsafeRow.getInt(2)) === Date.valueOf("2015-06-22"))
+    unsafeRow.setLong(3, DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2015-06-22 08:10:25")))
+    DateTimeUtils.toJavaTimestamp(unsafeRow.getLong(3)) should be
+      (Timestamp.valueOf("2015-06-22 08:10:25"))
   }
 
   test("null handling") {
@@ -113,7 +170,12 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
       IntegerType,
       LongType,
       FloatType,
-      DoubleType)
+      DoubleType,
+      StringType,
+      BinaryType,
+      DecimalType.Unlimited,
+      ArrayType(IntegerType)
+    )
     val converter = new UnsafeRowConverter(fieldTypes)
 
     val rowWithAllNullColumns: InternalRow = {
@@ -127,8 +189,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
     val sizeRequired: Int = converter.getSizeRequirement(rowWithAllNullColumns)
     val createdFromNullBuffer: Array[Long] = new Array[Long](sizeRequired / 8)
     val numBytesWritten = converter.writeRow(
-      rowWithAllNullColumns, createdFromNullBuffer, PlatformDependent.LONG_ARRAY_OFFSET)
-    numBytesWritten should be (sizeRequired)
+      rowWithAllNullColumns, createdFromNullBuffer, PlatformDependent.LONG_ARRAY_OFFSET, null)
+    assert(numBytesWritten === sizeRequired)
 
     val createdFromNull = new UnsafeRow()
     createdFromNull.pointTo(
@@ -136,13 +198,17 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
     for (i <- 0 to fieldTypes.length - 1) {
       assert(createdFromNull.isNullAt(i))
     }
-    createdFromNull.getBoolean(1) should be (false)
-    createdFromNull.getByte(2) should be (0)
-    createdFromNull.getShort(3) should be (0)
-    createdFromNull.getInt(4) should be (0)
-    createdFromNull.getLong(5) should be (0)
+    assert(createdFromNull.getBoolean(1) === false)
+    assert(createdFromNull.getByte(2) === 0)
+    assert(createdFromNull.getShort(3) === 0)
+    assert(createdFromNull.getInt(4) === 0)
+    assert(createdFromNull.getLong(5) === 0)
     assert(java.lang.Float.isNaN(createdFromNull.getFloat(6)))
-    assert(java.lang.Double.isNaN(createdFromNull.getFloat(7)))
+    assert(java.lang.Double.isNaN(createdFromNull.getDouble(7)))
+    assert(createdFromNull.getString(8) === null)
+    assert(createdFromNull.get(9) === null)
+    assert(createdFromNull.get(10) === null)
+    assert(createdFromNull.get(11) === null)
 
     // If we have an UnsafeRow with columns that are initially non-null and we null out those
     // columns, then the serialized row representation should be identical to what we would get by
@@ -157,28 +223,68 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
       r.setLong(5, 500)
       r.setFloat(6, 600)
       r.setDouble(7, 700)
+      r.update(8, UTF8String.fromString("hello"))
+      r.update(9, "world".getBytes)
+      r.update(10, Decimal(10))
+      r.update(11, Array(11))
       r
     }
-    val setToNullAfterCreationBuffer: Array[Long] = new Array[Long](sizeRequired / 8)
+    val pool = new ObjectPool(1)
+    val setToNullAfterCreationBuffer: Array[Long] = new Array[Long](sizeRequired / 8 + 2)
     converter.writeRow(
-      rowWithNoNullColumns, setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET)
+      rowWithNoNullColumns, setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET, pool)
     val setToNullAfterCreation = new UnsafeRow()
     setToNullAfterCreation.pointTo(
-      setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null)
+      setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, pool)
 
-    setToNullAfterCreation.isNullAt(0) should be (rowWithNoNullColumns.isNullAt(0))
-    setToNullAfterCreation.getBoolean(1) should be (rowWithNoNullColumns.getBoolean(1))
-    setToNullAfterCreation.getByte(2) should be (rowWithNoNullColumns.getByte(2))
-    setToNullAfterCreation.getShort(3) should be (rowWithNoNullColumns.getShort(3))
-    setToNullAfterCreation.getInt(4) should be (rowWithNoNullColumns.getInt(4))
-    setToNullAfterCreation.getLong(5) should be (rowWithNoNullColumns.getLong(5))
-    setToNullAfterCreation.getFloat(6) should be (rowWithNoNullColumns.getFloat(6))
-    setToNullAfterCreation.getDouble(7) should be (rowWithNoNullColumns.getDouble(7))
+    assert(setToNullAfterCreation.isNullAt(0) === rowWithNoNullColumns.isNullAt(0))
+    assert(setToNullAfterCreation.getBoolean(1) === rowWithNoNullColumns.getBoolean(1))
+    assert(setToNullAfterCreation.getByte(2) === rowWithNoNullColumns.getByte(2))
+    assert(setToNullAfterCreation.getShort(3) === rowWithNoNullColumns.getShort(3))
+    assert(setToNullAfterCreation.getInt(4) === rowWithNoNullColumns.getInt(4))
+    assert(setToNullAfterCreation.getLong(5) === rowWithNoNullColumns.getLong(5))
+    assert(setToNullAfterCreation.getFloat(6) === rowWithNoNullColumns.getFloat(6))
+    assert(setToNullAfterCreation.getDouble(7) === rowWithNoNullColumns.getDouble(7))
+    assert(setToNullAfterCreation.getString(8) === rowWithNoNullColumns.getString(8))
+    assert(setToNullAfterCreation.get(9) === rowWithNoNullColumns.get(9))
+    assert(setToNullAfterCreation.get(10) === rowWithNoNullColumns.get(10))
+    assert(setToNullAfterCreation.get(11) === rowWithNoNullColumns.get(11))
 
     for (i <- 0 to fieldTypes.length - 1) {
+      if (i >= 8) {
+        setToNullAfterCreation.update(i, null)
+      }
       setToNullAfterCreation.setNullAt(i)
     }
-    assert(Arrays.equals(createdFromNullBuffer, setToNullAfterCreationBuffer))
+    // There are some garbage left in the var-length area
+    assert(Arrays.equals(createdFromNullBuffer,
+      java.util.Arrays.copyOf(setToNullAfterCreationBuffer, sizeRequired / 8)))
+
+    setToNullAfterCreation.setNullAt(0)
+    setToNullAfterCreation.setBoolean(1, false)
+    setToNullAfterCreation.setByte(2, 20)
+    setToNullAfterCreation.setShort(3, 30)
+    setToNullAfterCreation.setInt(4, 400)
+    setToNullAfterCreation.setLong(5, 500)
+    setToNullAfterCreation.setFloat(6, 600)
+    setToNullAfterCreation.setDouble(7, 700)
+    setToNullAfterCreation.update(8, UTF8String.fromString("hello"))
+    setToNullAfterCreation.update(9, "world".getBytes)
+    setToNullAfterCreation.update(10, Decimal(10))
+    setToNullAfterCreation.update(11, Array(11))
+
+    assert(setToNullAfterCreation.isNullAt(0) === rowWithNoNullColumns.isNullAt(0))
+    assert(setToNullAfterCreation.getBoolean(1) === rowWithNoNullColumns.getBoolean(1))
+    assert(setToNullAfterCreation.getByte(2) === rowWithNoNullColumns.getByte(2))
+    assert(setToNullAfterCreation.getShort(3) === rowWithNoNullColumns.getShort(3))
+    assert(setToNullAfterCreation.getInt(4) === rowWithNoNullColumns.getInt(4))
+    assert(setToNullAfterCreation.getLong(5) === rowWithNoNullColumns.getLong(5))
+    assert(setToNullAfterCreation.getFloat(6) === rowWithNoNullColumns.getFloat(6))
+    assert(setToNullAfterCreation.getDouble(7) === rowWithNoNullColumns.getDouble(7))
+    assert(setToNullAfterCreation.getString(8) === rowWithNoNullColumns.getString(8))
+    assert(setToNullAfterCreation.get(9) === rowWithNoNullColumns.get(9))
+    assert(setToNullAfterCreation.get(10) === rowWithNoNullColumns.get(10))
+    assert(setToNullAfterCreation.get(11) === rowWithNoNullColumns.get(11))
   }
 
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/ed359de5/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ObjectPoolSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ObjectPoolSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ObjectPoolSuite.scala
new file mode 100644
index 0000000..94764df
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ObjectPoolSuite.scala
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.util
+
+import org.scalatest.Matchers
+
+import org.apache.spark.SparkFunSuite
+
+class ObjectPoolSuite extends SparkFunSuite with Matchers {
+
+  test("pool") {
+    val pool = new ObjectPool(1)
+    assert(pool.put(1) === 0)
+    assert(pool.put("hello") === 1)
+    assert(pool.put(false) === 2)
+
+    assert(pool.get(0) === 1)
+    assert(pool.get(1) === "hello")
+    assert(pool.get(2) === false)
+    assert(pool.size() === 3)
+
+    pool.replace(1, "world")
+    assert(pool.get(1) === "world")
+    assert(pool.size() === 3)
+  }
+
+  test("unique pool") {
+    val pool = new UniqueObjectPool(1)
+    assert(pool.put(1) === 0)
+    assert(pool.put("hello") === 1)
+    assert(pool.put(1) === 0)
+    assert(pool.put("hello") === 1)
+
+    assert(pool.get(0) === 1)
+    assert(pool.get(1) === "hello")
+    assert(pool.size() === 2)
+
+    intercept[UnsupportedOperationException] {
+      pool.replace(1, "world")
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/ed359de5/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
index ba2c8f5..44930f8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
@@ -238,11 +238,6 @@ case class GeneratedAggregate(
       StructType(fields)
     }
 
-    val schemaSupportsUnsafe: Boolean = {
-      UnsafeFixedWidthAggregationMap.supportsAggregationBufferSchema(aggregationBufferSchema) &&
-        UnsafeFixedWidthAggregationMap.supportsGroupKeySchema(groupKeySchema)
-    }
-
     child.execute().mapPartitions { iter =>
       // Builds a new custom class for holding the results of aggregation for a group.
       val initialValues = computeFunctions.flatMap(_.initialValues)
@@ -283,12 +278,12 @@ case class GeneratedAggregate(
 
         val resultProjection = resultProjectionBuilder()
         Iterator(resultProjection(buffer))
-      } else if (unsafeEnabled && schemaSupportsUnsafe) {
+      } else if (unsafeEnabled) {
         log.info("Using Unsafe-based aggregator")
         val aggregationMap = new UnsafeFixedWidthAggregationMap(
-          newAggregationBuffer(EmptyRow),
-          aggregationBufferSchema,
-          groupKeySchema,
+          newAggregationBuffer,
+          new UnsafeRowConverter(groupKeySchema),
+          new UnsafeRowConverter(aggregationBufferSchema),
           TaskContext.get.taskMemoryManager(),
           1024 * 16, // initial capacity
           false // disable tracking of performance metrics
@@ -323,9 +318,6 @@ case class GeneratedAggregate(
           }
         }
       } else {
-        if (unsafeEnabled) {
-          log.info("Not using Unsafe-based aggregator because it is not supported for this schema")
-        }
         val buffers = new java.util.HashMap[InternalRow, MutableRow]()
 
         var currentRow: InternalRow = null


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org