You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2016/12/29 02:59:46 UTC

spark git commit: [SPARK-16213][SQL] Reduce runtime overhead of a program that creates an primitive array in DataFrame

Repository: spark
Updated Branches:
  refs/heads/master 092c6725b -> 93f35569f


[SPARK-16213][SQL] Reduce runtime overhead of a program that creates an primitive array in DataFrame

## What changes were proposed in this pull request?

This PR reduces runtime overhead of a program the creates an primitive array in DataFrame by using the similar approach to #15044. Generated code performs boxing operation in an assignment from InternalRow to an `Object[]` temporary array (at Lines 051 and 061 in the generated code before without this PR). If we know that type of array elements is primitive, we apply the following optimizations:
1. Eliminate a pair of `isNullAt()` and a null assignment
2. Allocate an primitive array instead of `Object[]` (eliminate boxing operations)
3. Create `UnsafeArrayData` by using `UnsafeArrayWriter` to keep a primitive array in a row format instead of doing non-lightweight operations in constructor of `GenericArrayData`
The PR also performs the same things for `CreateMap`.

Here are performance results of [DataFrame programs](https://github.com/kiszk/spark/blob/6bf54ec5e227689d69f6db991e9ecbc54e153d0a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/PrimitiveArrayBenchmark.scala#L83-L112) by up to 17.9x over without this PR.

```
Without SPARK-16043
OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 4.4.11-200.fc22.x86_64
Intel Xeon E3-12xx v2 (Ivy Bridge)
Read a primitive array in DataFrame:     Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
------------------------------------------------------------------------------------------------
Int                                           3805 / 4150          0.0      507308.9       1.0X
Double                                        3593 / 3852          0.0      479056.9       1.1X

With SPARK-16043
Read a primitive array in DataFrame:     Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
------------------------------------------------------------------------------------------------
Int                                            213 /  271          0.0       28387.5       1.0X
Double                                         204 /  223          0.0       27250.9       1.0X
```
Note : #15780 is enabled for these measurements

An motivating example

``` java
val df = sparkContext.parallelize(Seq(0.0d, 1.0d), 1).toDF
df.selectExpr("Array(value + 1.1d, value + 2.2d)").show
```

Generated code without this PR

``` java
/* 005 */ final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator {
/* 006 */   private Object[] references;
/* 007 */   private scala.collection.Iterator[] inputs;
/* 008 */   private scala.collection.Iterator inputadapter_input;
/* 009 */   private UnsafeRow serializefromobject_result;
/* 010 */   private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder serializefromobject_holder;
/* 011 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter serializefromobject_rowWriter;
/* 012 */   private Object[] project_values;
/* 013 */   private UnsafeRow project_result;
/* 014 */   private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder project_holder;
/* 015 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter project_rowWriter;
/* 016 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter project_arrayWriter;
/* 017 */
/* 018 */   public GeneratedIterator(Object[] references) {
/* 019 */     this.references = references;
/* 020 */   }
/* 021 */
/* 022 */   public void init(int index, scala.collection.Iterator[] inputs) {
/* 023 */     partitionIndex = index;
/* 024 */     this.inputs = inputs;
/* 025 */     inputadapter_input = inputs[0];
/* 026 */     serializefromobject_result = new UnsafeRow(1);
/* 027 */     this.serializefromobject_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(serializefromobject_result, 0);
/* 028 */     this.serializefromobject_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(serializefromobject_holder, 1);
/* 029 */     this.project_values = null;
/* 030 */     project_result = new UnsafeRow(1);
/* 031 */     this.project_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(project_result, 32);
/* 032 */     this.project_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(project_holder, 1);
/* 033 */     this.project_arrayWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter();
/* 034 */
/* 035 */   }
/* 036 */
/* 037 */   protected void processNext() throws java.io.IOException {
/* 038 */     while (inputadapter_input.hasNext()) {
/* 039 */       InternalRow inputadapter_row = (InternalRow) inputadapter_input.next();
/* 040 */       double inputadapter_value = inputadapter_row.getDouble(0);
/* 041 */
/* 042 */       final boolean project_isNull = false;
/* 043 */       this.project_values = new Object[2];
/* 044 */       boolean project_isNull1 = false;
/* 045 */
/* 046 */       double project_value1 = -1.0;
/* 047 */       project_value1 = inputadapter_value + 1.1D;
/* 048 */       if (false) {
/* 049 */         project_values[0] = null;
/* 050 */       } else {
/* 051 */         project_values[0] = project_value1;
/* 052 */       }
/* 053 */
/* 054 */       boolean project_isNull4 = false;
/* 055 */
/* 056 */       double project_value4 = -1.0;
/* 057 */       project_value4 = inputadapter_value + 2.2D;
/* 058 */       if (false) {
/* 059 */         project_values[1] = null;
/* 060 */       } else {
/* 061 */         project_values[1] = project_value4;
/* 062 */       }
/* 063 */
/* 064 */       final ArrayData project_value = new org.apache.spark.sql.catalyst.util.GenericArrayData(project_values);
/* 065 */       this.project_values = null;
/* 066 */       project_holder.reset();
/* 067 */
/* 068 */       project_rowWriter.zeroOutNullBytes();
/* 069 */
/* 070 */       if (project_isNull) {
/* 071 */         project_rowWriter.setNullAt(0);
/* 072 */       } else {
/* 073 */         // Remember the current cursor so that we can calculate how many bytes are
/* 074 */         // written later.
/* 075 */         final int project_tmpCursor = project_holder.cursor;
/* 076 */
/* 077 */         if (project_value instanceof UnsafeArrayData) {
/* 078 */           final int project_sizeInBytes = ((UnsafeArrayData) project_value).getSizeInBytes();
/* 079 */           // grow the global buffer before writing data.
/* 080 */           project_holder.grow(project_sizeInBytes);
/* 081 */           ((UnsafeArrayData) project_value).writeToMemory(project_holder.buffer, project_holder.cursor);
/* 082 */           project_holder.cursor += project_sizeInBytes;
/* 083 */
/* 084 */         } else {
/* 085 */           final int project_numElements = project_value.numElements();
/* 086 */           project_arrayWriter.initialize(project_holder, project_numElements, 8);
/* 087 */
/* 088 */           for (int project_index = 0; project_index < project_numElements; project_index++) {
/* 089 */             if (project_value.isNullAt(project_index)) {
/* 090 */               project_arrayWriter.setNullDouble(project_index);
/* 091 */             } else {
/* 092 */               final double project_element = project_value.getDouble(project_index);
/* 093 */               project_arrayWriter.write(project_index, project_element);
/* 094 */             }
/* 095 */           }
/* 096 */         }
/* 097 */
/* 098 */         project_rowWriter.setOffsetAndSize(0, project_tmpCursor, project_holder.cursor - project_tmpCursor);
/* 099 */       }
/* 100 */       project_result.setTotalSize(project_holder.totalSize());
/* 101 */       append(project_result);
/* 102 */       if (shouldStop()) return;
/* 103 */     }
/* 104 */   }
/* 105 */ }
```

Generated code with this PR

``` java
/* 005 */ final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator {
/* 006 */   private Object[] references;
/* 007 */   private scala.collection.Iterator[] inputs;
/* 008 */   private scala.collection.Iterator inputadapter_input;
/* 009 */   private UnsafeRow serializefromobject_result;
/* 010 */   private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder serializefromobject_holder;
/* 011 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter serializefromobject_rowWriter;
/* 012 */   private UnsafeArrayData project_arrayData;
/* 013 */   private UnsafeRow project_result;
/* 014 */   private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder project_holder;
/* 015 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter project_rowWriter;
/* 016 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter project_arrayWriter;
/* 017 */
/* 018 */   public GeneratedIterator(Object[] references) {
/* 019 */     this.references = references;
/* 020 */   }
/* 021 */
/* 022 */   public void init(int index, scala.collection.Iterator[] inputs) {
/* 023 */     partitionIndex = index;
/* 024 */     this.inputs = inputs;
/* 025 */     inputadapter_input = inputs[0];
/* 026 */     serializefromobject_result = new UnsafeRow(1);
/* 027 */     this.serializefromobject_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(serializefromobject_result, 0);
/* 028 */     this.serializefromobject_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(serializefromobject_holder, 1);
/* 029 */
/* 030 */     project_result = new UnsafeRow(1);
/* 031 */     this.project_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(project_result, 32);
/* 032 */     this.project_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(project_holder, 1);
/* 033 */     this.project_arrayWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter();
/* 034 */
/* 035 */   }
/* 036 */
/* 037 */   protected void processNext() throws java.io.IOException {
/* 038 */     while (inputadapter_input.hasNext()) {
/* 039 */       InternalRow inputadapter_row = (InternalRow) inputadapter_input.next();
/* 040 */       double inputadapter_value = inputadapter_row.getDouble(0);
/* 041 */
/* 042 */       byte[] project_array = new byte[32];
/* 043 */       project_arrayData = new UnsafeArrayData();
/* 044 */       Platform.putLong(project_array, 16, 2);
/* 045 */       project_arrayData.pointTo(project_array, 16, 32);
/* 046 */
/* 047 */       boolean project_isNull1 = false;
/* 048 */
/* 049 */       double project_value1 = -1.0;
/* 050 */       project_value1 = inputadapter_value + 1.1D;
/* 051 */       if (false) {
/* 052 */         project_arrayData.setNullAt(0);
/* 053 */       } else {
/* 054 */         project_arrayData.setDouble(0, project_value1);
/* 055 */       }
/* 056 */
/* 057 */       boolean project_isNull4 = false;
/* 058 */
/* 059 */       double project_value4 = -1.0;
/* 060 */       project_value4 = inputadapter_value + 2.2D;
/* 061 */       if (false) {
/* 062 */         project_arrayData.setNullAt(1);
/* 063 */       } else {
/* 064 */         project_arrayData.setDouble(1, project_value4);
/* 065 */       }
/* 066 */       project_holder.reset();
/* 067 */
/* 068 */       // Remember the current cursor so that we can calculate how many bytes are
/* 069 */       // written later.
/* 070 */       final int project_tmpCursor = project_holder.cursor;
/* 071 */
/* 072 */       if (project_arrayData instanceof UnsafeArrayData) {
/* 073 */         final int project_sizeInBytes = ((UnsafeArrayData) project_arrayData).getSizeInBytes();
/* 074 */         // grow the global buffer before writing data.
/* 075 */         project_holder.grow(project_sizeInBytes);
/* 076 */         ((UnsafeArrayData) project_arrayData).writeToMemory(project_holder.buffer, project_holder.cursor);
/* 077 */         project_holder.cursor += project_sizeInBytes;
/* 078 */
/* 079 */       } else {
/* 080 */         final int project_numElements = project_arrayData.numElements();
/* 081 */         project_arrayWriter.initialize(project_holder, project_numElements, 8);
/* 082 */
/* 083 */         for (int project_index = 0; project_index < project_numElements; project_index++) {
/* 084 */           if (project_arrayData.isNullAt(project_index)) {
/* 085 */             project_arrayWriter.setNullDouble(project_index);
/* 086 */           } else {
/* 087 */             final double project_element = project_arrayData.getDouble(project_index);
/* 088 */             project_arrayWriter.write(project_index, project_element);
/* 089 */           }
/* 090 */         }
/* 091 */       }
/* 092 */
/* 093 */       project_rowWriter.setOffsetAndSize(0, project_tmpCursor, project_holder.cursor - project_tmpCursor);
/* 094 */       project_result.setTotalSize(project_holder.totalSize());
/* 095 */       append(project_result);
/* 096 */       if (shouldStop()) return;
/* 097 */     }
/* 098 */   }
/* 099 */ }
```
## How was this patch tested?

Added unit tests into `DataFrameComplexTypeSuite`

Author: Kazuaki Ishizaki <is...@jp.ibm.com>
Author: Liang-Chi Hsieh <vi...@gmail.com>

Closes #13909 from kiszk/SPARK-16213.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/93f35569
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/93f35569
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/93f35569

Branch: refs/heads/master
Commit: 93f35569fd4e7dc1e4037d3df538a21c526f9c5d
Parents: 092c672
Author: Kazuaki Ishizaki <is...@jp.ibm.com>
Authored: Thu Dec 29 10:59:37 2016 +0800
Committer: Wenchen Fan <we...@databricks.com>
Committed: Thu Dec 29 10:59:37 2016 +0800

----------------------------------------------------------------------
 .../catalyst/expressions/UnsafeArrayData.java   |  52 ++++++
 .../expressions/complexTypeCreator.scala        | 174 ++++++++++++-------
 .../spark/sql/catalyst/util/ArrayData.scala     |  13 ++
 .../sql/catalyst/util/GenericArrayData.scala    |   4 +
 .../expressions/CodeGenerationSuite.scala       |  34 ++--
 .../catalyst/expressions/ComplexTypeSuite.scala |   4 +
 .../expressions/ExpressionEvalHelper.scala      |  30 +++-
 .../sql/execution/vectorized/ColumnVector.java  |   6 +
 .../execution/ObjectHashAggregateSuite.scala    |   4 +-
 9 files changed, 230 insertions(+), 91 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/93f35569/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
index e8c3387..64ab01c 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
@@ -287,6 +287,58 @@ public final class UnsafeArrayData extends ArrayData {
     return map;
   }
 
+  @Override
+  public void update(int ordinal, Object value) { throw new UnsupportedOperationException(); }
+
+  public void setNullAt(int ordinal) {
+    assertIndexIsValid(ordinal);
+    BitSetMethods.set(baseObject, baseOffset + 8, ordinal);
+
+    /* we assume the corrresponding column was already 0 or
+       will be set to 0 later by the caller side */
+  }
+
+  public void setBoolean(int ordinal, boolean value) {
+    assertIndexIsValid(ordinal);
+    Platform.putBoolean(baseObject, getElementOffset(ordinal, 1), value);
+  }
+
+  public void setByte(int ordinal, byte value) {
+    assertIndexIsValid(ordinal);
+    Platform.putByte(baseObject, getElementOffset(ordinal, 1), value);
+  }
+
+  public void setShort(int ordinal, short value) {
+    assertIndexIsValid(ordinal);
+    Platform.putShort(baseObject, getElementOffset(ordinal, 2), value);
+  }
+
+  public void setInt(int ordinal, int value) {
+    assertIndexIsValid(ordinal);
+    Platform.putInt(baseObject, getElementOffset(ordinal, 4), value);
+  }
+
+  public void setLong(int ordinal, long value) {
+    assertIndexIsValid(ordinal);
+    Platform.putLong(baseObject, getElementOffset(ordinal, 8), value);
+  }
+
+  public void setFloat(int ordinal, float value) {
+    if (Float.isNaN(value)) {
+      value = Float.NaN;
+    }
+    assertIndexIsValid(ordinal);
+    Platform.putFloat(baseObject, getElementOffset(ordinal, 4), value);
+  }
+
+  public void setDouble(int ordinal, double value) {
+    if (Double.isNaN(value)) {
+      value = Double.NaN;
+    }
+    assertIndexIsValid(ordinal);
+    Platform.putDouble(baseObject, getElementOffset(ordinal, 8), value);
+  }
+
   // This `hashCode` computation could consume much processor time for large data.
   // If the computation becomes a bottleneck, we can use a light-weight logic; the first fixed bytes
   // are used to compute `hashCode` (See `Vector.hashCode`).

http://git-wip-us.apache.org/repos/asf/spark/blob/93f35569/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
index 599fb63..22277ad 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
@@ -19,11 +19,12 @@ package org.apache.spark.sql.catalyst.expressions
 
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
-import org.apache.spark.sql.catalyst.analysis.Star
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
 import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, TypeUtils}
 import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.Platform
+import org.apache.spark.unsafe.array.ByteArrayMethods
 import org.apache.spark.unsafe.types.UTF8String
 
 /**
@@ -43,7 +44,7 @@ case class CreateArray(children: Seq[Expression]) extends Expression {
   override def checkInputDataTypes(): TypeCheckResult =
     TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), "function array")
 
-  override def dataType: DataType = {
+  override def dataType: ArrayType = {
     ArrayType(
       children.headOption.map(_.dataType).getOrElse(NullType),
       containsNull = children.exists(_.nullable))
@@ -56,33 +57,99 @@ case class CreateArray(children: Seq[Expression]) extends Expression {
   }
 
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
-    val arrayClass = classOf[GenericArrayData].getName
-    val values = ctx.freshName("values")
-    ctx.addMutableState("Object[]", values, s"this.$values = null;")
-
-    ev.copy(code = s"""
-      this.$values = new Object[${children.size}];""" +
-      ctx.splitExpressions(
-        ctx.INPUT_ROW,
-        children.zipWithIndex.map { case (e, i) =>
-          val eval = e.genCode(ctx)
-          eval.code + s"""
-            if (${eval.isNull}) {
-              $values[$i] = null;
-            } else {
-              $values[$i] = ${eval.value};
-            }
-           """
-        }) +
-      s"""
-        final ArrayData ${ev.value} = new $arrayClass($values);
-        this.$values = null;
-      """, isNull = "false")
+    val et = dataType.elementType
+    val evals = children.map(e => e.genCode(ctx))
+    val (preprocess, assigns, postprocess, arrayData) =
+      GenArrayData.genCodeToCreateArrayData(ctx, et, evals, false)
+    ev.copy(
+      code = preprocess + ctx.splitExpressions(ctx.INPUT_ROW, assigns) + postprocess,
+      value = arrayData,
+      isNull = "false")
   }
 
   override def prettyName: String = "array"
 }
 
+private [sql] object GenArrayData {
+  /**
+   * Return Java code pieces based on DataType and isPrimitive to allocate ArrayData class
+   *
+   * @param ctx a [[CodegenContext]]
+   * @param elementType data type of underlying array elements
+   * @param elementsCode a set of [[ExprCode]] for each element of an underlying array
+   * @param isMapKey if true, throw an exception when the element is null
+   * @return (code pre-assignments, assignments to each array elements, code post-assignments,
+   *           arrayData name)
+   */
+  def genCodeToCreateArrayData(
+      ctx: CodegenContext,
+      elementType: DataType,
+      elementsCode: Seq[ExprCode],
+      isMapKey: Boolean): (String, Seq[String], String, String) = {
+    val arrayName = ctx.freshName("array")
+    val arrayDataName = ctx.freshName("arrayData")
+    val numElements = elementsCode.length
+
+    if (!ctx.isPrimitiveType(elementType)) {
+      val genericArrayClass = classOf[GenericArrayData].getName
+      ctx.addMutableState("Object[]", arrayName,
+        s"this.$arrayName = new Object[${numElements}];")
+
+      val assignments = elementsCode.zipWithIndex.map { case (eval, i) =>
+        val isNullAssignment = if (!isMapKey) {
+          s"$arrayName[$i] = null;"
+        } else {
+          "throw new RuntimeException(\"Cannot use null as map key!\");"
+        }
+        eval.code + s"""
+         if (${eval.isNull}) {
+           $isNullAssignment
+         } else {
+           $arrayName[$i] = ${eval.value};
+         }
+       """
+      }
+
+      ("",
+       assignments,
+       s"final ArrayData $arrayDataName = new $genericArrayClass($arrayName);",
+       arrayDataName)
+    } else {
+      val unsafeArraySizeInBytes =
+        UnsafeArrayData.calculateHeaderPortionInBytes(numElements) +
+        ByteArrayMethods.roundNumberOfBytesToNearestWord(elementType.defaultSize * numElements)
+      val baseOffset = Platform.BYTE_ARRAY_OFFSET
+      ctx.addMutableState("UnsafeArrayData", arrayDataName, "");
+
+      val primitiveValueTypeName = ctx.primitiveTypeName(elementType)
+      val assignments = elementsCode.zipWithIndex.map { case (eval, i) =>
+        val isNullAssignment = if (!isMapKey) {
+          s"$arrayDataName.setNullAt($i);"
+        } else {
+          "throw new RuntimeException(\"Cannot use null as map key!\");"
+        }
+        eval.code + s"""
+         if (${eval.isNull}) {
+           $isNullAssignment
+         } else {
+           $arrayDataName.set$primitiveValueTypeName($i, ${eval.value});
+         }
+       """
+      }
+
+      (s"""
+        byte[] $arrayName = new byte[$unsafeArraySizeInBytes];
+        $arrayDataName = new UnsafeArrayData();
+        Platform.putLong($arrayName, $baseOffset, $numElements);
+        $arrayDataName.pointTo($arrayName, $baseOffset, $unsafeArraySizeInBytes);
+      """,
+       assignments,
+       "",
+       arrayDataName)
+    }
+  }
+}
+
 /**
  * Returns a catalyst Map containing the evaluation of all children expressions as keys and values.
  * The children are a flatted sequence of kv pairs, e.g. (key1, value1, key2, value2, ...)
@@ -133,49 +200,26 @@ case class CreateMap(children: Seq[Expression]) extends Expression {
   }
 
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
-    val arrayClass = classOf[GenericArrayData].getName
     val mapClass = classOf[ArrayBasedMapData].getName
-    val keyArray = ctx.freshName("keyArray")
-    val valueArray = ctx.freshName("valueArray")
-    ctx.addMutableState("Object[]", keyArray, s"this.$keyArray = null;")
-    ctx.addMutableState("Object[]", valueArray, s"this.$valueArray = null;")
-
-    val keyData = s"new $arrayClass($keyArray)"
-    val valueData = s"new $arrayClass($valueArray)"
-    ev.copy(code = s"""
-      $keyArray = new Object[${keys.size}];
-      $valueArray = new Object[${values.size}];""" +
-      ctx.splitExpressions(
-        ctx.INPUT_ROW,
-        keys.zipWithIndex.map { case (key, i) =>
-          val eval = key.genCode(ctx)
-          s"""
-            ${eval.code}
-            if (${eval.isNull}) {
-              throw new RuntimeException("Cannot use null as map key!");
-            } else {
-              $keyArray[$i] = ${eval.value};
-            }
-          """
-        }) +
-      ctx.splitExpressions(
-        ctx.INPUT_ROW,
-        values.zipWithIndex.map { case (value, i) =>
-          val eval = value.genCode(ctx)
-          s"""
-            ${eval.code}
-            if (${eval.isNull}) {
-              $valueArray[$i] = null;
-            } else {
-              $valueArray[$i] = ${eval.value};
-            }
-          """
-        }) +
+    val MapType(keyDt, valueDt, _) = dataType
+    val evalKeys = keys.map(e => e.genCode(ctx))
+    val evalValues = values.map(e => e.genCode(ctx))
+    val (preprocessKeyData, assignKeys, postprocessKeyData, keyArrayData) =
+      GenArrayData.genCodeToCreateArrayData(ctx, keyDt, evalKeys, true)
+    val (preprocessValueData, assignValues, postprocessValueData, valueArrayData) =
+      GenArrayData.genCodeToCreateArrayData(ctx, valueDt, evalValues, false)
+    val code =
       s"""
-        final MapData ${ev.value} = new $mapClass($keyData, $valueData);
-        this.$keyArray = null;
-        this.$valueArray = null;
-      """, isNull = "false")
+       final boolean ${ev.isNull} = false;
+       $preprocessKeyData
+       ${ctx.splitExpressions(ctx.INPUT_ROW, assignKeys)}
+       $postprocessKeyData
+       $preprocessValueData
+       ${ctx.splitExpressions(ctx.INPUT_ROW, assignValues)}
+       $postprocessValueData
+       final MapData ${ev.value} = new $mapClass($keyArrayData, $valueArrayData);
+      """
+    ev.copy(code = code)
   }
 
   override def prettyName: String = "map"

http://git-wip-us.apache.org/repos/asf/spark/blob/93f35569/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala
index 140e86d..9beef41 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala
@@ -42,6 +42,19 @@ abstract class ArrayData extends SpecializedGetters with Serializable {
 
   def array: Array[Any]
 
+  def setNullAt(i: Int): Unit
+
+  def update(i: Int, value: Any): Unit
+
+  // default implementation (slow)
+  def setBoolean(i: Int, value: Boolean): Unit = update(i, value)
+  def setByte(i: Int, value: Byte): Unit = update(i, value)
+  def setShort(i: Int, value: Short): Unit = update(i, value)
+  def setInt(i: Int, value: Int): Unit = update(i, value)
+  def setLong(i: Int, value: Long): Unit = update(i, value)
+  def setFloat(i: Int, value: Float): Unit = update(i, value)
+  def setDouble(i: Int, value: Double): Unit = update(i, value)
+
   def toBooleanArray(): Array[Boolean] = {
     val size = numElements()
     val values = new Array[Boolean](size)

http://git-wip-us.apache.org/repos/asf/spark/blob/93f35569/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala
index 7ee9581..dd660c8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala
@@ -71,6 +71,10 @@ class GenericArrayData(val array: Array[Any]) extends ArrayData {
   override def getArray(ordinal: Int): ArrayData = getAs(ordinal)
   override def getMap(ordinal: Int): MapData = getAs(ordinal)
 
+  override def setNullAt(ordinal: Int): Unit = array(ordinal) = null
+
+  override def update(ordinal: Int, value: Any): Unit = array(ordinal) = value
+
   override def toString(): String = array.mkString("[", ",", "]")
 
   override def equals(o: Any): Boolean = {

http://git-wip-us.apache.org/repos/asf/spark/blob/93f35569/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
index ee5d1f6..587022f 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
@@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.expressions.objects.{CreateExternalRow, GetExternalRowField, ValidateExternalType}
-import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData}
+import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, DateTimeUtils}
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.UTF8String
 import org.apache.spark.util.ThreadUtils
@@ -71,7 +71,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
     val actual = plan(new GenericInternalRow(length)).toSeq(expressions.map(_.dataType))
     val expected = Seq.fill(length)(true)
 
-    if (!checkResult(actual, expected)) {
+    if (actual != expected) {
       fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected")
     }
   }
@@ -106,9 +106,10 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
     val expressions = Seq(If(EqualTo(strExpr, strExpr), strExpr, strExpr))
     val plan = GenerateMutableProjection.generate(expressions)
     val actual = plan(null).toSeq(expressions.map(_.dataType))
-    val expected = Seq(UTF8String.fromString("abc"))
+    assert(actual.length == 1)
+    val expected = UTF8String.fromString("abc")
 
-    if (!checkResult(actual, expected)) {
+    if (!checkResult(actual.head, expected, expressions.head.dataType)) {
       fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected")
     }
   }
@@ -118,9 +119,10 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
     val expressions = Seq(CreateArray(List.fill(length)(EqualTo(Literal(1), Literal(1)))))
     val plan = GenerateMutableProjection.generate(expressions)
     val actual = plan(new GenericInternalRow(length)).toSeq(expressions.map(_.dataType))
-    val expected = Seq(new GenericArrayData(Seq.fill(length)(true)))
+    assert(actual.length == 1)
+    val expected = UnsafeArrayData.fromPrimitiveArray(Array.fill(length)(true))
 
-    if (!checkResult(actual, expected)) {
+    if (!checkResult(actual.head, expected, expressions.head.dataType)) {
       fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected")
     }
   }
@@ -132,12 +134,11 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
         case (expr, i) => Seq(Literal(i), expr)
       }))
     val plan = GenerateMutableProjection.generate(expressions)
-    val actual = plan(new GenericInternalRow(length)).toSeq(expressions.map(_.dataType)).map {
-      case m: ArrayBasedMapData => ArrayBasedMapData.toScalaMap(m)
-    }
-    val expected = (0 until length).map((_, true)).toMap :: Nil
+    val actual = plan(new GenericInternalRow(length)).toSeq(expressions.map(_.dataType))
+    assert(actual.length == 1)
+    val expected = ArrayBasedMapData((0 until length).toArray, Array.fill(length)(true))
 
-    if (!checkResult(actual, expected)) {
+    if (!checkResult(actual.head, expected, expressions.head.dataType)) {
       fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected")
     }
   }
@@ -149,7 +150,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
     val actual = plan(new GenericInternalRow(length)).toSeq(expressions.map(_.dataType))
     val expected = Seq(InternalRow(Seq.fill(length)(true): _*))
 
-    if (!checkResult(actual, expected)) {
+    if (!checkResult(actual, expected, expressions.head.dataType)) {
       fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected")
     }
   }
@@ -162,9 +163,10 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
       }))
     val plan = GenerateMutableProjection.generate(expressions)
     val actual = plan(new GenericInternalRow(length)).toSeq(expressions.map(_.dataType))
-    val expected = Seq(InternalRow(Seq.fill(length)(true): _*))
+    assert(actual.length == 1)
+    val expected = InternalRow(Seq.fill(length)(true): _*)
 
-    if (!checkResult(actual, expected)) {
+    if (!checkResult(actual.head, expected, expressions.head.dataType)) {
       fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected")
     }
   }
@@ -177,7 +179,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
     val actual = plan(new GenericInternalRow(length)).toSeq(expressions.map(_.dataType))
     val expected = Seq(Row.fromSeq(Seq.fill(length)(1)))
 
-    if (!checkResult(actual, expected)) {
+    if (actual != expected) {
       fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected")
     }
   }
@@ -194,7 +196,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
     val expected = Seq.fill(length)(
       DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2015-07-24 07:00:00")))
 
-    if (!checkResult(actual, expected)) {
+    if (actual != expected) {
       fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected")
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/93f35569/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
index c21c6de..abe1d2b 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
@@ -120,16 +120,20 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
   test("CreateArray") {
     val intSeq = Seq(5, 10, 15, 20, 25)
     val longSeq = intSeq.map(_.toLong)
+    val byteSeq = intSeq.map(_.toByte)
     val strSeq = intSeq.map(_.toString)
     checkEvaluation(CreateArray(intSeq.map(Literal(_))), intSeq, EmptyRow)
     checkEvaluation(CreateArray(longSeq.map(Literal(_))), longSeq, EmptyRow)
+    checkEvaluation(CreateArray(byteSeq.map(Literal(_))), byteSeq, EmptyRow)
     checkEvaluation(CreateArray(strSeq.map(Literal(_))), strSeq, EmptyRow)
 
     val intWithNull = intSeq.map(Literal(_)) :+ Literal.create(null, IntegerType)
     val longWithNull = longSeq.map(Literal(_)) :+ Literal.create(null, LongType)
+    val byteWithNull = byteSeq.map(Literal(_)) :+ Literal.create(null, ByteType)
     val strWithNull = strSeq.map(Literal(_)) :+ Literal.create(null, StringType)
     checkEvaluation(CreateArray(intWithNull), intSeq :+ null, EmptyRow)
     checkEvaluation(CreateArray(longWithNull), longSeq :+ null, EmptyRow)
+    checkEvaluation(CreateArray(byteWithNull), byteSeq :+ null, EmptyRow)
     checkEvaluation(CreateArray(strWithNull), strSeq :+ null, EmptyRow)
     checkEvaluation(CreateArray(Literal.create(null, IntegerType) :: Nil), null :: Nil)
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/93f35569/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
index f836504..1ba6dd1 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
@@ -28,8 +28,8 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
 import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.optimizer.SimpleTestOptimizer
 import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project}
-import org.apache.spark.sql.catalyst.util.MapData
-import org.apache.spark.sql.types.DataType
+import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData}
+import org.apache.spark.sql.types._
 import org.apache.spark.util.Utils
 
 /**
@@ -59,14 +59,28 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks {
    * Check the equality between result of expression and expected value, it will handle
    * Array[Byte], Spread[Double], and MapData.
    */
-  protected def checkResult(result: Any, expected: Any): Boolean = {
+  protected def checkResult(result: Any, expected: Any, dataType: DataType): Boolean = {
     (result, expected) match {
       case (result: Array[Byte], expected: Array[Byte]) =>
         java.util.Arrays.equals(result, expected)
       case (result: Double, expected: Spread[Double @unchecked]) =>
         expected.asInstanceOf[Spread[Double]].isWithin(result)
+      case (result: ArrayData, expected: ArrayData) =>
+        result.numElements == expected.numElements && {
+          val et = dataType.asInstanceOf[ArrayType].elementType
+          var isSame = true
+          var i = 0
+          while (isSame && i < result.numElements) {
+            isSame = checkResult(result.get(i, et), expected.get(i, et), et)
+            i += 1
+          }
+          isSame
+        }
       case (result: MapData, expected: MapData) =>
-        result.keyArray() == expected.keyArray() && result.valueArray() == expected.valueArray()
+        val kt = dataType.asInstanceOf[MapType].keyType
+        val vt = dataType.asInstanceOf[MapType].valueType
+        checkResult(result.keyArray, expected.keyArray, ArrayType(kt)) &&
+          checkResult(result.valueArray, expected.valueArray, ArrayType(vt))
       case (result: Double, expected: Double) =>
         if (expected.isNaN) result.isNaN else expected == result
       case (result: Float, expected: Float) =>
@@ -108,7 +122,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks {
     val actual = try evaluate(expression, inputRow) catch {
       case e: Exception => fail(s"Exception evaluating $expression", e)
     }
-    if (!checkResult(actual, expected)) {
+    if (!checkResult(actual, expected, expression.dataType)) {
       val input = if (inputRow == EmptyRow) "" else s", input: $inputRow"
       fail(s"Incorrect evaluation (codegen off): $expression, " +
         s"actual: $actual, " +
@@ -127,7 +141,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks {
     plan.initialize(0)
 
     val actual = plan(inputRow).get(0, expression.dataType)
-    if (!checkResult(actual, expected)) {
+    if (!checkResult(actual, expected, expression.dataType)) {
       val input = if (inputRow == EmptyRow) "" else s", input: $inputRow"
       fail(s"Incorrect evaluation: $expression, actual: $actual, expected: $expected$input")
     }
@@ -188,7 +202,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks {
       expression)
     plan.initialize(0)
     var actual = plan(inputRow).get(0, expression.dataType)
-    assert(checkResult(actual, expected))
+    assert(checkResult(actual, expected, expression.dataType))
 
     plan = generateProject(
       GenerateUnsafeProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil),
@@ -196,7 +210,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks {
     plan.initialize(0)
     actual = FromUnsafeProjection(expression.dataType :: Nil)(
       plan(inputRow)).get(0, expression.dataType)
-    assert(checkResult(actual, expected))
+    assert(checkResult(actual, expected, expression.dataType))
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/93f35569/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java
----------------------------------------------------------------------
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java
index ff07940..354c878 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java
@@ -246,6 +246,12 @@ public abstract class ColumnVector implements AutoCloseable {
     public Object get(int ordinal, DataType dataType) {
       throw new UnsupportedOperationException();
     }
+
+    @Override
+    public void update(int ordinal, Object value) { throw new UnsupportedOperationException(); }
+
+    @Override
+    public void setNullAt(int ordinal) { throw new UnsupportedOperationException(); }
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/93f35569/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala
index 9a8d449..9eaf44c 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala
@@ -411,8 +411,8 @@ class ObjectHashAggregateSuite
     actual.zip(expected).foreach { case (lhs: Row, rhs: Row) =>
       assert(lhs.length == rhs.length)
       lhs.toSeq.zip(rhs.toSeq).foreach {
-        case (a: Double, b: Double) => checkResult(a, b +- tolerance)
-        case (a, b) => checkResult(a, b)
+        case (a: Double, b: Double) => checkResult(a, b +- tolerance, DoubleType)
+        case (a, b) => a == b
       }
     }
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org