You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by da...@apache.org on 2015/10/05 22:01:07 UTC

spark git commit: [SPARK-10585] [SQL] only copy data once when generate unsafe projection

Repository: spark
Updated Branches:
  refs/heads/master 883bd8fcc -> c4871369d


[SPARK-10585] [SQL] only copy data once when generate unsafe projection

This PR is a completely rewritten of GenerateUnsafeProjection, to accomplish the goal of copying data only once. The old code of GenerateUnsafeProjection is still there to reduce review difficulty.

Instead of creating unsafe conversion code for struct, array and map, we create code of writing the content to the global row buffer.

Author: Wenchen Fan <cl...@163.com>
Author: Wenchen Fan <cl...@outlook.com>

Closes #8747 from cloud-fan/copy-once.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/c4871369
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/c4871369
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/c4871369

Branch: refs/heads/master
Commit: c4871369db96fc33c465d11b3bbd1ffeb3b94e89
Parents: 883bd8f
Author: Wenchen Fan <cl...@163.com>
Authored: Mon Oct 5 13:00:58 2015 -0700
Committer: Davies Liu <da...@gmail.com>
Committed: Mon Oct 5 13:00:58 2015 -0700

----------------------------------------------------------------------
 .../catalyst/expressions/UnsafeArrayData.java   |   7 +-
 .../sql/catalyst/expressions/UnsafeMapData.java |  15 +-
 .../sql/catalyst/expressions/UnsafeReaders.java |   6 +
 .../sql/catalyst/expressions/UnsafeRow.java     |   4 +-
 .../catalyst/expressions/UnsafeRowWriters.java  |   4 +-
 .../sql/catalyst/expressions/UnsafeWriters.java |   4 +-
 .../expressions/codegen/BufferHolder.java       |  54 ++++
 .../expressions/codegen/UnsafeArrayWriter.java  | 151 +++++++++
 .../expressions/codegen/UnsafeRowWriter.java    | 199 ++++++++++++
 .../expressions/codegen/CodeGenerator.scala     |   1 +
 .../codegen/GenerateUnsafeProjection.scala      | 284 ++++++++++++++++-
 .../expressions/UnsafeRowConverterSuite.scala   | 305 ++++++++++++++-----
 12 files changed, 950 insertions(+), 84 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/c4871369/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
index 501dff0..da9538b 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
@@ -20,7 +20,6 @@ package org.apache.spark.sql.catalyst.expressions;
 import java.math.BigDecimal;
 import java.math.BigInteger;
 
-import org.apache.spark.sql.catalyst.InternalRow;
 import org.apache.spark.sql.types.*;
 import org.apache.spark.unsafe.Platform;
 import org.apache.spark.unsafe.array.ByteArrayMethods;
@@ -256,7 +255,7 @@ public class UnsafeArrayData extends ArrayData {
   }
 
   @Override
-  public InternalRow getStruct(int ordinal, int numFields) {
+  public UnsafeRow getStruct(int ordinal, int numFields) {
     assertIndexIsValid(ordinal);
     final int offset = getElementOffset(ordinal);
     if (offset < 0) return null;
@@ -267,7 +266,7 @@ public class UnsafeArrayData extends ArrayData {
   }
 
   @Override
-  public ArrayData getArray(int ordinal) {
+  public UnsafeArrayData getArray(int ordinal) {
     assertIndexIsValid(ordinal);
     final int offset = getElementOffset(ordinal);
     if (offset < 0) return null;
@@ -276,7 +275,7 @@ public class UnsafeArrayData extends ArrayData {
   }
 
   @Override
-  public MapData getMap(int ordinal) {
+  public UnsafeMapData getMap(int ordinal) {
     assertIndexIsValid(ordinal);
     final int offset = getElementOffset(ordinal);
     if (offset < 0) return null;

http://git-wip-us.apache.org/repos/asf/spark/blob/c4871369/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java
index 4621605..e9dab9e 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java
@@ -17,18 +17,23 @@
 
 package org.apache.spark.sql.catalyst.expressions;
 
-import org.apache.spark.sql.types.ArrayData;
 import org.apache.spark.sql.types.MapData;
 
 /**
  * An Unsafe implementation of Map which is backed by raw memory instead of Java objects.
  *
  * Currently we just use 2 UnsafeArrayData to represent UnsafeMapData.
+ *
+ * Note that when we write out this map, we should write out the `numElements` at first 4 bytes,
+ * and numBytes of key array at second 4 bytes, then follows key array content and value array
+ * content without `numElements` header.
+ * When we read in a map, we should read first 4 bytes as `numElements` and second 4 bytes as
+ * numBytes of key array, and construct unsafe key array and value array with these 2 information.
  */
 public class UnsafeMapData extends MapData {
 
-  public final UnsafeArrayData keys;
-  public final UnsafeArrayData values;
+  private final UnsafeArrayData keys;
+  private final UnsafeArrayData values;
   // The number of elements in this array
   private int numElements;
   // The size of this array's backing data, in bytes
@@ -50,12 +55,12 @@ public class UnsafeMapData extends MapData {
   }
 
   @Override
-  public ArrayData keyArray() {
+  public UnsafeArrayData keyArray() {
     return keys;
   }
 
   @Override
-  public ArrayData valueArray() {
+  public UnsafeArrayData valueArray() {
     return values;
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/c4871369/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeReaders.java
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeReaders.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeReaders.java
index 7b03185..6c5fcbc 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeReaders.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeReaders.java
@@ -21,6 +21,9 @@ import org.apache.spark.unsafe.Platform;
 
 public class UnsafeReaders {
 
+  /**
+   * Reads in unsafe array according to the format described in `UnsafeArrayData`.
+   */
   public static UnsafeArrayData readArray(Object baseObject, long baseOffset, int numBytes) {
     // Read the number of elements from first 4 bytes.
     final int numElements = Platform.getInt(baseObject, baseOffset);
@@ -30,6 +33,9 @@ public class UnsafeReaders {
     return array;
   }
 
+  /**
+   * Reads in unsafe map according to the format described in `UnsafeMapData`.
+   */
   public static UnsafeMapData readMap(Object baseObject, long baseOffset, int numBytes) {
     // Read the number of elements from first 4 bytes.
     final int numElements = Platform.getInt(baseObject, baseOffset);

http://git-wip-us.apache.org/repos/asf/spark/blob/c4871369/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
index 6c02004..e8ac299 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
@@ -446,7 +446,7 @@ public final class UnsafeRow extends MutableRow {
   }
 
   @Override
-  public ArrayData getArray(int ordinal) {
+  public UnsafeArrayData getArray(int ordinal) {
     if (isNullAt(ordinal)) {
       return null;
     } else {
@@ -458,7 +458,7 @@ public final class UnsafeRow extends MutableRow {
   }
 
   @Override
-  public MapData getMap(int ordinal) {
+  public UnsafeMapData getMap(int ordinal) {
     if (isNullAt(ordinal)) {
       return null;
     } else {

http://git-wip-us.apache.org/repos/asf/spark/blob/c4871369/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java
index 2f43db6..0f1e020 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java
@@ -233,8 +233,8 @@ public class UnsafeRowWriters {
 
     public static int write(UnsafeRow target, int ordinal, int cursor, UnsafeMapData input) {
       final long offset = target.getBaseOffset() + cursor;
-      final UnsafeArrayData keyArray = input.keys;
-      final UnsafeArrayData valueArray = input.values;
+      final UnsafeArrayData keyArray = input.keyArray();
+      final UnsafeArrayData valueArray = input.valueArray();
       final int keysNumBytes = keyArray.getSizeInBytes();
       final int valuesNumBytes = valueArray.getSizeInBytes();
       final int numBytes = 4 + 4 + keysNumBytes + valuesNumBytes;

http://git-wip-us.apache.org/repos/asf/spark/blob/c4871369/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeWriters.java
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeWriters.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeWriters.java
index cd83695..ce2d9c4 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeWriters.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeWriters.java
@@ -168,8 +168,8 @@ public class UnsafeWriters {
     }
 
     public static int write(Object targetObject, long targetOffset, UnsafeMapData input) {
-      final UnsafeArrayData keyArray = input.keys;
-      final UnsafeArrayData valueArray = input.values;
+      final UnsafeArrayData keyArray = input.keyArray();
+      final UnsafeArrayData valueArray = input.valueArray();
       final int keysNumBytes = keyArray.getSizeInBytes();
       final int valuesNumBytes = valueArray.getSizeInBytes();
       final int numBytes = 4 + 4 + keysNumBytes + valuesNumBytes;

http://git-wip-us.apache.org/repos/asf/spark/blob/c4871369/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java
new file mode 100644
index 0000000..9c94686
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.codegen;
+
+import org.apache.spark.unsafe.Platform;
+
+/**
+ * A helper class to manage the row buffer used in `GenerateUnsafeProjection`.
+ *
+ * Note that it is only used in `GenerateUnsafeProjection`, so it's safe to mark member variables
+ * public for ease of use.
+ */
+public class BufferHolder {
+  public byte[] buffer = new byte[64];
+  public int cursor = Platform.BYTE_ARRAY_OFFSET;
+
+  public void grow(int neededSize) {
+    final int length = totalSize() + neededSize;
+    if (buffer.length < length) {
+      // This will not happen frequently, because the buffer is re-used.
+      final byte[] tmp = new byte[length * 2];
+      Platform.copyMemory(
+        buffer,
+        Platform.BYTE_ARRAY_OFFSET,
+        tmp,
+        Platform.BYTE_ARRAY_OFFSET,
+        totalSize());
+      buffer = tmp;
+    }
+  }
+
+  public void reset() {
+    cursor = Platform.BYTE_ARRAY_OFFSET;
+  }
+
+  public int totalSize() {
+    return cursor - Platform.BYTE_ARRAY_OFFSET;
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/c4871369/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java
new file mode 100644
index 0000000..138178c
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java
@@ -0,0 +1,151 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.codegen;
+
+import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData;
+import org.apache.spark.sql.types.Decimal;
+import org.apache.spark.unsafe.Platform;
+import org.apache.spark.unsafe.types.CalendarInterval;
+import org.apache.spark.unsafe.types.UTF8String;
+
+/**
+ * A helper class to write data into global row buffer using `UnsafeArrayData` format,
+ * used by {@link org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection}.
+ */
+public class UnsafeArrayWriter {
+
+  private BufferHolder holder;
+  // The offset of the global buffer where we start to write this array.
+  private int startingOffset;
+
+  public void initialize(BufferHolder holder, int numElements, int fixedElementSize) {
+    // We need 4 bytes each element to store offset.
+    final int fixedSize = 4 * numElements;
+
+    this.holder = holder;
+    this.startingOffset = holder.cursor;
+
+    holder.grow(fixedSize);
+    holder.cursor += fixedSize;
+
+    // Grows the global buffer ahead for fixed size data.
+    holder.grow(fixedElementSize * numElements);
+  }
+
+  private long getElementOffset(int ordinal) {
+    return startingOffset + 4 * ordinal;
+  }
+
+  public void setNullAt(int ordinal) {
+    final int relativeOffset = holder.cursor - startingOffset;
+    // Writes negative offset value to represent null element.
+    Platform.putInt(holder.buffer, getElementOffset(ordinal), -relativeOffset);
+  }
+
+  public void setOffset(int ordinal) {
+    final int relativeOffset = holder.cursor - startingOffset;
+    Platform.putInt(holder.buffer, getElementOffset(ordinal), relativeOffset);
+  }
+
+  public void writeCompactDecimal(int ordinal, Decimal input, int precision, int scale) {
+    // make sure Decimal object has the same scale as DecimalType
+    if (input.changePrecision(precision, scale)) {
+      Platform.putLong(holder.buffer, holder.cursor, input.toUnscaledLong());
+      setOffset(ordinal);
+      holder.cursor += 8;
+    } else {
+      setNullAt(ordinal);
+    }
+  }
+
+  public void write(int ordinal, Decimal input, int precision, int scale) {
+    // make sure Decimal object has the same scale as DecimalType
+    if (input.changePrecision(precision, scale)) {
+      final byte[] bytes = input.toJavaBigDecimal().unscaledValue().toByteArray();
+      assert bytes.length <= 16;
+      holder.grow(bytes.length);
+
+      // Write the bytes to the variable length portion.
+      Platform.copyMemory(
+        bytes, Platform.BYTE_ARRAY_OFFSET, holder.buffer, holder.cursor, bytes.length);
+      setOffset(ordinal);
+      holder.cursor += bytes.length;
+    } else {
+      setNullAt(ordinal);
+    }
+  }
+
+  public void write(int ordinal, UTF8String input) {
+    final int numBytes = input.numBytes();
+
+    // grow the global buffer before writing data.
+    holder.grow(numBytes);
+
+    // Write the bytes to the variable length portion.
+    input.writeToMemory(holder.buffer, holder.cursor);
+
+    setOffset(ordinal);
+
+    // move the cursor forward.
+    holder.cursor += numBytes;
+  }
+
+  public void write(int ordinal, byte[] input) {
+    // grow the global buffer before writing data.
+    holder.grow(input.length);
+
+    // Write the bytes to the variable length portion.
+    Platform.copyMemory(
+      input, Platform.BYTE_ARRAY_OFFSET, holder.buffer, holder.cursor, input.length);
+
+    setOffset(ordinal);
+
+    // move the cursor forward.
+    holder.cursor += input.length;
+  }
+
+  public void write(int ordinal, CalendarInterval input) {
+    // grow the global buffer before writing data.
+    holder.grow(16);
+
+    // Write the months and microseconds fields of Interval to the variable length portion.
+    Platform.putLong(holder.buffer, holder.cursor, input.months);
+    Platform.putLong(holder.buffer, holder.cursor + 8, input.microseconds);
+
+    setOffset(ordinal);
+
+    // move the cursor forward.
+    holder.cursor += 16;
+  }
+
+
+
+  // If this array is already an UnsafeArray, we don't need to go through all elements, we can
+  // directly write it.
+  public static void directWrite(BufferHolder holder, UnsafeArrayData input) {
+    final int numBytes = input.getSizeInBytes();
+
+    // grow the global buffer before writing data.
+    holder.grow(numBytes);
+
+    // Writes the array content to the variable length portion.
+    input.writeToMemory(holder.buffer, holder.cursor);
+
+    holder.cursor += numBytes;
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/c4871369/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java
new file mode 100644
index 0000000..8b7debd
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java
@@ -0,0 +1,199 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.codegen;
+
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
+import org.apache.spark.sql.types.Decimal;
+import org.apache.spark.unsafe.Platform;
+import org.apache.spark.unsafe.array.ByteArrayMethods;
+import org.apache.spark.unsafe.bitset.BitSetMethods;
+import org.apache.spark.unsafe.types.CalendarInterval;
+import org.apache.spark.unsafe.types.UTF8String;
+
+/**
+ * A helper class to write data into global row buffer using `UnsafeRow` format,
+ * used by {@link org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection}.
+ */
+public class UnsafeRowWriter {
+
+  private BufferHolder holder;
+  // The offset of the global buffer where we start to write this row.
+  private int startingOffset;
+  private int nullBitsSize;
+
+  public void initialize(BufferHolder holder, int numFields) {
+    this.holder = holder;
+    this.startingOffset = holder.cursor;
+    this.nullBitsSize = UnsafeRow.calculateBitSetWidthInBytes(numFields);
+
+    // grow the global buffer to make sure it has enough space to write fixed-length data.
+    final int fixedSize = nullBitsSize + 8 * numFields;
+    holder.grow(fixedSize);
+    holder.cursor += fixedSize;
+
+    // zero-out the null bits region
+    for (int i = 0; i < nullBitsSize; i += 8) {
+      Platform.putLong(holder.buffer, startingOffset + i, 0L);
+    }
+  }
+
+  private void zeroOutPaddingBytes(int numBytes) {
+    if ((numBytes & 0x07) > 0) {
+      Platform.putLong(holder.buffer, holder.cursor + ((numBytes >> 3) << 3), 0L);
+    }
+  }
+
+  public void setNullAt(int ordinal) {
+    BitSetMethods.set(holder.buffer, startingOffset, ordinal);
+    Platform.putLong(holder.buffer, getFieldOffset(ordinal), 0L);
+  }
+
+  public long getFieldOffset(int ordinal) {
+    return startingOffset + nullBitsSize + 8 * ordinal;
+  }
+
+  public void setOffsetAndSize(int ordinal, long size) {
+    setOffsetAndSize(ordinal, holder.cursor, size);
+  }
+
+  public void setOffsetAndSize(int ordinal, long currentCursor, long size) {
+    final long relativeOffset = currentCursor - startingOffset;
+    final long fieldOffset = getFieldOffset(ordinal);
+    final long offsetAndSize = (relativeOffset << 32) | size;
+
+    Platform.putLong(holder.buffer, fieldOffset, offsetAndSize);
+  }
+
+  // Do word alignment for this row and grow the row buffer if needed.
+  // todo: remove this after we make unsafe array data word align.
+  public void alignToWords(int numBytes) {
+    final int remainder = numBytes & 0x07;
+
+    if (remainder > 0) {
+      final int paddingBytes = 8 - remainder;
+      holder.grow(paddingBytes);
+
+      for (int i = 0; i < paddingBytes; i++) {
+        Platform.putByte(holder.buffer, holder.cursor, (byte) 0);
+        holder.cursor++;
+      }
+    }
+  }
+
+  public void writeCompactDecimal(int ordinal, Decimal input, int precision, int scale) {
+    // make sure Decimal object has the same scale as DecimalType
+    if (input.changePrecision(precision, scale)) {
+      Platform.putLong(holder.buffer, getFieldOffset(ordinal), input.toUnscaledLong());
+    } else {
+      setNullAt(ordinal);
+    }
+  }
+
+  public void write(int ordinal, Decimal input, int precision, int scale) {
+    // grow the global buffer before writing data.
+    holder.grow(16);
+
+    // zero-out the bytes
+    Platform.putLong(holder.buffer, holder.cursor, 0L);
+    Platform.putLong(holder.buffer, holder.cursor + 8, 0L);
+
+    // Make sure Decimal object has the same scale as DecimalType.
+    // Note that we may pass in null Decimal object to set null for it.
+    if (input == null || !input.changePrecision(precision, scale)) {
+      BitSetMethods.set(holder.buffer, startingOffset, ordinal);
+      // keep the offset for future update
+      setOffsetAndSize(ordinal, 0L);
+    } else {
+      final byte[] bytes = input.toJavaBigDecimal().unscaledValue().toByteArray();
+      assert bytes.length <= 16;
+
+      // Write the bytes to the variable length portion.
+      Platform.copyMemory(
+        bytes, Platform.BYTE_ARRAY_OFFSET, holder.buffer, holder.cursor, bytes.length);
+      setOffsetAndSize(ordinal, bytes.length);
+    }
+
+    // move the cursor forward.
+    holder.cursor += 16;
+  }
+
+  public void write(int ordinal, UTF8String input) {
+    final int numBytes = input.numBytes();
+    final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes);
+
+    // grow the global buffer before writing data.
+    holder.grow(roundedSize);
+
+    zeroOutPaddingBytes(numBytes);
+
+    // Write the bytes to the variable length portion.
+    input.writeToMemory(holder.buffer, holder.cursor);
+
+    setOffsetAndSize(ordinal, numBytes);
+
+    // move the cursor forward.
+    holder.cursor += roundedSize;
+  }
+
+  public void write(int ordinal, byte[] input) {
+    final int numBytes = input.length;
+    final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes);
+
+    // grow the global buffer before writing data.
+    holder.grow(roundedSize);
+
+    zeroOutPaddingBytes(numBytes);
+
+    // Write the bytes to the variable length portion.
+    Platform.copyMemory(input, Platform.BYTE_ARRAY_OFFSET, holder.buffer, holder.cursor, numBytes);
+
+    setOffsetAndSize(ordinal, numBytes);
+
+    // move the cursor forward.
+    holder.cursor += roundedSize;
+  }
+
+  public void write(int ordinal, CalendarInterval input) {
+    // grow the global buffer before writing data.
+    holder.grow(16);
+
+    // Write the months and microseconds fields of Interval to the variable length portion.
+    Platform.putLong(holder.buffer, holder.cursor, input.months);
+    Platform.putLong(holder.buffer, holder.cursor + 8, input.microseconds);
+
+    setOffsetAndSize(ordinal, 16);
+
+    // move the cursor forward.
+    holder.cursor += 16;
+  }
+
+
+
+  // If this struct is already an UnsafeRow, we don't need to go through all fields, we can
+  // directly write it.
+  public static void directWrite(BufferHolder holder, UnsafeRow input) {
+    // No need to zero-out the bytes as UnsafeRow is word aligned for sure.
+    final int numBytes = input.getSizeInBytes();
+    // grow the global buffer before writing data.
+    holder.grow(numBytes);
+    // Write the bytes to the variable length portion.
+    input.writeToMemory(holder.buffer, holder.cursor);
+    // move the cursor forward.
+    holder.cursor += numBytes;
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/c4871369/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index da3103b..9a28781 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -272,6 +272,7 @@ class CodeGenContext {
    * 64kb code size limit in JVM
    *
    * @param row the variable name of row that is used by expressions
+   * @param expressions the codes to evaluate expressions.
    */
   def splitExpressions(row: String, expressions: Seq[String]): String = {
     val blocks = new ArrayBuffer[String]()

http://git-wip-us.apache.org/repos/asf/spark/blob/c4871369/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
index 55562fa..99bf50a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
@@ -393,10 +393,292 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
     case _ => input
   }
 
+  private val rowWriterClass = classOf[UnsafeRowWriter].getName
+  private val arrayWriterClass = classOf[UnsafeArrayWriter].getName
+
+  // TODO: if the nullability of field is correct, we can use it to save null check.
+  private def writeStructToBuffer(
+      ctx: CodeGenContext,
+      input: String,
+      fieldTypes: Seq[DataType],
+      bufferHolder: String): String = {
+    val fieldEvals = fieldTypes.zipWithIndex.map { case (dt, i) =>
+      val fieldName = ctx.freshName("fieldName")
+      val code = s"final ${ctx.javaType(dt)} $fieldName = ${ctx.getValue(input, dt, i.toString)};"
+      val isNull = s"$input.isNullAt($i)"
+      GeneratedExpressionCode(code, isNull, fieldName)
+    }
+
+    s"""
+      if ($input instanceof UnsafeRow) {
+        $rowWriterClass.directWrite($bufferHolder, (UnsafeRow) $input);
+      } else {
+        ${writeExpressionsToBuffer(ctx, input, fieldEvals, fieldTypes, bufferHolder)}
+      }
+    """
+  }
+
+  private def writeExpressionsToBuffer(
+      ctx: CodeGenContext,
+      row: String,
+      inputs: Seq[GeneratedExpressionCode],
+      inputTypes: Seq[DataType],
+      bufferHolder: String): String = {
+    val rowWriter = ctx.freshName("rowWriter")
+    ctx.addMutableState(rowWriterClass, rowWriter, s"this.$rowWriter = new $rowWriterClass();")
+
+    val writeFields = inputs.zip(inputTypes).zipWithIndex.map {
+      case ((input, dt), index) =>
+        val tmpCursor = ctx.freshName("tmpCursor")
+
+        val setNull = dt match {
+          case t: DecimalType if t.precision > Decimal.MAX_LONG_DIGITS =>
+            // Can't call setNullAt() for DecimalType with precision larger than 18.
+            s"$rowWriter.write($index, null, ${t.precision}, ${t.scale});"
+          case _ => s"$rowWriter.setNullAt($index);"
+        }
+
+        val writeField = dt match {
+          case t: StructType =>
+            s"""
+              // Remember the current cursor so that we can calculate how many bytes are
+              // written later.
+              final int $tmpCursor = $bufferHolder.cursor;
+              ${writeStructToBuffer(ctx, input.primitive, t.map(_.dataType), bufferHolder)}
+              $rowWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor);
+            """
+
+          case a @ ArrayType(et, _) =>
+            s"""
+              // Remember the current cursor so that we can calculate how many bytes are
+              // written later.
+              final int $tmpCursor = $bufferHolder.cursor;
+              ${writeArrayToBuffer(ctx, input.primitive, et, bufferHolder)}
+              $rowWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor);
+              $rowWriter.alignToWords($bufferHolder.cursor - $tmpCursor);
+            """
+
+          case m @ MapType(kt, vt, _) =>
+            s"""
+              // Remember the current cursor so that we can calculate how many bytes are
+              // written later.
+              final int $tmpCursor = $bufferHolder.cursor;
+              ${writeMapToBuffer(ctx, input.primitive, kt, vt, bufferHolder)}
+              $rowWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor);
+              $rowWriter.alignToWords($bufferHolder.cursor - $tmpCursor);
+            """
+
+          case _ if ctx.isPrimitiveType(dt) =>
+            val fieldOffset = ctx.freshName("fieldOffset")
+            s"""
+              final long $fieldOffset = $rowWriter.getFieldOffset($index);
+              Platform.putLong($bufferHolder.buffer, $fieldOffset, 0L);
+              ${writePrimitiveType(ctx, input.primitive, dt, s"$bufferHolder.buffer", fieldOffset)}
+            """
+
+          case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS =>
+            s"$rowWriter.writeCompactDecimal($index, ${input.primitive}, " +
+              s"${t.precision}, ${t.scale});"
+
+          case t: DecimalType =>
+            s"$rowWriter.write($index, ${input.primitive}, ${t.precision}, ${t.scale});"
+
+          case NullType => ""
+
+          case _ => s"$rowWriter.write($index, ${input.primitive});"
+        }
+
+        s"""
+          ${input.code}
+          if (${input.isNull}) {
+            $setNull
+          } else {
+            $writeField
+          }
+        """
+    }
+
+    s"""
+      $rowWriter.initialize($bufferHolder, ${inputs.length});
+      ${ctx.splitExpressions(row, writeFields)}
+    """
+  }
+
+  // TODO: if the nullability of array element is correct, we can use it to save null check.
+  private def writeArrayToBuffer(
+      ctx: CodeGenContext,
+      input: String,
+      elementType: DataType,
+      bufferHolder: String,
+      needHeader: Boolean = true): String = {
+    val arrayWriter = ctx.freshName("arrayWriter")
+    ctx.addMutableState(arrayWriterClass, arrayWriter,
+      s"this.$arrayWriter = new $arrayWriterClass();")
+    val numElements = ctx.freshName("numElements")
+    val index = ctx.freshName("index")
+    val element = ctx.freshName("element")
+
+    val jt = ctx.javaType(elementType)
+
+    val fixedElementSize = elementType match {
+      case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS => 8
+      case _ if ctx.isPrimitiveType(jt) => elementType.defaultSize
+      case _ => 0
+    }
+
+    val writeElement = elementType match {
+      case t: StructType =>
+        s"""
+          $arrayWriter.setOffset($index);
+          ${writeStructToBuffer(ctx, element, t.map(_.dataType), bufferHolder)}
+        """
+
+      case a @ ArrayType(et, _) =>
+        s"""
+          $arrayWriter.setOffset($index);
+          ${writeArrayToBuffer(ctx, element, et, bufferHolder)}
+        """
+
+      case m @ MapType(kt, vt, _) =>
+        s"""
+          $arrayWriter.setOffset($index);
+          ${writeMapToBuffer(ctx, element, kt, vt, bufferHolder)}
+        """
+
+      case _ if ctx.isPrimitiveType(elementType) =>
+        // Should we do word align?
+        val dataSize = elementType.defaultSize
+
+        s"""
+          $arrayWriter.setOffset($index);
+          ${writePrimitiveType(ctx, element, elementType,
+            s"$bufferHolder.buffer", s"$bufferHolder.cursor")}
+          $bufferHolder.cursor += $dataSize;
+        """
+
+      case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS =>
+        s"$arrayWriter.writeCompactDecimal($index, $element, ${t.precision}, ${t.scale});"
+
+      case t: DecimalType =>
+        s"$arrayWriter.write($index, $element, ${t.precision}, ${t.scale});"
+
+      case NullType => ""
+
+      case _ => s"$arrayWriter.write($index, $element);"
+    }
+
+    val writeHeader = if (needHeader) {
+      // If header is required, we need to write the number of elements into first 4 bytes.
+      s"""
+        $bufferHolder.grow(4);
+        Platform.putInt($bufferHolder.buffer, $bufferHolder.cursor, $numElements);
+        $bufferHolder.cursor += 4;
+      """
+    } else ""
+
+    s"""
+      final int $numElements = $input.numElements();
+      $writeHeader
+      if ($input instanceof UnsafeArrayData) {
+        $arrayWriterClass.directWrite($bufferHolder, (UnsafeArrayData) $input);
+      } else {
+        $arrayWriter.initialize($bufferHolder, $numElements, $fixedElementSize);
+
+        for (int $index = 0; $index < $numElements; $index++) {
+          if ($input.isNullAt($index)) {
+            $arrayWriter.setNullAt($index);
+          } else {
+            final $jt $element = ${ctx.getValue(input, elementType, index)};
+            $writeElement
+          }
+        }
+      }
+    """
+  }
+
+  // TODO: if the nullability of value element is correct, we can use it to save null check.
+  private def writeMapToBuffer(
+      ctx: CodeGenContext,
+      input: String,
+      keyType: DataType,
+      valueType: DataType,
+      bufferHolder: String): String = {
+    val keys = ctx.freshName("keys")
+    val values = ctx.freshName("values")
+    val tmpCursor = ctx.freshName("tmpCursor")
+
+
+    // Writes out unsafe map according to the format described in `UnsafeMapData`.
+    s"""
+      final ArrayData $keys = $input.keyArray();
+      final ArrayData $values = $input.valueArray();
+
+      $bufferHolder.grow(8);
+
+      // Write the numElements into first 4 bytes.
+      Platform.putInt($bufferHolder.buffer, $bufferHolder.cursor, $keys.numElements());
+
+      $bufferHolder.cursor += 8;
+      // Remember the current cursor so that we can write numBytes of key array later.
+      final int $tmpCursor = $bufferHolder.cursor;
+
+      ${writeArrayToBuffer(ctx, keys, keyType, bufferHolder, needHeader = false)}
+      // Write the numBytes of key array into second 4 bytes.
+      Platform.putInt($bufferHolder.buffer, $tmpCursor - 4, $bufferHolder.cursor - $tmpCursor);
+
+      ${writeArrayToBuffer(ctx, values, valueType, bufferHolder, needHeader = false)}
+    """
+  }
+
+  private def writePrimitiveType(
+      ctx: CodeGenContext,
+      input: String,
+      dt: DataType,
+      buffer: String,
+      offset: String) = {
+    assert(ctx.isPrimitiveType(dt))
+
+    val putMethod = s"put${ctx.primitiveTypeName(dt)}"
+
+    dt match {
+      case FloatType | DoubleType =>
+        val normalized = ctx.freshName("normalized")
+        val boxedType = ctx.boxedType(dt)
+        val handleNaN =
+          s"""
+            final ${ctx.javaType(dt)} $normalized;
+            if ($boxedType.isNaN($input)) {
+              $normalized = $boxedType.NaN;
+            } else {
+              $normalized = $input;
+            }
+          """
+
+        s"""
+          $handleNaN
+          Platform.$putMethod($buffer, $offset, $normalized);
+        """
+      case _ => s"Platform.$putMethod($buffer, $offset, $input);"
+    }
+  }
+
   def createCode(ctx: CodeGenContext, expressions: Seq[Expression]): GeneratedExpressionCode = {
     val exprEvals = expressions.map(e => e.gen(ctx))
     val exprTypes = expressions.map(_.dataType)
-    createCodeForStruct(ctx, "i", exprEvals, exprTypes)
+
+    val result = ctx.freshName("result")
+    ctx.addMutableState("UnsafeRow", result, s"this.$result = new UnsafeRow();")
+    val bufferHolder = ctx.freshName("bufferHolder")
+    val holderClass = classOf[BufferHolder].getName
+    ctx.addMutableState(holderClass, bufferHolder, s"this.$bufferHolder = new $holderClass();")
+
+    val code =
+      s"""
+        $bufferHolder.reset();
+        ${writeExpressionsToBuffer(ctx, "i", exprEvals, exprTypes, bufferHolder)}
+        $result.pointTo($bufferHolder.buffer, ${expressions.length}, $bufferHolder.totalSize());
+      """
+    GeneratedExpressionCode(code, "false", result)
   }
 
   protected def canonicalize(in: Seq[Expression]): Seq[Expression] =

http://git-wip-us.apache.org/repos/asf/spark/blob/c4871369/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
index 8c72203..c991cd8 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
@@ -18,7 +18,6 @@
 package org.apache.spark.sql.catalyst.expressions
 
 import java.sql.{Date, Timestamp}
-import java.util.Arrays
 
 import org.scalatest.Matchers
 
@@ -43,7 +42,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
     row.setInt(2, 2)
 
     val unsafeRow: UnsafeRow = converter.apply(row)
-    assert(converter.apply(row).getSizeInBytes === 8 + (3 * 8))
+    assert(unsafeRow.getSizeInBytes === 8 + (3 * 8))
     assert(unsafeRow.getLong(0) === 0)
     assert(unsafeRow.getLong(1) === 1)
     assert(unsafeRow.getInt(2) === 2)
@@ -62,6 +61,13 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
     assert(unsafeRowCopy.getLong(0) === 0)
     assert(unsafeRowCopy.getLong(1) === 1)
     assert(unsafeRowCopy.getInt(2) === 2)
+
+    // Make sure the converter can be reused, i.e. we correctly reset all states.
+    val unsafeRow2: UnsafeRow = converter.apply(row)
+    assert(unsafeRow2.getSizeInBytes === 8 + (3 * 8))
+    assert(unsafeRow2.getLong(0) === 0)
+    assert(unsafeRow2.getLong(1) === 1)
+    assert(unsafeRow2.getInt(2) === 2)
   }
 
   test("basic conversion with primitive, string and binary types") {
@@ -176,7 +182,6 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
       r
     }
 
-    // todo: we reuse the UnsafeRow in projection, so these tests are meaningless.
     val setToNullAfterCreation = converter.apply(rowWithNoNullColumns)
     assert(setToNullAfterCreation.isNullAt(0) === rowWithNoNullColumns.isNullAt(0))
     assert(setToNullAfterCreation.getBoolean(1) === rowWithNoNullColumns.getBoolean(1))
@@ -192,7 +197,6 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
       rowWithNoNullColumns.getDecimal(10, 10, 0))
     assert(setToNullAfterCreation.getDecimal(11, 38, 18) ===
       rowWithNoNullColumns.getDecimal(11, 38, 18))
-    // assert(setToNullAfterCreation.get(11) === rowWithNoNullColumns.get(11))
 
     for (i <- fieldTypes.indices) {
       // Cann't call setNullAt() on DecimalType
@@ -202,8 +206,6 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
         setToNullAfterCreation.setNullAt(i)
       }
     }
-    // There are some garbage left in the var-length area
-    assert(Arrays.equals(createdFromNull.getBytes, setToNullAfterCreation.getBytes()))
 
     setToNullAfterCreation.setNullAt(0)
     setToNullAfterCreation.setBoolean(1, false)
@@ -251,107 +253,274 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
     assert(converter.apply(row1).getBytes === converter.apply(row2).getBytes)
   }
 
+  test("basic conversion with struct type") {
+    val fieldTypes: Array[DataType] = Array(
+      new StructType().add("i", IntegerType),
+      new StructType().add("nest", new StructType().add("l", LongType))
+    )
+
+    val converter = UnsafeProjection.create(fieldTypes)
+
+    val row = new GenericMutableRow(fieldTypes.length)
+    row.update(0, InternalRow(1))
+    row.update(1, InternalRow(InternalRow(2L)))
+
+    val unsafeRow: UnsafeRow = converter.apply(row)
+    assert(unsafeRow.numFields == 2)
+
+    val row1 = unsafeRow.getStruct(0, 1)
+    assert(row1.getSizeInBytes == 8 + 1 * 8)
+    assert(row1.numFields == 1)
+    assert(row1.getInt(0) == 1)
+
+    val row2 = unsafeRow.getStruct(1, 1)
+    assert(row2.numFields() == 1)
+
+    val innerRow = row2.getStruct(0, 1)
+
+    {
+      assert(innerRow.getSizeInBytes == 8 + 1 * 8)
+      assert(innerRow.numFields == 1)
+      assert(innerRow.getLong(0) == 2L)
+    }
+
+    assert(row2.getSizeInBytes == 8 + 1 * 8 + innerRow.getSizeInBytes)
+
+    assert(unsafeRow.getSizeInBytes == 8 + 2 * 8 + row1.getSizeInBytes + row2.getSizeInBytes)
+  }
+
+  private def createArray(values: Any*): ArrayData = new GenericArrayData(values.toArray)
+
+  private def createMap(keys: Any*)(values: Any*): MapData = {
+    assert(keys.length == values.length)
+    new ArrayBasedMapData(createArray(keys: _*), createArray(values: _*))
+  }
+
+  private def arraySizeInRow(numBytes: Int): Int = roundedSize(4 + numBytes)
+
+  private def mapSizeInRow(numBytes: Int): Int = roundedSize(8 + numBytes)
+
+  private def testArrayInt(array: UnsafeArrayData, values: Seq[Int]): Unit = {
+    assert(array.numElements == values.length)
+    assert(array.getSizeInBytes == (4 + 4) * values.length)
+    values.zipWithIndex.foreach {
+      case (value, index) => assert(array.getInt(index) == value)
+    }
+  }
+
+  private def testMapInt(map: UnsafeMapData, keys: Seq[Int], values: Seq[Int]): Unit = {
+    assert(keys.length == values.length)
+    assert(map.numElements == keys.length)
+
+    testArrayInt(map.keyArray, keys)
+    testArrayInt(map.valueArray, values)
+
+    assert(map.getSizeInBytes == map.keyArray.getSizeInBytes + map.valueArray.getSizeInBytes)
+  }
+
   test("basic conversion with array type") {
     val fieldTypes: Array[DataType] = Array(
-      ArrayType(LongType),
-      ArrayType(ArrayType(LongType))
+      ArrayType(IntegerType),
+      ArrayType(ArrayType(IntegerType))
     )
     val converter = UnsafeProjection.create(fieldTypes)
 
-    val array1 = new GenericArrayData(Array[Any](1L, 2L))
-    val array2 = new GenericArrayData(Array[Any](new GenericArrayData(Array[Any](3L, 4L))))
     val row = new GenericMutableRow(fieldTypes.length)
-    row.update(0, array1)
-    row.update(1, array2)
+    row.update(0, createArray(1, 2))
+    row.update(1, createArray(createArray(3, 4)))
 
     val unsafeRow: UnsafeRow = converter.apply(row)
     assert(unsafeRow.numFields() == 2)
 
-    val unsafeArray1 = unsafeRow.getArray(0).asInstanceOf[UnsafeArrayData]
-    assert(unsafeArray1.getSizeInBytes == 4 * 2 + 8 * 2)
-    assert(unsafeArray1.numElements() == 2)
-    assert(unsafeArray1.getLong(0) == 1L)
-    assert(unsafeArray1.getLong(1) == 2L)
+    val unsafeArray1 = unsafeRow.getArray(0)
+    testArrayInt(unsafeArray1, Seq(1, 2))
 
-    val unsafeArray2 = unsafeRow.getArray(1).asInstanceOf[UnsafeArrayData]
-    assert(unsafeArray2.numElements() == 1)
+    val unsafeArray2 = unsafeRow.getArray(1)
+    assert(unsafeArray2.numElements == 1)
 
-    val nestedArray = unsafeArray2.getArray(0).asInstanceOf[UnsafeArrayData]
-    assert(nestedArray.getSizeInBytes == 4 * 2 + 8 * 2)
-    assert(nestedArray.numElements() == 2)
-    assert(nestedArray.getLong(0) == 3L)
-    assert(nestedArray.getLong(1) == 4L)
+    val nestedArray = unsafeArray2.getArray(0)
+    testArrayInt(nestedArray, Seq(3, 4))
 
-    assert(unsafeArray2.getSizeInBytes == 4 + 4 + nestedArray.getSizeInBytes)
+    assert(unsafeArray2.getSizeInBytes == 4 + (4 + nestedArray.getSizeInBytes))
 
-    val array1Size = roundedSize(4 + unsafeArray1.getSizeInBytes)
-    val array2Size = roundedSize(4 + unsafeArray2.getSizeInBytes)
+    val array1Size = arraySizeInRow(unsafeArray1.getSizeInBytes)
+    val array2Size = arraySizeInRow(unsafeArray2.getSizeInBytes)
     assert(unsafeRow.getSizeInBytes == 8 + 8 * 2 + array1Size + array2Size)
   }
 
   test("basic conversion with map type") {
-    def createArray(values: Any*): ArrayData = new GenericArrayData(values.toArray)
+    val fieldTypes: Array[DataType] = Array(
+      MapType(IntegerType, IntegerType),
+      MapType(IntegerType, MapType(IntegerType, IntegerType))
+    )
+    val converter = UnsafeProjection.create(fieldTypes)
 
-    def testIntLongMap(map: UnsafeMapData, keys: Array[Int], values: Array[Long]): Unit = {
-      val numElements = keys.length
-      assert(map.numElements() == numElements)
+    val map1 = createMap(1, 2)(3, 4)
 
-      val keyArray = map.keys
-      assert(keyArray.getSizeInBytes == 4 * numElements + 4 * numElements)
-      assert(keyArray.numElements() == numElements)
-      keys.zipWithIndex.foreach { case (key, i) =>
-        assert(keyArray.getInt(i) == key)
-      }
+    val innerMap = createMap(5, 6)(7, 8)
+    val map2 = createMap(9)(innerMap)
 
-      val valueArray = map.values
-      assert(valueArray.getSizeInBytes == 4 * numElements + 8 * numElements)
-      assert(valueArray.numElements() == numElements)
-      values.zipWithIndex.foreach { case (value, i) =>
-        assert(valueArray.getLong(i) == value)
-      }
+    val row = new GenericMutableRow(fieldTypes.length)
+    row.update(0, map1)
+    row.update(1, map2)
 
-      assert(map.getSizeInBytes == keyArray.getSizeInBytes + valueArray.getSizeInBytes)
+    val unsafeRow: UnsafeRow = converter.apply(row)
+    assert(unsafeRow.numFields == 2)
+
+    val unsafeMap1 = unsafeRow.getMap(0)
+    testMapInt(unsafeMap1, Seq(1, 2), Seq(3, 4))
+
+    val unsafeMap2 = unsafeRow.getMap(1)
+    assert(unsafeMap2.numElements == 1)
+
+    val keyArray = unsafeMap2.keyArray
+    testArrayInt(keyArray, Seq(9))
+
+    val valueArray = unsafeMap2.valueArray
+
+    {
+      assert(valueArray.numElements == 1)
+
+      val nestedMap = valueArray.getMap(0)
+      testMapInt(nestedMap, Seq(5, 6), Seq(7, 8))
+
+      assert(valueArray.getSizeInBytes == 4 + (8 + nestedMap.getSizeInBytes))
     }
 
+    assert(unsafeMap2.getSizeInBytes == keyArray.getSizeInBytes + valueArray.getSizeInBytes)
+
+    val map1Size = mapSizeInRow(unsafeMap1.getSizeInBytes)
+    val map2Size = mapSizeInRow(unsafeMap2.getSizeInBytes)
+    assert(unsafeRow.getSizeInBytes == 8 + 8 * 2 + map1Size + map2Size)
+  }
+
+  test("basic conversion with struct and array") {
     val fieldTypes: Array[DataType] = Array(
-      MapType(IntegerType, LongType),
-      MapType(IntegerType, MapType(IntegerType, LongType))
+      new StructType().add("arr", ArrayType(IntegerType)),
+      ArrayType(new StructType().add("l", LongType))
     )
     val converter = UnsafeProjection.create(fieldTypes)
 
-    val map1 = new ArrayBasedMapData(createArray(1, 2), createArray(3L, 4L))
+    val row = new GenericMutableRow(fieldTypes.length)
+    row.update(0, InternalRow(createArray(1)))
+    row.update(1, createArray(InternalRow(2L)))
+
+    val unsafeRow: UnsafeRow = converter.apply(row)
+    assert(unsafeRow.numFields() == 2)
+
+    val field1 = unsafeRow.getStruct(0, 1)
+    assert(field1.numFields == 1)
+
+    val innerArray = field1.getArray(0)
+    testArrayInt(innerArray, Seq(1))
 
-    val innerMap = new ArrayBasedMapData(createArray(5, 6), createArray(7L, 8L))
-    val map2 = new ArrayBasedMapData(createArray(9), createArray(innerMap))
+    assert(field1.getSizeInBytes == 8 + 8 + arraySizeInRow(innerArray.getSizeInBytes))
+
+    val field2 = unsafeRow.getArray(1)
+    assert(field2.numElements == 1)
+
+    val innerStruct = field2.getStruct(0, 1)
+
+    {
+      assert(innerStruct.numFields == 1)
+      assert(innerStruct.getSizeInBytes == 8 + 8)
+      assert(innerStruct.getLong(0) == 2L)
+    }
+
+    assert(field2.getSizeInBytes == 4 + innerStruct.getSizeInBytes)
+
+    assert(unsafeRow.getSizeInBytes ==
+      8 + 8 * 2 + field1.getSizeInBytes + arraySizeInRow(field2.getSizeInBytes))
+  }
+
+  test("basic conversion with struct and map") {
+    val fieldTypes: Array[DataType] = Array(
+      new StructType().add("map", MapType(IntegerType, IntegerType)),
+      MapType(IntegerType, new StructType().add("l", LongType))
+    )
+    val converter = UnsafeProjection.create(fieldTypes)
 
     val row = new GenericMutableRow(fieldTypes.length)
-    row.update(0, map1)
-    row.update(1, map2)
+    row.update(0, InternalRow(createMap(1)(2)))
+    row.update(1, createMap(3)(InternalRow(4L)))
 
     val unsafeRow: UnsafeRow = converter.apply(row)
     assert(unsafeRow.numFields() == 2)
 
-    val unsafeMap1 = unsafeRow.getMap(0).asInstanceOf[UnsafeMapData]
-    testIntLongMap(unsafeMap1, Array(1, 2), Array(3L, 4L))
+    val field1 = unsafeRow.getStruct(0, 1)
+    assert(field1.numFields == 1)
 
-    val unsafeMap2 = unsafeRow.getMap(1).asInstanceOf[UnsafeMapData]
-    assert(unsafeMap2.numElements() == 1)
+    val innerMap = field1.getMap(0)
+    testMapInt(innerMap, Seq(1), Seq(2))
 
-    val keyArray = unsafeMap2.keys
-    assert(keyArray.getSizeInBytes == 4 + 4)
-    assert(keyArray.numElements() == 1)
-    assert(keyArray.getInt(0) == 9)
+    assert(field1.getSizeInBytes == 8 + 8 + mapSizeInRow(innerMap.getSizeInBytes))
 
-    val valueArray = unsafeMap2.values
-    assert(valueArray.numElements() == 1)
-    val nestedMap = valueArray.getMap(0).asInstanceOf[UnsafeMapData]
-    testIntLongMap(nestedMap, Array(5, 6), Array(7L, 8L))
-    assert(valueArray.getSizeInBytes == 4 + 8 + nestedMap.getSizeInBytes)
+    val field2 = unsafeRow.getMap(1)
 
-    assert(unsafeMap2.getSizeInBytes == keyArray.getSizeInBytes + valueArray.getSizeInBytes)
+    val keyArray = field2.keyArray
+    testArrayInt(keyArray, Seq(3))
 
-    val map1Size = roundedSize(8 + unsafeMap1.getSizeInBytes)
-    val map2Size = roundedSize(8 + unsafeMap2.getSizeInBytes)
-    assert(unsafeRow.getSizeInBytes == 8 + 8 * 2 + map1Size + map2Size)
+    val valueArray = field2.valueArray
+
+    {
+      assert(valueArray.numElements == 1)
+
+      val innerStruct = valueArray.getStruct(0, 1)
+      assert(innerStruct.numFields == 1)
+      assert(innerStruct.getSizeInBytes == 8 + 8)
+      assert(innerStruct.getLong(0) == 4L)
+
+      assert(valueArray.getSizeInBytes == 4 + innerStruct.getSizeInBytes)
+    }
+
+    assert(field2.getSizeInBytes == keyArray.getSizeInBytes + valueArray.getSizeInBytes)
+
+    assert(unsafeRow.getSizeInBytes ==
+      8 + 8 * 2 + field1.getSizeInBytes + mapSizeInRow(field2.getSizeInBytes))
+  }
+
+  test("basic conversion with array and map") {
+    val fieldTypes: Array[DataType] = Array(
+      ArrayType(MapType(IntegerType, IntegerType)),
+      MapType(IntegerType, ArrayType(IntegerType))
+    )
+    val converter = UnsafeProjection.create(fieldTypes)
+
+    val row = new GenericMutableRow(fieldTypes.length)
+    row.update(0, createArray(createMap(1)(2)))
+    row.update(1, createMap(3)(createArray(4)))
+
+    val unsafeRow: UnsafeRow = converter.apply(row)
+    assert(unsafeRow.numFields() == 2)
+
+    val field1 = unsafeRow.getArray(0)
+    assert(field1.numElements == 1)
+
+    val innerMap = field1.getMap(0)
+    testMapInt(innerMap, Seq(1), Seq(2))
+
+    assert(field1.getSizeInBytes == 4 + (8 + innerMap.getSizeInBytes))
+
+    val field2 = unsafeRow.getMap(1)
+    assert(field2.numElements == 1)
+
+    val keyArray = field2.keyArray
+    testArrayInt(keyArray, Seq(3))
+
+    val valueArray = field2.valueArray
+
+    {
+      assert(valueArray.numElements == 1)
+
+      val innerArray = valueArray.getArray(0)
+      testArrayInt(innerArray, Seq(4))
+
+      assert(valueArray.getSizeInBytes == 4 + (4 + innerArray.getSizeInBytes))
+    }
+
+    assert(field2.getSizeInBytes == keyArray.getSizeInBytes + valueArray.getSizeInBytes)
+
+    assert(unsafeRow.getSizeInBytes ==
+      8 + 8 * 2 + arraySizeInRow(field1.getSizeInBytes) + mapSizeInRow(field2.getSizeInBytes))
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org