You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by hv...@apache.org on 2018/03/16 17:28:22 UTC
spark git commit: [SPARK-23581][SQL] Add interpreted unsafe projection
Repository: spark
Updated Branches:
refs/heads/master dffeac369 -> 88d8de926
[SPARK-23581][SQL] Add interpreted unsafe projection
## What changes were proposed in this pull request?
We currently can only create unsafe rows using code generation. This is a problem for situations in which code generation fails. There is no fallback, and as a result we cannot execute the query.
This PR adds an interpreted version of `UnsafeProjection`. The implementation is modeled after `InterpretedMutableProjection`. It stores the expression results in a `GenericInternalRow`, and it then uses a conversion function to convert the `GenericInternalRow` into an `UnsafeRow`.
This PR does not implement the actual code generated to interpreted fallback logic. This will be done in a follow-up.
## How was this patch tested?
I am piggybacking on exiting `UnsafeProjection` tests, and I have added an interpreted version for each of these.
Author: Herman van Hovell <hv...@databricks.com>
Closes #20750 from hvanhovell/SPARK-23581.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/88d8de92
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/88d8de92
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/88d8de92
Branch: refs/heads/master
Commit: 88d8de9260edf6e9d5449ff7ef6e35d16051fc9f
Parents: dffeac3
Author: Herman van Hovell <hv...@databricks.com>
Authored: Fri Mar 16 18:28:16 2018 +0100
Committer: Herman van Hovell <hv...@databricks.com>
Committed: Fri Mar 16 18:28:16 2018 +0100
----------------------------------------------------------------------
.../expressions/codegen/UnsafeArrayWriter.java | 32 +-
.../expressions/codegen/UnsafeRowWriter.java | 30 +-
.../expressions/codegen/UnsafeWriter.java | 43 +++
.../sql/catalyst/expressions/Expression.scala | 26 ++
.../InterpretedUnsafeProjection.scala | 366 +++++++++++++++++++
.../expressions/MonotonicallyIncreasingID.scala | 4 +-
.../sql/catalyst/expressions/Projection.scala | 19 +-
.../codegen/GenerateUnsafeProjection.scala | 2 +-
.../expressions/randomExpressions.scala | 6 +-
.../catalyst/expressions/ComplexTypeSuite.scala | 2 +-
.../expressions/ExpressionEvalHelper.scala | 20 +-
.../expressions/ObjectExpressionsSuite.scala | 21 +-
.../catalyst/expressions/ScalaUDFSuite.scala | 2 +-
.../expressions/UnsafeRowConverterSuite.scala | 56 +--
14 files changed, 555 insertions(+), 74 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/88d8de92/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java
index 791e8d8..82cd1b2 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java
@@ -30,7 +30,7 @@ import static org.apache.spark.sql.catalyst.expressions.UnsafeArrayData.calculat
* A helper class to write data into global row buffer using `UnsafeArrayData` format,
* used by {@link org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection}.
*/
-public class UnsafeArrayWriter {
+public final class UnsafeArrayWriter extends UnsafeWriter {
private BufferHolder holder;
@@ -83,7 +83,7 @@ public class UnsafeArrayWriter {
return startingOffset + headerInBytes + ordinal * elementSize;
}
- public void setOffsetAndSize(int ordinal, long currentCursor, int size) {
+ public void setOffsetAndSize(int ordinal, int currentCursor, int size) {
assertIndexIsValid(ordinal);
final long relativeOffset = currentCursor - startingOffset;
final long offsetAndSize = (relativeOffset << 32) | (long)size;
@@ -96,49 +96,31 @@ public class UnsafeArrayWriter {
BitSetMethods.set(holder.buffer, startingOffset + 8, ordinal);
}
- public void setNullBoolean(int ordinal) {
- setNullBit(ordinal);
- // put zero into the corresponding field when set null
- Platform.putBoolean(holder.buffer, getElementOffset(ordinal, 1), false);
- }
-
- public void setNullByte(int ordinal) {
+ public void setNull1Bytes(int ordinal) {
setNullBit(ordinal);
// put zero into the corresponding field when set null
Platform.putByte(holder.buffer, getElementOffset(ordinal, 1), (byte)0);
}
- public void setNullShort(int ordinal) {
+ public void setNull2Bytes(int ordinal) {
setNullBit(ordinal);
// put zero into the corresponding field when set null
Platform.putShort(holder.buffer, getElementOffset(ordinal, 2), (short)0);
}
- public void setNullInt(int ordinal) {
+ public void setNull4Bytes(int ordinal) {
setNullBit(ordinal);
// put zero into the corresponding field when set null
Platform.putInt(holder.buffer, getElementOffset(ordinal, 4), 0);
}
- public void setNullLong(int ordinal) {
+ public void setNull8Bytes(int ordinal) {
setNullBit(ordinal);
// put zero into the corresponding field when set null
Platform.putLong(holder.buffer, getElementOffset(ordinal, 8), (long)0);
}
- public void setNullFloat(int ordinal) {
- setNullBit(ordinal);
- // put zero into the corresponding field when set null
- Platform.putFloat(holder.buffer, getElementOffset(ordinal, 4), (float)0);
- }
-
- public void setNullDouble(int ordinal) {
- setNullBit(ordinal);
- // put zero into the corresponding field when set null
- Platform.putDouble(holder.buffer, getElementOffset(ordinal, 8), (double)0);
- }
-
- public void setNull(int ordinal) { setNullLong(ordinal); }
+ public void setNull(int ordinal) { setNull8Bytes(ordinal); }
public void write(int ordinal, boolean value) {
assertIndexIsValid(ordinal);
http://git-wip-us.apache.org/repos/asf/spark/blob/88d8de92/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java
index 5d9515c..2620bbc 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java
@@ -38,7 +38,7 @@ import org.apache.spark.unsafe.types.UTF8String;
* beginning of the global row buffer, we don't need to update `startingOffset` and can just call
* `zeroOutNullBytes` before writing new data.
*/
-public class UnsafeRowWriter {
+public final class UnsafeRowWriter extends UnsafeWriter {
private final BufferHolder holder;
// The offset of the global buffer where we start to write this row.
@@ -93,18 +93,38 @@ public class UnsafeRowWriter {
Platform.putLong(holder.buffer, getFieldOffset(ordinal), 0L);
}
+ @Override
+ public void setNull1Bytes(int ordinal) {
+ setNullAt(ordinal);
+ }
+
+ @Override
+ public void setNull2Bytes(int ordinal) {
+ setNullAt(ordinal);
+ }
+
+ @Override
+ public void setNull4Bytes(int ordinal) {
+ setNullAt(ordinal);
+ }
+
+ @Override
+ public void setNull8Bytes(int ordinal) {
+ setNullAt(ordinal);
+ }
+
public long getFieldOffset(int ordinal) {
return startingOffset + nullBitsSize + 8 * ordinal;
}
- public void setOffsetAndSize(int ordinal, long size) {
+ public void setOffsetAndSize(int ordinal, int size) {
setOffsetAndSize(ordinal, holder.cursor, size);
}
- public void setOffsetAndSize(int ordinal, long currentCursor, long size) {
+ public void setOffsetAndSize(int ordinal, int currentCursor, int size) {
final long relativeOffset = currentCursor - startingOffset;
final long fieldOffset = getFieldOffset(ordinal);
- final long offsetAndSize = (relativeOffset << 32) | size;
+ final long offsetAndSize = (relativeOffset << 32) | (long) size;
Platform.putLong(holder.buffer, fieldOffset, offsetAndSize);
}
@@ -174,7 +194,7 @@ public class UnsafeRowWriter {
if (input == null || !input.changePrecision(precision, scale)) {
BitSetMethods.set(holder.buffer, startingOffset, ordinal);
// keep the offset for future update
- setOffsetAndSize(ordinal, 0L);
+ setOffsetAndSize(ordinal, 0);
} else {
final byte[] bytes = input.toJavaBigDecimal().unscaledValue().toByteArray();
assert bytes.length <= 16;
http://git-wip-us.apache.org/repos/asf/spark/blob/88d8de92/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java
new file mode 100644
index 0000000..c94b5c7
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.expressions.codegen;
+
+import org.apache.spark.sql.types.Decimal;
+import org.apache.spark.unsafe.types.CalendarInterval;
+import org.apache.spark.unsafe.types.UTF8String;
+
+/**
+ * Base class for writing Unsafe* structures.
+ */
+public abstract class UnsafeWriter {
+ public abstract void setNull1Bytes(int ordinal);
+ public abstract void setNull2Bytes(int ordinal);
+ public abstract void setNull4Bytes(int ordinal);
+ public abstract void setNull8Bytes(int ordinal);
+ public abstract void write(int ordinal, boolean value);
+ public abstract void write(int ordinal, byte value);
+ public abstract void write(int ordinal, short value);
+ public abstract void write(int ordinal, int value);
+ public abstract void write(int ordinal, long value);
+ public abstract void write(int ordinal, float value);
+ public abstract void write(int ordinal, double value);
+ public abstract void write(int ordinal, Decimal input, int precision, int scale);
+ public abstract void write(int ordinal, UTF8String input);
+ public abstract void write(int ordinal, byte[] input);
+ public abstract void write(int ordinal, CalendarInterval input);
+ public abstract void setOffsetAndSize(int ordinal, int currentCursor, int size);
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/88d8de92/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index ed90b18..d7f9e38 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -328,6 +328,32 @@ trait Nondeterministic extends Expression {
protected def evalInternal(input: InternalRow): Any
}
+/**
+ * An expression that contains mutable state. A stateful expression is always non-deterministic
+ * because the results it produces during evaluation are not only dependent on the given input
+ * but also on its internal state.
+ *
+ * The state of the expressions is generally not exposed in the parameter list and this makes
+ * comparing stateful expressions problematic because similar stateful expressions (with the same
+ * parameter list) but with different internal state will be considered equal. This is especially
+ * problematic during tree transformations. In order to counter this the `fastEquals` method for
+ * stateful expressions only returns `true` for the same reference.
+ *
+ * A stateful expression should never be evaluated multiple times for a single row. This should
+ * only be a problem for interpreted execution. This can be prevented by creating fresh copies
+ * of the stateful expression before execution, these can be made using the `freshCopy` function.
+ */
+trait Stateful extends Nondeterministic {
+ /**
+ * Return a fresh uninitialized copy of the stateful expression.
+ */
+ def freshCopy(): Stateful
+
+ /**
+ * Only the same reference is considered equal.
+ */
+ override def fastEquals(other: TreeNode[_]): Boolean = this eq other
+}
/**
* A leaf expression, i.e. one without any child expressions.
http://git-wip-us.apache.org/repos/asf/spark/blob/88d8de92/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala
new file mode 100644
index 0000000..0da5ece
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala
@@ -0,0 +1,366 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.SparkException
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeArrayWriter, UnsafeRowWriter, UnsafeWriter}
+import org.apache.spark.sql.catalyst.util.ArrayData
+import org.apache.spark.sql.types.{UserDefinedType, _}
+import org.apache.spark.unsafe.Platform
+
+/**
+ * An interpreted unsafe projection. This class reuses the [[UnsafeRow]] it produces, a consumer
+ * should copy the row if it is being buffered. This class is not thread safe.
+ *
+ * @param expressions that produces the resulting fields. These expressions must be bound
+ * to a schema.
+ */
+class InterpretedUnsafeProjection(expressions: Array[Expression]) extends UnsafeProjection {
+ import InterpretedUnsafeProjection._
+
+ /** Number of (top level) fields in the resulting row. */
+ private[this] val numFields = expressions.length
+
+ /** Array that expression results. */
+ private[this] val values = new Array[Any](numFields)
+
+ /** The row representing the expression results. */
+ private[this] val intermediate = new GenericInternalRow(values)
+
+ /** The row returned by the projection. */
+ private[this] val result = new UnsafeRow(numFields)
+
+ /** The buffer which holds the resulting row's backing data. */
+ private[this] val holder = new BufferHolder(result, numFields * 32)
+
+ /** The writer that writes the intermediate result to the result row. */
+ private[this] val writer: InternalRow => Unit = {
+ val rowWriter = new UnsafeRowWriter(holder, numFields)
+ val baseWriter = generateStructWriter(
+ holder,
+ rowWriter,
+ expressions.map(e => StructField("", e.dataType, e.nullable)))
+ if (!expressions.exists(_.nullable)) {
+ // No nullable fields. The top-level null bit mask will always be zeroed out.
+ baseWriter
+ } else {
+ // Zero out the null bit mask before we write the row.
+ row => {
+ rowWriter.zeroOutNullBytes()
+ baseWriter(row)
+ }
+ }
+ }
+
+ override def initialize(partitionIndex: Int): Unit = {
+ expressions.foreach(_.foreach {
+ case n: Nondeterministic => n.initialize(partitionIndex)
+ case _ =>
+ })
+ }
+
+ override def apply(row: InternalRow): UnsafeRow = {
+ // Put the expression results in the intermediate row.
+ var i = 0
+ while (i < numFields) {
+ values(i) = expressions(i).eval(row)
+ i += 1
+ }
+
+ // Write the intermediate row to an unsafe row.
+ holder.reset()
+ writer(intermediate)
+ result.setTotalSize(holder.totalSize())
+ result
+ }
+}
+
+/**
+ * Helper functions for creating an [[InterpretedUnsafeProjection]].
+ */
+object InterpretedUnsafeProjection extends UnsafeProjectionCreator {
+
+ /**
+ * Returns an [[UnsafeProjection]] for given sequence of bound Expressions.
+ */
+ override protected def createProjection(exprs: Seq[Expression]): UnsafeProjection = {
+ // We need to make sure that we do not reuse stateful expressions.
+ val cleanedExpressions = exprs.map(_.transform {
+ case s: Stateful => s.freshCopy()
+ })
+ new InterpretedUnsafeProjection(cleanedExpressions.toArray)
+ }
+
+ /**
+ * Generate a struct writer function. The generated function writes an [[InternalRow]] to the
+ * given buffer using the given [[UnsafeRowWriter]].
+ */
+ private def generateStructWriter(
+ bufferHolder: BufferHolder,
+ rowWriter: UnsafeRowWriter,
+ fields: Array[StructField]): InternalRow => Unit = {
+ val numFields = fields.length
+
+ // Create field writers.
+ val fieldWriters = fields.map { field =>
+ generateFieldWriter(bufferHolder, rowWriter, field.dataType, field.nullable)
+ }
+ // Create basic writer.
+ row => {
+ var i = 0
+ while (i < numFields) {
+ fieldWriters(i).apply(row, i)
+ i += 1
+ }
+ }
+ }
+
+ /**
+ * Generate a writer function for a struct field, array element, map key or map value. The
+ * generated function writes the element at an index in a [[SpecializedGetters]] object (row
+ * or array) to the given buffer using the given [[UnsafeWriter]].
+ */
+ private def generateFieldWriter(
+ bufferHolder: BufferHolder,
+ writer: UnsafeWriter,
+ dt: DataType,
+ nullable: Boolean): (SpecializedGetters, Int) => Unit = {
+
+ // Create the the basic writer.
+ val unsafeWriter: (SpecializedGetters, Int) => Unit = dt match {
+ case BooleanType =>
+ (v, i) => writer.write(i, v.getBoolean(i))
+
+ case ByteType =>
+ (v, i) => writer.write(i, v.getByte(i))
+
+ case ShortType =>
+ (v, i) => writer.write(i, v.getShort(i))
+
+ case IntegerType | DateType =>
+ (v, i) => writer.write(i, v.getInt(i))
+
+ case LongType | TimestampType =>
+ (v, i) => writer.write(i, v.getLong(i))
+
+ case FloatType =>
+ (v, i) => writer.write(i, v.getFloat(i))
+
+ case DoubleType =>
+ (v, i) => writer.write(i, v.getDouble(i))
+
+ case DecimalType.Fixed(precision, scale) =>
+ (v, i) => writer.write(i, v.getDecimal(i, precision, scale), precision, scale)
+
+ case CalendarIntervalType =>
+ (v, i) => writer.write(i, v.getInterval(i))
+
+ case BinaryType =>
+ (v, i) => writer.write(i, v.getBinary(i))
+
+ case StringType =>
+ (v, i) => writer.write(i, v.getUTF8String(i))
+
+ case StructType(fields) =>
+ val numFields = fields.length
+ val rowWriter = new UnsafeRowWriter(bufferHolder, numFields)
+ val structWriter = generateStructWriter(bufferHolder, rowWriter, fields)
+ (v, i) => {
+ val tmpCursor = bufferHolder.cursor
+ v.getStruct(i, fields.length) match {
+ case row: UnsafeRow =>
+ writeUnsafeData(
+ bufferHolder,
+ row.getBaseObject,
+ row.getBaseOffset,
+ row.getSizeInBytes)
+ case row =>
+ // Nested struct. We don't know where this will start because a row can be
+ // variable length, so we need to update the offsets and zero out the bit mask.
+ rowWriter.reset()
+ structWriter.apply(row)
+ }
+ writer.setOffsetAndSize(i, tmpCursor, bufferHolder.cursor - tmpCursor)
+ }
+
+ case ArrayType(elementType, containsNull) =>
+ val arrayWriter = new UnsafeArrayWriter
+ val elementSize = getElementSize(elementType)
+ val elementWriter = generateFieldWriter(
+ bufferHolder,
+ arrayWriter,
+ elementType,
+ containsNull)
+ (v, i) => {
+ val tmpCursor = bufferHolder.cursor
+ writeArray(bufferHolder, arrayWriter, elementWriter, v.getArray(i), elementSize)
+ writer.setOffsetAndSize(i, tmpCursor, bufferHolder.cursor - tmpCursor)
+ }
+
+ case MapType(keyType, valueType, valueContainsNull) =>
+ val keyArrayWriter = new UnsafeArrayWriter
+ val keySize = getElementSize(keyType)
+ val keyWriter = generateFieldWriter(
+ bufferHolder,
+ keyArrayWriter,
+ keyType,
+ nullable = false)
+ val valueArrayWriter = new UnsafeArrayWriter
+ val valueSize = getElementSize(valueType)
+ val valueWriter = generateFieldWriter(
+ bufferHolder,
+ valueArrayWriter,
+ valueType,
+ valueContainsNull)
+ (v, i) => {
+ val tmpCursor = bufferHolder.cursor
+ v.getMap(i) match {
+ case map: UnsafeMapData =>
+ writeUnsafeData(
+ bufferHolder,
+ map.getBaseObject,
+ map.getBaseOffset,
+ map.getSizeInBytes)
+ case map =>
+ // preserve 8 bytes to write the key array numBytes later.
+ bufferHolder.grow(8)
+ bufferHolder.cursor += 8
+
+ // Write the keys and write the numBytes of key array into the first 8 bytes.
+ writeArray(bufferHolder, keyArrayWriter, keyWriter, map.keyArray(), keySize)
+ Platform.putLong(bufferHolder.buffer, tmpCursor, bufferHolder.cursor - tmpCursor - 8)
+
+ // Write the values.
+ writeArray(bufferHolder, valueArrayWriter, valueWriter, map.valueArray(), valueSize)
+ }
+ writer.setOffsetAndSize(i, tmpCursor, bufferHolder.cursor - tmpCursor)
+ }
+
+ case udt: UserDefinedType[_] =>
+ generateFieldWriter(bufferHolder, writer, udt.sqlType, nullable)
+
+ case NullType =>
+ (_, _) => {}
+
+ case _ =>
+ throw new SparkException(s"Unsupported data type $dt")
+ }
+
+ // Always wrap the writer with a null safe version.
+ dt match {
+ case _: UserDefinedType[_] =>
+ // The null wrapper depends on the sql type and not on the UDT.
+ unsafeWriter
+ case DecimalType.Fixed(precision, _) if precision > Decimal.MAX_LONG_DIGITS =>
+ // We can't call setNullAt() for DecimalType with precision larger than 18, we call write
+ // directly. We can use the unwrapped writer directly.
+ unsafeWriter
+ case BooleanType | ByteType =>
+ (v, i) => {
+ if (!v.isNullAt(i)) {
+ unsafeWriter(v, i)
+ } else {
+ writer.setNull1Bytes(i)
+ }
+ }
+ case ShortType =>
+ (v, i) => {
+ if (!v.isNullAt(i)) {
+ unsafeWriter(v, i)
+ } else {
+ writer.setNull2Bytes(i)
+ }
+ }
+ case IntegerType | DateType | FloatType =>
+ (v, i) => {
+ if (!v.isNullAt(i)) {
+ unsafeWriter(v, i)
+ } else {
+ writer.setNull4Bytes(i)
+ }
+ }
+ case _ =>
+ (v, i) => {
+ if (!v.isNullAt(i)) {
+ unsafeWriter(v, i)
+ } else {
+ writer.setNull8Bytes(i)
+ }
+ }
+ }
+ }
+
+ /**
+ * Get the number of bytes elements of a data type will occupy in the fixed part of an
+ * [[UnsafeArrayData]] object. Reference types are stored as an 8 byte combination of an
+ * offset (upper 4 bytes) and a length (lower 4 bytes), these point to the variable length
+ * portion of the array object. Primitives take up to 8 bytes, depending on the size of the
+ * underlying data type.
+ */
+ private def getElementSize(dataType: DataType): Int = dataType match {
+ case NullType | StringType | BinaryType | CalendarIntervalType |
+ _: DecimalType | _: StructType | _: ArrayType | _: MapType => 8
+ case _ => dataType.defaultSize
+ }
+
+ /**
+ * Write an array to the buffer. If the array is already in serialized form (an instance of
+ * [[UnsafeArrayData]]) then we copy the bytes directly, otherwise we do an element-by-element
+ * copy.
+ */
+ private def writeArray(
+ bufferHolder: BufferHolder,
+ arrayWriter: UnsafeArrayWriter,
+ elementWriter: (SpecializedGetters, Int) => Unit,
+ array: ArrayData,
+ elementSize: Int): Unit = array match {
+ case unsafe: UnsafeArrayData =>
+ writeUnsafeData(
+ bufferHolder,
+ unsafe.getBaseObject,
+ unsafe.getBaseOffset,
+ unsafe.getSizeInBytes)
+ case _ =>
+ val numElements = array.numElements()
+ arrayWriter.initialize(bufferHolder, numElements, elementSize)
+ var i = 0
+ while (i < numElements) {
+ elementWriter.apply(array, i)
+ i += 1
+ }
+ }
+
+ /**
+ * Write an opaque block of data to the buffer. This is used to copy
+ * [[UnsafeRow]], [[UnsafeArrayData]] and [[UnsafeMapData]] objects.
+ */
+ private def writeUnsafeData(
+ bufferHolder: BufferHolder,
+ baseObject: AnyRef,
+ baseOffset: Long,
+ sizeInBytes: Int) : Unit = {
+ bufferHolder.grow(sizeInBytes)
+ Platform.copyMemory(
+ baseObject,
+ baseOffset,
+ bufferHolder.buffer,
+ bufferHolder.cursor,
+ sizeInBytes)
+ bufferHolder.cursor += sizeInBytes
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/88d8de92/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala
index 4523079..dd523d3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala
@@ -39,7 +39,7 @@ import org.apache.spark.sql.types.{DataType, LongType}
within each partition. The assumption is that the data frame has less than 1 billion
partitions, and each partition has less than 8 billion records.
""")
-case class MonotonicallyIncreasingID() extends LeafExpression with Nondeterministic {
+case class MonotonicallyIncreasingID() extends LeafExpression with Stateful {
/**
* Record ID within each partition. By being transient, count's value is reset to 0 every time
@@ -79,4 +79,6 @@ case class MonotonicallyIncreasingID() extends LeafExpression with Nondeterminis
override def prettyName: String = "monotonically_increasing_id"
override def sql: String = s"$prettyName()"
+
+ override def freshCopy(): MonotonicallyIncreasingID = MonotonicallyIncreasingID()
}
http://git-wip-us.apache.org/repos/asf/spark/blob/88d8de92/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
index 64b94f0..3cd7368 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
@@ -108,8 +108,7 @@ abstract class UnsafeProjection extends Projection {
override def apply(row: InternalRow): UnsafeRow
}
-object UnsafeProjection {
-
+trait UnsafeProjectionCreator {
/**
* Returns an UnsafeProjection for given StructType.
*
@@ -127,13 +126,13 @@ object UnsafeProjection {
}
/**
- * Returns an UnsafeProjection for given sequence of Expressions (bounded).
+ * Returns an UnsafeProjection for given sequence of bound Expressions.
*/
def create(exprs: Seq[Expression]): UnsafeProjection = {
val unsafeExprs = exprs.map(_ transform {
case CreateNamedStruct(children) => CreateNamedStructUnsafe(children)
})
- GenerateUnsafeProjection.generate(unsafeExprs)
+ createProjection(unsafeExprs)
}
def create(expr: Expression): UnsafeProjection = create(Seq(expr))
@@ -147,6 +146,18 @@ object UnsafeProjection {
}
/**
+ * Returns an [[UnsafeProjection]] for given sequence of bound Expressions.
+ */
+ protected def createProjection(exprs: Seq[Expression]): UnsafeProjection
+}
+
+object UnsafeProjection extends UnsafeProjectionCreator {
+
+ override protected def createProjection(exprs: Seq[Expression]): UnsafeProjection = {
+ GenerateUnsafeProjection.generate(exprs)
+ }
+
+ /**
* Same as other create()'s but allowing enabling/disabling subexpression elimination.
* TODO: refactor the plumbing and clean this up.
*/
http://git-wip-us.apache.org/repos/asf/spark/blob/88d8de92/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
index 22717f5..6682ba5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
@@ -247,7 +247,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
for (int $index = 0; $index < $numElements; $index++) {
if ($tmpInput.isNullAt($index)) {
- $arrayWriter.setNull$primitiveTypeName($index);
+ $arrayWriter.setNull${elementOrOffsetSize}Bytes($index);
} else {
$writeElement
}
http://git-wip-us.apache.org/repos/asf/spark/blob/88d8de92/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala
index 6c9937d..f366338 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala
@@ -31,7 +31,7 @@ import org.apache.spark.util.random.XORShiftRandom
*
* Since this expression is stateful, it cannot be a case object.
*/
-abstract class RDG extends UnaryExpression with ExpectsInputTypes with Nondeterministic {
+abstract class RDG extends UnaryExpression with ExpectsInputTypes with Stateful {
/**
* Record ID within each partition. By being transient, the Random Number Generator is
* reset every time we serialize and deserialize and initialize it.
@@ -85,6 +85,8 @@ case class Rand(child: Expression) extends RDG {
final ${CodeGenerator.javaType(dataType)} ${ev.value} = $rngTerm.nextDouble();""",
isNull = "false")
}
+
+ override def freshCopy(): Rand = Rand(child)
}
object Rand {
@@ -120,6 +122,8 @@ case class Randn(child: Expression) extends RDG {
final ${CodeGenerator.javaType(dataType)} ${ev.value} = $rngTerm.nextGaussian();""",
isNull = "false")
}
+
+ override def freshCopy(): Randn = Randn(child)
}
object Randn {
http://git-wip-us.apache.org/repos/asf/spark/blob/88d8de92/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
index 84190f0..b4138ce 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
@@ -180,7 +180,7 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
null, null)
}
intercept[RuntimeException] {
- checkEvalutionWithUnsafeProjection(
+ checkEvaluationWithUnsafeProjection(
CreateMap(interlace(strWithNull, intSeq.map(Literal(_)))),
null, null)
}
http://git-wip-us.apache.org/repos/asf/spark/blob/88d8de92/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
index 58d0c07..c6343b1 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
@@ -60,7 +60,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks {
checkEvaluationWithoutCodegen(expr, catalystValue, inputRow)
checkEvaluationWithGeneratedMutableProjection(expr, catalystValue, inputRow)
if (GenerateUnsafeProjection.canSupport(expr.dataType)) {
- checkEvalutionWithUnsafeProjection(expr, catalystValue, inputRow)
+ checkEvaluationWithUnsafeProjection(expr, catalystValue, inputRow)
}
checkEvaluationWithOptimization(expr, catalystValue, inputRow)
}
@@ -187,11 +187,20 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks {
plan(inputRow).get(0, expression.dataType)
}
- protected def checkEvalutionWithUnsafeProjection(
+ protected def checkEvaluationWithUnsafeProjection(
expression: Expression,
expected: Any,
inputRow: InternalRow = EmptyRow): Unit = {
- val unsafeRow = evaluateWithUnsafeProjection(expression, inputRow)
+ checkEvaluationWithUnsafeProjection(expression, expected, inputRow, UnsafeProjection)
+ checkEvaluationWithUnsafeProjection(expression, expected, inputRow, InterpretedUnsafeProjection)
+ }
+
+ protected def checkEvaluationWithUnsafeProjection(
+ expression: Expression,
+ expected: Any,
+ inputRow: InternalRow,
+ factory: UnsafeProjectionCreator): Unit = {
+ val unsafeRow = evaluateWithUnsafeProjection(expression, inputRow, factory)
val input = if (inputRow == EmptyRow) "" else s", input: $inputRow"
if (expected == null) {
@@ -203,7 +212,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks {
} else {
val lit = InternalRow(expected, expected)
val expectedRow =
- UnsafeProjection.create(Array(expression.dataType, expression.dataType)).apply(lit)
+ factory.create(Array(expression.dataType, expression.dataType)).apply(lit)
if (unsafeRow != expectedRow) {
fail("Incorrect evaluation in unsafe mode: " +
s"$expression, actual: $unsafeRow, expected: $expectedRow$input")
@@ -213,7 +222,8 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks {
private def evaluateWithUnsafeProjection(
expression: Expression,
- inputRow: InternalRow = EmptyRow): InternalRow = {
+ inputRow: InternalRow = EmptyRow,
+ factory: UnsafeProjectionCreator = UnsafeProjection): InternalRow = {
// SPARK-16489 Explicitly doing code generation twice so code gen will fail if
// some expression is reusing variable names across different instances.
// This behavior is tested in ExpressionEvalHelperSuite.
http://git-wip-us.apache.org/repos/asf/spark/blob/88d8de92/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
index ffeec2a..1f6964d 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
@@ -45,16 +45,22 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
val structInputRow = InternalRow.fromSeq(Seq(Array((1, 2), (3, 4))))
val structExpected = new GenericArrayData(
Array(InternalRow.fromSeq(Seq(1, 2)), InternalRow.fromSeq(Seq(3, 4))))
- checkEvalutionWithUnsafeProjection(
- structEncoder.serializer.head, structExpected, structInputRow)
+ checkEvaluationWithUnsafeProjection(
+ structEncoder.serializer.head,
+ structExpected,
+ structInputRow,
+ UnsafeProjection) // TODO(hvanhovell) revert this when SPARK-23587 is fixed
// test UnsafeArray-backed data
val arrayEncoder = ExpressionEncoder[Array[Array[Int]]]
val arrayInputRow = InternalRow.fromSeq(Seq(Array(Array(1, 2), Array(3, 4))))
val arrayExpected = new GenericArrayData(
Array(new GenericArrayData(Array(1, 2)), new GenericArrayData(Array(3, 4))))
- checkEvalutionWithUnsafeProjection(
- arrayEncoder.serializer.head, arrayExpected, arrayInputRow)
+ checkEvaluationWithUnsafeProjection(
+ arrayEncoder.serializer.head,
+ arrayExpected,
+ arrayInputRow,
+ UnsafeProjection) // TODO(hvanhovell) revert this when SPARK-23587 is fixed
// test UnsafeMap-backed data
val mapEncoder = ExpressionEncoder[Array[Map[Int, Int]]]
@@ -67,8 +73,11 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
new ArrayBasedMapData(
new GenericArrayData(Array(3, 4)),
new GenericArrayData(Array(300, 400)))))
- checkEvalutionWithUnsafeProjection(
- mapEncoder.serializer.head, mapExpected, mapInputRow)
+ checkEvaluationWithUnsafeProjection(
+ mapEncoder.serializer.head,
+ mapExpected,
+ mapInputRow,
+ UnsafeProjection) // TODO(hvanhovell) revert this when SPARK-23587 is fixed
}
test("SPARK-23585: UnwrapOption should support interpreted execution") {
http://git-wip-us.apache.org/repos/asf/spark/blob/88d8de92/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala
index 10e3ffd..e083ae0 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala
@@ -43,7 +43,7 @@ class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper {
assert(e1.getMessage.contains("Failed to execute user defined function"))
val e2 = intercept[SparkException] {
- checkEvalutionWithUnsafeProjection(udf, null)
+ checkEvaluationWithUnsafeProjection(udf, null)
}
assert(e2.getMessage.contains("Failed to execute user defined function"))
}
http://git-wip-us.apache.org/repos/asf/spark/blob/88d8de92/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
index cf3cbe2..c07da12 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
@@ -25,7 +25,7 @@ import org.scalatest.Matchers
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util._
-import org.apache.spark.sql.types._
+import org.apache.spark.sql.types.{IntegerType, LongType, _}
import org.apache.spark.unsafe.array.ByteArrayMethods
import org.apache.spark.unsafe.types.UTF8String
@@ -33,10 +33,18 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
private def roundedSize(size: Int) = ByteArrayMethods.roundNumberOfBytesToNearestWord(size)
- test("basic conversion with only primitive types") {
- val fieldTypes: Array[DataType] = Array(LongType, LongType, IntegerType)
- val converter = UnsafeProjection.create(fieldTypes)
+ private def testWithFactory(
+ name: String)(
+ f: UnsafeProjectionCreator => Unit): Unit = {
+ test(name) {
+ f(UnsafeProjection)
+ f(InterpretedUnsafeProjection)
+ }
+ }
+ testWithFactory("basic conversion with only primitive types") { factory =>
+ val fieldTypes: Array[DataType] = Array(LongType, LongType, IntegerType)
+ val converter = factory.create(fieldTypes)
val row = new SpecificInternalRow(fieldTypes)
row.setLong(0, 0)
row.setLong(1, 1)
@@ -71,9 +79,9 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
assert(unsafeRow2.getInt(2) === 2)
}
- test("basic conversion with primitive, string and binary types") {
+ testWithFactory("basic conversion with primitive, string and binary types") { factory =>
val fieldTypes: Array[DataType] = Array(LongType, StringType, BinaryType)
- val converter = UnsafeProjection.create(fieldTypes)
+ val converter = factory.create(fieldTypes)
val row = new SpecificInternalRow(fieldTypes)
row.setLong(0, 0)
@@ -90,9 +98,9 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
assert(unsafeRow.getBinary(2) === "World".getBytes(StandardCharsets.UTF_8))
}
- test("basic conversion with primitive, string, date and timestamp types") {
+ testWithFactory("basic conversion with primitive, string, date and timestamp types") { factory =>
val fieldTypes: Array[DataType] = Array(LongType, StringType, DateType, TimestampType)
- val converter = UnsafeProjection.create(fieldTypes)
+ val converter = factory.create(fieldTypes)
val row = new SpecificInternalRow(fieldTypes)
row.setLong(0, 0)
@@ -119,7 +127,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
(Timestamp.valueOf("2015-06-22 08:10:25"))
}
- test("null handling") {
+ testWithFactory("null handling") { factory =>
val fieldTypes: Array[DataType] = Array(
NullType,
BooleanType,
@@ -135,7 +143,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
DecimalType.SYSTEM_DEFAULT
// ArrayType(IntegerType)
)
- val converter = UnsafeProjection.create(fieldTypes)
+ val converter = factory.create(fieldTypes)
val rowWithAllNullColumns: InternalRow = {
val r = new SpecificInternalRow(fieldTypes)
@@ -240,7 +248,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
// assert(setToNullAfterCreation.get(11) === rowWithNoNullColumns.get(11))
}
- test("NaN canonicalization") {
+ testWithFactory("NaN canonicalization") { factory =>
val fieldTypes: Array[DataType] = Array(FloatType, DoubleType)
val row1 = new SpecificInternalRow(fieldTypes)
@@ -251,17 +259,17 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
row2.setFloat(0, java.lang.Float.intBitsToFloat(0x7fffffff))
row2.setDouble(1, java.lang.Double.longBitsToDouble(0x7fffffffffffffffL))
- val converter = UnsafeProjection.create(fieldTypes)
+ val converter = factory.create(fieldTypes)
assert(converter.apply(row1).getBytes === converter.apply(row2).getBytes)
}
- test("basic conversion with struct type") {
+ testWithFactory("basic conversion with struct type") { factory =>
val fieldTypes: Array[DataType] = Array(
new StructType().add("i", IntegerType),
new StructType().add("nest", new StructType().add("l", LongType))
)
- val converter = UnsafeProjection.create(fieldTypes)
+ val converter = factory.create(fieldTypes)
val row = new GenericInternalRow(fieldTypes.length)
row.update(0, InternalRow(1))
@@ -317,12 +325,12 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
assert(map.getSizeInBytes == 8 + map.keyArray.getSizeInBytes + map.valueArray.getSizeInBytes)
}
- test("basic conversion with array type") {
+ testWithFactory("basic conversion with array type") { factory =>
val fieldTypes: Array[DataType] = Array(
ArrayType(IntegerType),
ArrayType(ArrayType(IntegerType))
)
- val converter = UnsafeProjection.create(fieldTypes)
+ val converter = factory.create(fieldTypes)
val row = new GenericInternalRow(fieldTypes.length)
row.update(0, createArray(1, 2))
@@ -347,12 +355,12 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
assert(unsafeRow.getSizeInBytes == 8 + 8 * 2 + array1Size + array2Size)
}
- test("basic conversion with map type") {
+ testWithFactory("basic conversion with map type") { factory =>
val fieldTypes: Array[DataType] = Array(
MapType(IntegerType, IntegerType),
MapType(IntegerType, MapType(IntegerType, IntegerType))
)
- val converter = UnsafeProjection.create(fieldTypes)
+ val converter = factory.create(fieldTypes)
val map1 = createMap(1, 2)(3, 4)
@@ -393,12 +401,12 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
assert(unsafeRow.getSizeInBytes == 8 + 8 * 2 + map1Size + map2Size)
}
- test("basic conversion with struct and array") {
+ testWithFactory("basic conversion with struct and array") { factory =>
val fieldTypes: Array[DataType] = Array(
new StructType().add("arr", ArrayType(IntegerType)),
ArrayType(new StructType().add("l", LongType))
)
- val converter = UnsafeProjection.create(fieldTypes)
+ val converter = factory.create(fieldTypes)
val row = new GenericInternalRow(fieldTypes.length)
row.update(0, InternalRow(createArray(1)))
@@ -432,12 +440,12 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
8 + 8 * 2 + field1.getSizeInBytes + roundedSize(field2.getSizeInBytes))
}
- test("basic conversion with struct and map") {
+ testWithFactory("basic conversion with struct and map") { factory =>
val fieldTypes: Array[DataType] = Array(
new StructType().add("map", MapType(IntegerType, IntegerType)),
MapType(IntegerType, new StructType().add("l", LongType))
)
- val converter = UnsafeProjection.create(fieldTypes)
+ val converter = factory.create(fieldTypes)
val row = new GenericInternalRow(fieldTypes.length)
row.update(0, InternalRow(createMap(1)(2)))
@@ -478,12 +486,12 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
8 + 8 * 2 + field1.getSizeInBytes + roundedSize(field2.getSizeInBytes))
}
- test("basic conversion with array and map") {
+ testWithFactory("basic conversion with array and map") { factory =>
val fieldTypes: Array[DataType] = Array(
ArrayType(MapType(IntegerType, IntegerType)),
MapType(IntegerType, ArrayType(IntegerType))
)
- val converter = UnsafeProjection.create(fieldTypes)
+ val converter = factory.create(fieldTypes)
val row = new GenericInternalRow(fieldTypes.length)
row.update(0, createArray(createMap(1)(2)))
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org