You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by hv...@apache.org on 2018/03/16 17:28:22 UTC

spark git commit: [SPARK-23581][SQL] Add interpreted unsafe projection

Repository: spark
Updated Branches:
  refs/heads/master dffeac369 -> 88d8de926


[SPARK-23581][SQL] Add interpreted unsafe projection

## What changes were proposed in this pull request?
We currently can only create unsafe rows using code generation. This is a problem for situations in which code generation fails. There is no fallback, and as a result we cannot execute the query.

This PR adds an interpreted version of `UnsafeProjection`. The implementation is modeled after `InterpretedMutableProjection`. It stores the expression results in a `GenericInternalRow`, and it then uses a conversion function to convert the `GenericInternalRow` into an `UnsafeRow`.

This PR does not implement the actual code generated to interpreted fallback logic. This will be done in a follow-up.

## How was this patch tested?
I am piggybacking on exiting `UnsafeProjection` tests, and I have added an interpreted version for each of these.

Author: Herman van Hovell <hv...@databricks.com>

Closes #20750 from hvanhovell/SPARK-23581.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/88d8de92
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/88d8de92
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/88d8de92

Branch: refs/heads/master
Commit: 88d8de9260edf6e9d5449ff7ef6e35d16051fc9f
Parents: dffeac3
Author: Herman van Hovell <hv...@databricks.com>
Authored: Fri Mar 16 18:28:16 2018 +0100
Committer: Herman van Hovell <hv...@databricks.com>
Committed: Fri Mar 16 18:28:16 2018 +0100

----------------------------------------------------------------------
 .../expressions/codegen/UnsafeArrayWriter.java  |  32 +-
 .../expressions/codegen/UnsafeRowWriter.java    |  30 +-
 .../expressions/codegen/UnsafeWriter.java       |  43 +++
 .../sql/catalyst/expressions/Expression.scala   |  26 ++
 .../InterpretedUnsafeProjection.scala           | 366 +++++++++++++++++++
 .../expressions/MonotonicallyIncreasingID.scala |   4 +-
 .../sql/catalyst/expressions/Projection.scala   |  19 +-
 .../codegen/GenerateUnsafeProjection.scala      |   2 +-
 .../expressions/randomExpressions.scala         |   6 +-
 .../catalyst/expressions/ComplexTypeSuite.scala |   2 +-
 .../expressions/ExpressionEvalHelper.scala      |  20 +-
 .../expressions/ObjectExpressionsSuite.scala    |  21 +-
 .../catalyst/expressions/ScalaUDFSuite.scala    |   2 +-
 .../expressions/UnsafeRowConverterSuite.scala   |  56 +--
 14 files changed, 555 insertions(+), 74 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/88d8de92/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java
index 791e8d8..82cd1b2 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java
@@ -30,7 +30,7 @@ import static org.apache.spark.sql.catalyst.expressions.UnsafeArrayData.calculat
  * A helper class to write data into global row buffer using `UnsafeArrayData` format,
  * used by {@link org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection}.
  */
-public class UnsafeArrayWriter {
+public final class UnsafeArrayWriter extends UnsafeWriter {
 
   private BufferHolder holder;
 
@@ -83,7 +83,7 @@ public class UnsafeArrayWriter {
     return startingOffset + headerInBytes + ordinal * elementSize;
   }
 
-  public void setOffsetAndSize(int ordinal, long currentCursor, int size) {
+  public void setOffsetAndSize(int ordinal, int currentCursor, int size) {
     assertIndexIsValid(ordinal);
     final long relativeOffset = currentCursor - startingOffset;
     final long offsetAndSize = (relativeOffset << 32) | (long)size;
@@ -96,49 +96,31 @@ public class UnsafeArrayWriter {
     BitSetMethods.set(holder.buffer, startingOffset + 8, ordinal);
   }
 
-  public void setNullBoolean(int ordinal) {
-    setNullBit(ordinal);
-    // put zero into the corresponding field when set null
-    Platform.putBoolean(holder.buffer, getElementOffset(ordinal, 1), false);
-  }
-
-  public void setNullByte(int ordinal) {
+  public void setNull1Bytes(int ordinal) {
     setNullBit(ordinal);
     // put zero into the corresponding field when set null
     Platform.putByte(holder.buffer, getElementOffset(ordinal, 1), (byte)0);
   }
 
-  public void setNullShort(int ordinal) {
+  public void setNull2Bytes(int ordinal) {
     setNullBit(ordinal);
     // put zero into the corresponding field when set null
     Platform.putShort(holder.buffer, getElementOffset(ordinal, 2), (short)0);
   }
 
-  public void setNullInt(int ordinal) {
+  public void setNull4Bytes(int ordinal) {
     setNullBit(ordinal);
     // put zero into the corresponding field when set null
     Platform.putInt(holder.buffer, getElementOffset(ordinal, 4), 0);
   }
 
-  public void setNullLong(int ordinal) {
+  public void setNull8Bytes(int ordinal) {
     setNullBit(ordinal);
     // put zero into the corresponding field when set null
     Platform.putLong(holder.buffer, getElementOffset(ordinal, 8), (long)0);
   }
 
-  public void setNullFloat(int ordinal) {
-    setNullBit(ordinal);
-    // put zero into the corresponding field when set null
-    Platform.putFloat(holder.buffer, getElementOffset(ordinal, 4), (float)0);
-  }
-
-  public void setNullDouble(int ordinal) {
-    setNullBit(ordinal);
-    // put zero into the corresponding field when set null
-    Platform.putDouble(holder.buffer, getElementOffset(ordinal, 8), (double)0);
-  }
-
-  public void setNull(int ordinal) { setNullLong(ordinal); }
+  public void setNull(int ordinal) { setNull8Bytes(ordinal); }
 
   public void write(int ordinal, boolean value) {
     assertIndexIsValid(ordinal);

http://git-wip-us.apache.org/repos/asf/spark/blob/88d8de92/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java
index 5d9515c..2620bbc 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java
@@ -38,7 +38,7 @@ import org.apache.spark.unsafe.types.UTF8String;
  * beginning of the global row buffer, we don't need to update `startingOffset` and can just call
  * `zeroOutNullBytes` before writing new data.
  */
-public class UnsafeRowWriter {
+public final class UnsafeRowWriter extends UnsafeWriter {
 
   private final BufferHolder holder;
   // The offset of the global buffer where we start to write this row.
@@ -93,18 +93,38 @@ public class UnsafeRowWriter {
     Platform.putLong(holder.buffer, getFieldOffset(ordinal), 0L);
   }
 
+  @Override
+  public void setNull1Bytes(int ordinal) {
+    setNullAt(ordinal);
+  }
+
+  @Override
+  public void setNull2Bytes(int ordinal) {
+    setNullAt(ordinal);
+  }
+
+  @Override
+  public void setNull4Bytes(int ordinal) {
+    setNullAt(ordinal);
+  }
+
+  @Override
+  public void setNull8Bytes(int ordinal) {
+    setNullAt(ordinal);
+  }
+
   public long getFieldOffset(int ordinal) {
     return startingOffset + nullBitsSize + 8 * ordinal;
   }
 
-  public void setOffsetAndSize(int ordinal, long size) {
+  public void setOffsetAndSize(int ordinal, int size) {
     setOffsetAndSize(ordinal, holder.cursor, size);
   }
 
-  public void setOffsetAndSize(int ordinal, long currentCursor, long size) {
+  public void setOffsetAndSize(int ordinal, int currentCursor, int size) {
     final long relativeOffset = currentCursor - startingOffset;
     final long fieldOffset = getFieldOffset(ordinal);
-    final long offsetAndSize = (relativeOffset << 32) | size;
+    final long offsetAndSize = (relativeOffset << 32) | (long) size;
 
     Platform.putLong(holder.buffer, fieldOffset, offsetAndSize);
   }
@@ -174,7 +194,7 @@ public class UnsafeRowWriter {
       if (input == null || !input.changePrecision(precision, scale)) {
         BitSetMethods.set(holder.buffer, startingOffset, ordinal);
         // keep the offset for future update
-        setOffsetAndSize(ordinal, 0L);
+        setOffsetAndSize(ordinal, 0);
       } else {
         final byte[] bytes = input.toJavaBigDecimal().unscaledValue().toByteArray();
         assert bytes.length <= 16;

http://git-wip-us.apache.org/repos/asf/spark/blob/88d8de92/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java
new file mode 100644
index 0000000..c94b5c7
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.expressions.codegen;
+
+import org.apache.spark.sql.types.Decimal;
+import org.apache.spark.unsafe.types.CalendarInterval;
+import org.apache.spark.unsafe.types.UTF8String;
+
+/**
+ * Base class for writing Unsafe* structures.
+ */
+public abstract class UnsafeWriter {
+  public abstract void setNull1Bytes(int ordinal);
+  public abstract void setNull2Bytes(int ordinal);
+  public abstract void setNull4Bytes(int ordinal);
+  public abstract void setNull8Bytes(int ordinal);
+  public abstract void write(int ordinal, boolean value);
+  public abstract void write(int ordinal, byte value);
+  public abstract void write(int ordinal, short value);
+  public abstract void write(int ordinal, int value);
+  public abstract void write(int ordinal, long value);
+  public abstract void write(int ordinal, float value);
+  public abstract void write(int ordinal, double value);
+  public abstract void write(int ordinal, Decimal input, int precision, int scale);
+  public abstract void write(int ordinal, UTF8String input);
+  public abstract void write(int ordinal, byte[] input);
+  public abstract void write(int ordinal, CalendarInterval input);
+  public abstract void setOffsetAndSize(int ordinal, int currentCursor, int size);
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/88d8de92/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index ed90b18..d7f9e38 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -328,6 +328,32 @@ trait Nondeterministic extends Expression {
   protected def evalInternal(input: InternalRow): Any
 }
 
+/**
+ * An expression that contains mutable state. A stateful expression is always non-deterministic
+ * because the results it produces during evaluation are not only dependent on the given input
+ * but also on its internal state.
+ *
+ * The state of the expressions is generally not exposed in the parameter list and this makes
+ * comparing stateful expressions problematic because similar stateful expressions (with the same
+ * parameter list) but with different internal state will be considered equal. This is especially
+ * problematic during tree transformations. In order to counter this the `fastEquals` method for
+ * stateful expressions only returns `true` for the same reference.
+ *
+ * A stateful expression should never be evaluated multiple times for a single row. This should
+ * only be a problem for interpreted execution. This can be prevented by creating fresh copies
+ * of the stateful expression before execution, these can be made using the `freshCopy` function.
+ */
+trait Stateful extends Nondeterministic {
+  /**
+   * Return a fresh uninitialized copy of the stateful expression.
+   */
+  def freshCopy(): Stateful
+
+  /**
+   * Only the same reference is considered equal.
+   */
+  override def fastEquals(other: TreeNode[_]): Boolean = this eq other
+}
 
 /**
  * A leaf expression, i.e. one without any child expressions.

http://git-wip-us.apache.org/repos/asf/spark/blob/88d8de92/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala
new file mode 100644
index 0000000..0da5ece
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala
@@ -0,0 +1,366 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.SparkException
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeArrayWriter, UnsafeRowWriter, UnsafeWriter}
+import org.apache.spark.sql.catalyst.util.ArrayData
+import org.apache.spark.sql.types.{UserDefinedType, _}
+import org.apache.spark.unsafe.Platform
+
+/**
+ * An interpreted unsafe projection. This class reuses the [[UnsafeRow]] it produces, a consumer
+ * should copy the row if it is being buffered. This class is not thread safe.
+ *
+ * @param expressions that produces the resulting fields. These expressions must be bound
+ *                    to a schema.
+ */
+class InterpretedUnsafeProjection(expressions: Array[Expression]) extends UnsafeProjection {
+  import InterpretedUnsafeProjection._
+
+  /** Number of (top level) fields in the resulting row. */
+  private[this] val numFields = expressions.length
+
+  /** Array that expression results. */
+  private[this] val values = new Array[Any](numFields)
+
+  /** The row representing the expression results. */
+  private[this] val intermediate = new GenericInternalRow(values)
+
+  /** The row returned by the projection. */
+  private[this] val result = new UnsafeRow(numFields)
+
+  /** The buffer which holds the resulting row's backing data. */
+  private[this] val holder = new BufferHolder(result, numFields * 32)
+
+  /** The writer that writes the intermediate result to the result row. */
+  private[this] val writer: InternalRow => Unit = {
+    val rowWriter = new UnsafeRowWriter(holder, numFields)
+    val baseWriter = generateStructWriter(
+      holder,
+      rowWriter,
+      expressions.map(e => StructField("", e.dataType, e.nullable)))
+    if (!expressions.exists(_.nullable)) {
+      // No nullable fields. The top-level null bit mask will always be zeroed out.
+      baseWriter
+    } else {
+      // Zero out the null bit mask before we write the row.
+      row => {
+        rowWriter.zeroOutNullBytes()
+        baseWriter(row)
+      }
+    }
+  }
+
+  override def initialize(partitionIndex: Int): Unit = {
+    expressions.foreach(_.foreach {
+      case n: Nondeterministic => n.initialize(partitionIndex)
+      case _ =>
+    })
+  }
+
+  override def apply(row: InternalRow): UnsafeRow = {
+    // Put the expression results in the intermediate row.
+    var i = 0
+    while (i < numFields) {
+      values(i) = expressions(i).eval(row)
+      i += 1
+    }
+
+    // Write the intermediate row to an unsafe row.
+    holder.reset()
+    writer(intermediate)
+    result.setTotalSize(holder.totalSize())
+    result
+  }
+}
+
+/**
+ * Helper functions for creating an [[InterpretedUnsafeProjection]].
+ */
+object InterpretedUnsafeProjection extends UnsafeProjectionCreator {
+
+  /**
+   * Returns an [[UnsafeProjection]] for given sequence of bound Expressions.
+   */
+  override protected def createProjection(exprs: Seq[Expression]): UnsafeProjection = {
+    // We need to make sure that we do not reuse stateful expressions.
+    val cleanedExpressions = exprs.map(_.transform {
+      case s: Stateful => s.freshCopy()
+    })
+    new InterpretedUnsafeProjection(cleanedExpressions.toArray)
+  }
+
+  /**
+   * Generate a struct writer function. The generated function writes an [[InternalRow]] to the
+   * given buffer using the given [[UnsafeRowWriter]].
+   */
+  private def generateStructWriter(
+      bufferHolder: BufferHolder,
+      rowWriter: UnsafeRowWriter,
+      fields: Array[StructField]): InternalRow => Unit = {
+    val numFields = fields.length
+
+    // Create field writers.
+    val fieldWriters = fields.map { field =>
+      generateFieldWriter(bufferHolder, rowWriter, field.dataType, field.nullable)
+    }
+    // Create basic writer.
+    row => {
+      var i = 0
+      while (i < numFields) {
+        fieldWriters(i).apply(row, i)
+        i += 1
+      }
+    }
+  }
+
+  /**
+   * Generate a writer function for a struct field, array element, map key or map value. The
+   * generated function writes the element at an index in a [[SpecializedGetters]] object (row
+   * or array) to the given buffer using the given [[UnsafeWriter]].
+   */
+  private def generateFieldWriter(
+      bufferHolder: BufferHolder,
+      writer: UnsafeWriter,
+      dt: DataType,
+      nullable: Boolean): (SpecializedGetters, Int) => Unit = {
+
+    // Create the the basic writer.
+    val unsafeWriter: (SpecializedGetters, Int) => Unit = dt match {
+      case BooleanType =>
+        (v, i) => writer.write(i, v.getBoolean(i))
+
+      case ByteType =>
+        (v, i) => writer.write(i, v.getByte(i))
+
+      case ShortType =>
+        (v, i) => writer.write(i, v.getShort(i))
+
+      case IntegerType | DateType =>
+        (v, i) => writer.write(i, v.getInt(i))
+
+      case LongType | TimestampType =>
+        (v, i) => writer.write(i, v.getLong(i))
+
+      case FloatType =>
+        (v, i) => writer.write(i, v.getFloat(i))
+
+      case DoubleType =>
+        (v, i) => writer.write(i, v.getDouble(i))
+
+      case DecimalType.Fixed(precision, scale) =>
+        (v, i) => writer.write(i, v.getDecimal(i, precision, scale), precision, scale)
+
+      case CalendarIntervalType =>
+        (v, i) => writer.write(i, v.getInterval(i))
+
+      case BinaryType =>
+        (v, i) => writer.write(i, v.getBinary(i))
+
+      case StringType =>
+        (v, i) => writer.write(i, v.getUTF8String(i))
+
+      case StructType(fields) =>
+        val numFields = fields.length
+        val rowWriter = new UnsafeRowWriter(bufferHolder, numFields)
+        val structWriter = generateStructWriter(bufferHolder, rowWriter, fields)
+        (v, i) => {
+          val tmpCursor = bufferHolder.cursor
+          v.getStruct(i, fields.length) match {
+            case row: UnsafeRow =>
+              writeUnsafeData(
+                bufferHolder,
+                row.getBaseObject,
+                row.getBaseOffset,
+                row.getSizeInBytes)
+            case row =>
+              // Nested struct. We don't know where this will start because a row can be
+              // variable length, so we need to update the offsets and zero out the bit mask.
+              rowWriter.reset()
+              structWriter.apply(row)
+          }
+          writer.setOffsetAndSize(i, tmpCursor, bufferHolder.cursor - tmpCursor)
+        }
+
+      case ArrayType(elementType, containsNull) =>
+        val arrayWriter = new UnsafeArrayWriter
+        val elementSize = getElementSize(elementType)
+        val elementWriter = generateFieldWriter(
+          bufferHolder,
+          arrayWriter,
+          elementType,
+          containsNull)
+        (v, i) => {
+          val tmpCursor = bufferHolder.cursor
+          writeArray(bufferHolder, arrayWriter, elementWriter, v.getArray(i), elementSize)
+          writer.setOffsetAndSize(i, tmpCursor, bufferHolder.cursor - tmpCursor)
+        }
+
+      case MapType(keyType, valueType, valueContainsNull) =>
+        val keyArrayWriter = new UnsafeArrayWriter
+        val keySize = getElementSize(keyType)
+        val keyWriter = generateFieldWriter(
+          bufferHolder,
+          keyArrayWriter,
+          keyType,
+          nullable = false)
+        val valueArrayWriter = new UnsafeArrayWriter
+        val valueSize = getElementSize(valueType)
+        val valueWriter = generateFieldWriter(
+          bufferHolder,
+          valueArrayWriter,
+          valueType,
+          valueContainsNull)
+        (v, i) => {
+          val tmpCursor = bufferHolder.cursor
+          v.getMap(i) match {
+            case map: UnsafeMapData =>
+              writeUnsafeData(
+                bufferHolder,
+                map.getBaseObject,
+                map.getBaseOffset,
+                map.getSizeInBytes)
+            case map =>
+              // preserve 8 bytes to write the key array numBytes later.
+              bufferHolder.grow(8)
+              bufferHolder.cursor += 8
+
+              // Write the keys and write the numBytes of key array into the first 8 bytes.
+              writeArray(bufferHolder, keyArrayWriter, keyWriter, map.keyArray(), keySize)
+              Platform.putLong(bufferHolder.buffer, tmpCursor, bufferHolder.cursor - tmpCursor - 8)
+
+              // Write the values.
+              writeArray(bufferHolder, valueArrayWriter, valueWriter, map.valueArray(), valueSize)
+          }
+          writer.setOffsetAndSize(i, tmpCursor, bufferHolder.cursor - tmpCursor)
+        }
+
+      case udt: UserDefinedType[_] =>
+        generateFieldWriter(bufferHolder, writer, udt.sqlType, nullable)
+
+      case NullType =>
+        (_, _) => {}
+
+      case _ =>
+        throw new SparkException(s"Unsupported data type $dt")
+    }
+
+    // Always wrap the writer with a null safe version.
+    dt match {
+      case _: UserDefinedType[_] =>
+        // The null wrapper depends on the sql type and not on the UDT.
+        unsafeWriter
+      case DecimalType.Fixed(precision, _) if precision > Decimal.MAX_LONG_DIGITS =>
+        // We can't call setNullAt() for DecimalType with precision larger than 18, we call write
+        // directly. We can use the unwrapped writer directly.
+        unsafeWriter
+      case BooleanType | ByteType =>
+        (v, i) => {
+          if (!v.isNullAt(i)) {
+            unsafeWriter(v, i)
+          } else {
+            writer.setNull1Bytes(i)
+          }
+        }
+      case ShortType =>
+        (v, i) => {
+          if (!v.isNullAt(i)) {
+            unsafeWriter(v, i)
+          } else {
+            writer.setNull2Bytes(i)
+          }
+        }
+      case IntegerType | DateType | FloatType =>
+        (v, i) => {
+          if (!v.isNullAt(i)) {
+            unsafeWriter(v, i)
+          } else {
+            writer.setNull4Bytes(i)
+          }
+        }
+      case _ =>
+        (v, i) => {
+          if (!v.isNullAt(i)) {
+            unsafeWriter(v, i)
+          } else {
+            writer.setNull8Bytes(i)
+          }
+        }
+    }
+  }
+
+  /**
+   * Get the number of bytes elements of a data type will occupy in the fixed part of an
+   * [[UnsafeArrayData]] object. Reference types are stored as an 8 byte combination of an
+   * offset (upper 4 bytes) and a length (lower 4 bytes), these point to the variable length
+   * portion of the array object. Primitives take up to 8 bytes, depending on the size of the
+   * underlying data type.
+   */
+  private def getElementSize(dataType: DataType): Int = dataType match {
+    case NullType | StringType | BinaryType | CalendarIntervalType |
+         _: DecimalType | _: StructType | _: ArrayType | _: MapType => 8
+    case _ => dataType.defaultSize
+  }
+
+  /**
+   * Write an array to the buffer. If the array is already in serialized form (an instance of
+   * [[UnsafeArrayData]]) then we copy the bytes directly, otherwise we do an element-by-element
+   * copy.
+   */
+  private def writeArray(
+      bufferHolder: BufferHolder,
+      arrayWriter: UnsafeArrayWriter,
+      elementWriter: (SpecializedGetters, Int) => Unit,
+      array: ArrayData,
+      elementSize: Int): Unit = array match {
+    case unsafe: UnsafeArrayData =>
+      writeUnsafeData(
+        bufferHolder,
+        unsafe.getBaseObject,
+        unsafe.getBaseOffset,
+        unsafe.getSizeInBytes)
+    case _ =>
+      val numElements = array.numElements()
+      arrayWriter.initialize(bufferHolder, numElements, elementSize)
+      var i = 0
+      while (i < numElements) {
+        elementWriter.apply(array, i)
+        i += 1
+      }
+  }
+
+  /**
+   * Write an opaque block of data to the buffer. This is used to copy
+   * [[UnsafeRow]], [[UnsafeArrayData]] and [[UnsafeMapData]] objects.
+   */
+  private def writeUnsafeData(
+      bufferHolder: BufferHolder,
+      baseObject: AnyRef,
+      baseOffset: Long,
+      sizeInBytes: Int) : Unit = {
+    bufferHolder.grow(sizeInBytes)
+    Platform.copyMemory(
+      baseObject,
+      baseOffset,
+      bufferHolder.buffer,
+      bufferHolder.cursor,
+      sizeInBytes)
+    bufferHolder.cursor += sizeInBytes
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/88d8de92/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala
index 4523079..dd523d3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala
@@ -39,7 +39,7 @@ import org.apache.spark.sql.types.{DataType, LongType}
       within each partition. The assumption is that the data frame has less than 1 billion
       partitions, and each partition has less than 8 billion records.
   """)
-case class MonotonicallyIncreasingID() extends LeafExpression with Nondeterministic {
+case class MonotonicallyIncreasingID() extends LeafExpression with Stateful {
 
   /**
    * Record ID within each partition. By being transient, count's value is reset to 0 every time
@@ -79,4 +79,6 @@ case class MonotonicallyIncreasingID() extends LeafExpression with Nondeterminis
   override def prettyName: String = "monotonically_increasing_id"
 
   override def sql: String = s"$prettyName()"
+
+  override def freshCopy(): MonotonicallyIncreasingID = MonotonicallyIncreasingID()
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/88d8de92/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
index 64b94f0..3cd7368 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
@@ -108,8 +108,7 @@ abstract class UnsafeProjection extends Projection {
   override def apply(row: InternalRow): UnsafeRow
 }
 
-object UnsafeProjection {
-
+trait UnsafeProjectionCreator {
   /**
    * Returns an UnsafeProjection for given StructType.
    *
@@ -127,13 +126,13 @@ object UnsafeProjection {
   }
 
   /**
-   * Returns an UnsafeProjection for given sequence of Expressions (bounded).
+   * Returns an UnsafeProjection for given sequence of bound Expressions.
    */
   def create(exprs: Seq[Expression]): UnsafeProjection = {
     val unsafeExprs = exprs.map(_ transform {
       case CreateNamedStruct(children) => CreateNamedStructUnsafe(children)
     })
-    GenerateUnsafeProjection.generate(unsafeExprs)
+    createProjection(unsafeExprs)
   }
 
   def create(expr: Expression): UnsafeProjection = create(Seq(expr))
@@ -147,6 +146,18 @@ object UnsafeProjection {
   }
 
   /**
+   * Returns an [[UnsafeProjection]] for given sequence of bound Expressions.
+   */
+  protected def createProjection(exprs: Seq[Expression]): UnsafeProjection
+}
+
+object UnsafeProjection extends UnsafeProjectionCreator {
+
+  override protected def createProjection(exprs: Seq[Expression]): UnsafeProjection = {
+    GenerateUnsafeProjection.generate(exprs)
+  }
+
+  /**
    * Same as other create()'s but allowing enabling/disabling subexpression elimination.
    * TODO: refactor the plumbing and clean this up.
    */

http://git-wip-us.apache.org/repos/asf/spark/blob/88d8de92/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
index 22717f5..6682ba5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
@@ -247,7 +247,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
 
         for (int $index = 0; $index < $numElements; $index++) {
           if ($tmpInput.isNullAt($index)) {
-            $arrayWriter.setNull$primitiveTypeName($index);
+            $arrayWriter.setNull${elementOrOffsetSize}Bytes($index);
           } else {
             $writeElement
           }

http://git-wip-us.apache.org/repos/asf/spark/blob/88d8de92/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala
index 6c9937d..f366338 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala
@@ -31,7 +31,7 @@ import org.apache.spark.util.random.XORShiftRandom
  *
  * Since this expression is stateful, it cannot be a case object.
  */
-abstract class RDG extends UnaryExpression with ExpectsInputTypes with Nondeterministic {
+abstract class RDG extends UnaryExpression with ExpectsInputTypes with Stateful {
   /**
    * Record ID within each partition. By being transient, the Random Number Generator is
    * reset every time we serialize and deserialize and initialize it.
@@ -85,6 +85,8 @@ case class Rand(child: Expression) extends RDG {
       final ${CodeGenerator.javaType(dataType)} ${ev.value} = $rngTerm.nextDouble();""",
       isNull = "false")
   }
+
+  override def freshCopy(): Rand = Rand(child)
 }
 
 object Rand {
@@ -120,6 +122,8 @@ case class Randn(child: Expression) extends RDG {
       final ${CodeGenerator.javaType(dataType)} ${ev.value} = $rngTerm.nextGaussian();""",
       isNull = "false")
   }
+
+  override def freshCopy(): Randn = Randn(child)
 }
 
 object Randn {

http://git-wip-us.apache.org/repos/asf/spark/blob/88d8de92/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
index 84190f0..b4138ce 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
@@ -180,7 +180,7 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
         null, null)
     }
     intercept[RuntimeException] {
-      checkEvalutionWithUnsafeProjection(
+      checkEvaluationWithUnsafeProjection(
         CreateMap(interlace(strWithNull, intSeq.map(Literal(_)))),
         null, null)
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/88d8de92/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
index 58d0c07..c6343b1 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
@@ -60,7 +60,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks {
     checkEvaluationWithoutCodegen(expr, catalystValue, inputRow)
     checkEvaluationWithGeneratedMutableProjection(expr, catalystValue, inputRow)
     if (GenerateUnsafeProjection.canSupport(expr.dataType)) {
-      checkEvalutionWithUnsafeProjection(expr, catalystValue, inputRow)
+      checkEvaluationWithUnsafeProjection(expr, catalystValue, inputRow)
     }
     checkEvaluationWithOptimization(expr, catalystValue, inputRow)
   }
@@ -187,11 +187,20 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks {
     plan(inputRow).get(0, expression.dataType)
   }
 
-  protected def checkEvalutionWithUnsafeProjection(
+  protected def checkEvaluationWithUnsafeProjection(
       expression: Expression,
       expected: Any,
       inputRow: InternalRow = EmptyRow): Unit = {
-    val unsafeRow = evaluateWithUnsafeProjection(expression, inputRow)
+    checkEvaluationWithUnsafeProjection(expression, expected, inputRow, UnsafeProjection)
+    checkEvaluationWithUnsafeProjection(expression, expected, inputRow, InterpretedUnsafeProjection)
+  }
+
+  protected def checkEvaluationWithUnsafeProjection(
+      expression: Expression,
+      expected: Any,
+      inputRow: InternalRow,
+      factory: UnsafeProjectionCreator): Unit = {
+    val unsafeRow = evaluateWithUnsafeProjection(expression, inputRow, factory)
     val input = if (inputRow == EmptyRow) "" else s", input: $inputRow"
 
     if (expected == null) {
@@ -203,7 +212,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks {
     } else {
       val lit = InternalRow(expected, expected)
       val expectedRow =
-        UnsafeProjection.create(Array(expression.dataType, expression.dataType)).apply(lit)
+        factory.create(Array(expression.dataType, expression.dataType)).apply(lit)
       if (unsafeRow != expectedRow) {
         fail("Incorrect evaluation in unsafe mode: " +
           s"$expression, actual: $unsafeRow, expected: $expectedRow$input")
@@ -213,7 +222,8 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks {
 
   private def evaluateWithUnsafeProjection(
       expression: Expression,
-      inputRow: InternalRow = EmptyRow): InternalRow = {
+      inputRow: InternalRow = EmptyRow,
+      factory: UnsafeProjectionCreator = UnsafeProjection): InternalRow = {
     // SPARK-16489 Explicitly doing code generation twice so code gen will fail if
     // some expression is reusing variable names across different instances.
     // This behavior is tested in ExpressionEvalHelperSuite.

http://git-wip-us.apache.org/repos/asf/spark/blob/88d8de92/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
index ffeec2a..1f6964d 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
@@ -45,16 +45,22 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
     val structInputRow = InternalRow.fromSeq(Seq(Array((1, 2), (3, 4))))
     val structExpected = new GenericArrayData(
       Array(InternalRow.fromSeq(Seq(1, 2)), InternalRow.fromSeq(Seq(3, 4))))
-    checkEvalutionWithUnsafeProjection(
-      structEncoder.serializer.head, structExpected, structInputRow)
+    checkEvaluationWithUnsafeProjection(
+      structEncoder.serializer.head,
+      structExpected,
+      structInputRow,
+      UnsafeProjection) // TODO(hvanhovell) revert this when SPARK-23587 is fixed
 
     // test UnsafeArray-backed data
     val arrayEncoder = ExpressionEncoder[Array[Array[Int]]]
     val arrayInputRow = InternalRow.fromSeq(Seq(Array(Array(1, 2), Array(3, 4))))
     val arrayExpected = new GenericArrayData(
       Array(new GenericArrayData(Array(1, 2)), new GenericArrayData(Array(3, 4))))
-    checkEvalutionWithUnsafeProjection(
-      arrayEncoder.serializer.head, arrayExpected, arrayInputRow)
+    checkEvaluationWithUnsafeProjection(
+      arrayEncoder.serializer.head,
+      arrayExpected,
+      arrayInputRow,
+      UnsafeProjection) // TODO(hvanhovell) revert this when SPARK-23587 is fixed
 
     // test UnsafeMap-backed data
     val mapEncoder = ExpressionEncoder[Array[Map[Int, Int]]]
@@ -67,8 +73,11 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
       new ArrayBasedMapData(
         new GenericArrayData(Array(3, 4)),
         new GenericArrayData(Array(300, 400)))))
-    checkEvalutionWithUnsafeProjection(
-      mapEncoder.serializer.head, mapExpected, mapInputRow)
+    checkEvaluationWithUnsafeProjection(
+      mapEncoder.serializer.head,
+      mapExpected,
+      mapInputRow,
+      UnsafeProjection) // TODO(hvanhovell) revert this when SPARK-23587 is fixed
   }
 
   test("SPARK-23585: UnwrapOption should support interpreted execution") {

http://git-wip-us.apache.org/repos/asf/spark/blob/88d8de92/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala
index 10e3ffd..e083ae0 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala
@@ -43,7 +43,7 @@ class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper {
     assert(e1.getMessage.contains("Failed to execute user defined function"))
 
     val e2 = intercept[SparkException] {
-      checkEvalutionWithUnsafeProjection(udf, null)
+      checkEvaluationWithUnsafeProjection(udf, null)
     }
     assert(e2.getMessage.contains("Failed to execute user defined function"))
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/88d8de92/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
index cf3cbe2..c07da12 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
@@ -25,7 +25,7 @@ import org.scalatest.Matchers
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.util._
-import org.apache.spark.sql.types._
+import org.apache.spark.sql.types.{IntegerType, LongType, _}
 import org.apache.spark.unsafe.array.ByteArrayMethods
 import org.apache.spark.unsafe.types.UTF8String
 
@@ -33,10 +33,18 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
 
   private def roundedSize(size: Int) = ByteArrayMethods.roundNumberOfBytesToNearestWord(size)
 
-  test("basic conversion with only primitive types") {
-    val fieldTypes: Array[DataType] = Array(LongType, LongType, IntegerType)
-    val converter = UnsafeProjection.create(fieldTypes)
+  private def testWithFactory(
+    name: String)(
+    f: UnsafeProjectionCreator => Unit): Unit = {
+    test(name) {
+      f(UnsafeProjection)
+      f(InterpretedUnsafeProjection)
+    }
+  }
 
+  testWithFactory("basic conversion with only primitive types") { factory =>
+    val fieldTypes: Array[DataType] = Array(LongType, LongType, IntegerType)
+    val converter = factory.create(fieldTypes)
     val row = new SpecificInternalRow(fieldTypes)
     row.setLong(0, 0)
     row.setLong(1, 1)
@@ -71,9 +79,9 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
     assert(unsafeRow2.getInt(2) === 2)
   }
 
-  test("basic conversion with primitive, string and binary types") {
+  testWithFactory("basic conversion with primitive, string and binary types") { factory =>
     val fieldTypes: Array[DataType] = Array(LongType, StringType, BinaryType)
-    val converter = UnsafeProjection.create(fieldTypes)
+    val converter = factory.create(fieldTypes)
 
     val row = new SpecificInternalRow(fieldTypes)
     row.setLong(0, 0)
@@ -90,9 +98,9 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
     assert(unsafeRow.getBinary(2) === "World".getBytes(StandardCharsets.UTF_8))
   }
 
-  test("basic conversion with primitive, string, date and timestamp types") {
+  testWithFactory("basic conversion with primitive, string, date and timestamp types") { factory =>
     val fieldTypes: Array[DataType] = Array(LongType, StringType, DateType, TimestampType)
-    val converter = UnsafeProjection.create(fieldTypes)
+    val converter = factory.create(fieldTypes)
 
     val row = new SpecificInternalRow(fieldTypes)
     row.setLong(0, 0)
@@ -119,7 +127,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
     (Timestamp.valueOf("2015-06-22 08:10:25"))
   }
 
-  test("null handling") {
+  testWithFactory("null handling") { factory =>
     val fieldTypes: Array[DataType] = Array(
       NullType,
       BooleanType,
@@ -135,7 +143,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
       DecimalType.SYSTEM_DEFAULT
       // ArrayType(IntegerType)
     )
-    val converter = UnsafeProjection.create(fieldTypes)
+    val converter = factory.create(fieldTypes)
 
     val rowWithAllNullColumns: InternalRow = {
       val r = new SpecificInternalRow(fieldTypes)
@@ -240,7 +248,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
     // assert(setToNullAfterCreation.get(11) === rowWithNoNullColumns.get(11))
   }
 
-  test("NaN canonicalization") {
+  testWithFactory("NaN canonicalization") { factory =>
     val fieldTypes: Array[DataType] = Array(FloatType, DoubleType)
 
     val row1 = new SpecificInternalRow(fieldTypes)
@@ -251,17 +259,17 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
     row2.setFloat(0, java.lang.Float.intBitsToFloat(0x7fffffff))
     row2.setDouble(1, java.lang.Double.longBitsToDouble(0x7fffffffffffffffL))
 
-    val converter = UnsafeProjection.create(fieldTypes)
+    val converter = factory.create(fieldTypes)
     assert(converter.apply(row1).getBytes === converter.apply(row2).getBytes)
   }
 
-  test("basic conversion with struct type") {
+  testWithFactory("basic conversion with struct type") { factory =>
     val fieldTypes: Array[DataType] = Array(
       new StructType().add("i", IntegerType),
       new StructType().add("nest", new StructType().add("l", LongType))
     )
 
-    val converter = UnsafeProjection.create(fieldTypes)
+    val converter = factory.create(fieldTypes)
 
     val row = new GenericInternalRow(fieldTypes.length)
     row.update(0, InternalRow(1))
@@ -317,12 +325,12 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
     assert(map.getSizeInBytes == 8 + map.keyArray.getSizeInBytes + map.valueArray.getSizeInBytes)
   }
 
-  test("basic conversion with array type") {
+  testWithFactory("basic conversion with array type") { factory =>
     val fieldTypes: Array[DataType] = Array(
       ArrayType(IntegerType),
       ArrayType(ArrayType(IntegerType))
     )
-    val converter = UnsafeProjection.create(fieldTypes)
+    val converter = factory.create(fieldTypes)
 
     val row = new GenericInternalRow(fieldTypes.length)
     row.update(0, createArray(1, 2))
@@ -347,12 +355,12 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
     assert(unsafeRow.getSizeInBytes == 8 + 8 * 2 + array1Size + array2Size)
   }
 
-  test("basic conversion with map type") {
+  testWithFactory("basic conversion with map type") { factory =>
     val fieldTypes: Array[DataType] = Array(
       MapType(IntegerType, IntegerType),
       MapType(IntegerType, MapType(IntegerType, IntegerType))
     )
-    val converter = UnsafeProjection.create(fieldTypes)
+    val converter = factory.create(fieldTypes)
 
     val map1 = createMap(1, 2)(3, 4)
 
@@ -393,12 +401,12 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
     assert(unsafeRow.getSizeInBytes == 8 + 8 * 2 + map1Size + map2Size)
   }
 
-  test("basic conversion with struct and array") {
+  testWithFactory("basic conversion with struct and array") { factory =>
     val fieldTypes: Array[DataType] = Array(
       new StructType().add("arr", ArrayType(IntegerType)),
       ArrayType(new StructType().add("l", LongType))
     )
-    val converter = UnsafeProjection.create(fieldTypes)
+    val converter = factory.create(fieldTypes)
 
     val row = new GenericInternalRow(fieldTypes.length)
     row.update(0, InternalRow(createArray(1)))
@@ -432,12 +440,12 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
       8 + 8 * 2 + field1.getSizeInBytes + roundedSize(field2.getSizeInBytes))
   }
 
-  test("basic conversion with struct and map") {
+  testWithFactory("basic conversion with struct and map") { factory =>
     val fieldTypes: Array[DataType] = Array(
       new StructType().add("map", MapType(IntegerType, IntegerType)),
       MapType(IntegerType, new StructType().add("l", LongType))
     )
-    val converter = UnsafeProjection.create(fieldTypes)
+    val converter = factory.create(fieldTypes)
 
     val row = new GenericInternalRow(fieldTypes.length)
     row.update(0, InternalRow(createMap(1)(2)))
@@ -478,12 +486,12 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
       8 + 8 * 2 + field1.getSizeInBytes + roundedSize(field2.getSizeInBytes))
   }
 
-  test("basic conversion with array and map") {
+  testWithFactory("basic conversion with array and map") { factory =>
     val fieldTypes: Array[DataType] = Array(
       ArrayType(MapType(IntegerType, IntegerType)),
       MapType(IntegerType, ArrayType(IntegerType))
     )
-    val converter = UnsafeProjection.create(fieldTypes)
+    val converter = factory.create(fieldTypes)
 
     val row = new GenericInternalRow(fieldTypes.length)
     row.update(0, createArray(createMap(1)(2)))


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org