You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2017/07/27 11:20:00 UTC
spark git commit: [SPARK-21440][SQL][PYSPARK] Refactor
ArrowConverters and add ArrayType and StructType support.
Repository: spark
Updated Branches:
refs/heads/master ebbe589d1 -> 2ff35a057
[SPARK-21440][SQL][PYSPARK] Refactor ArrowConverters and add ArrayType and StructType support.
## What changes were proposed in this pull request?
This is a refactoring of `ArrowConverters` and related classes.
1. Refactor `ColumnWriter` as `ArrowWriter`.
2. Add `ArrayType` and `StructType` support.
3. Refactor `ArrowConverters` to skip intermediate `ArrowRecordBatch` creation.
## How was this patch tested?
Added some tests and existing tests.
Author: Takuya UESHIN <ue...@databricks.com>
Closes #18655 from ueshin/issues/SPARK-21440.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2ff35a05
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2ff35a05
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2ff35a05
Branch: refs/heads/master
Commit: 2ff35a057efd36bd5c8a545a1ec3bc341432a904
Parents: ebbe589
Author: Takuya UESHIN <ue...@databricks.com>
Authored: Thu Jul 27 19:19:51 2017 +0800
Committer: Wenchen Fan <we...@databricks.com>
Committed: Thu Jul 27 19:19:51 2017 +0800
----------------------------------------------------------------------
python/pyspark/sql/tests.py | 4 +-
.../scala/org/apache/spark/sql/Dataset.scala | 4 +-
.../sql/execution/arrow/ArrowConverters.scala | 351 ++-------------
.../spark/sql/execution/arrow/ArrowWriter.scala | 323 ++++++++++++++
.../execution/arrow/ArrowConvertersSuite.scala | 447 ++++++++++++++++++-
.../sql/execution/arrow/ArrowWriterSuite.scala | 260 +++++++++++
6 files changed, 1074 insertions(+), 315 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/2ff35a05/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 1c1a0ca..54756ed 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -3018,8 +3018,8 @@ class ArrowTests(ReusedPySparkTestCase):
self.assertTrue(df_without.equals(df_with_arrow), msg=msg)
def test_unsupported_datatype(self):
- schema = StructType([StructField("array", ArrayType(IntegerType(), False), True)])
- df = self.spark.createDataFrame([([1, 2, 3],)], schema=schema)
+ schema = StructType([StructField("dt", DateType(), True)])
+ df = self.spark.createDataFrame([(datetime.date(1970, 1, 1),)], schema=schema)
with QuietTest(self.sc):
self.assertRaises(Exception, lambda: df.toPandas())
http://git-wip-us.apache.org/repos/asf/spark/blob/2ff35a05/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 71ab0dd..9007367 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -27,6 +27,7 @@ import scala.util.control.NonFatal
import org.apache.commons.lang3.StringUtils
+import org.apache.spark.TaskContext
import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability}
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.api.java.function._
@@ -3090,7 +3091,8 @@ class Dataset[T] private[sql](
val schemaCaptured = this.schema
val maxRecordsPerBatch = sparkSession.sessionState.conf.arrowMaxRecordsPerBatch
queryExecution.toRdd.mapPartitionsInternal { iter =>
- ArrowConverters.toPayloadIterator(iter, schemaCaptured, maxRecordsPerBatch)
+ val context = TaskContext.get()
+ ArrowConverters.toPayloadIterator(iter, schemaCaptured, maxRecordsPerBatch, context)
}
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/2ff35a05/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
index c913efe..240f38f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
@@ -20,18 +20,13 @@ package org.apache.spark.sql.execution.arrow
import java.io.ByteArrayOutputStream
import java.nio.channels.Channels
-import scala.collection.JavaConverters._
-
-import io.netty.buffer.ArrowBuf
-import org.apache.arrow.memory.{BufferAllocator, RootAllocator}
+import org.apache.arrow.memory.BufferAllocator
import org.apache.arrow.vector._
-import org.apache.arrow.vector.BaseValueVector.BaseMutator
import org.apache.arrow.vector.file._
-import org.apache.arrow.vector.schema.{ArrowFieldNode, ArrowRecordBatch}
-import org.apache.arrow.vector.types.FloatingPointPrecision
-import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema}
+import org.apache.arrow.vector.schema.ArrowRecordBatch
import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel
+import org.apache.spark.TaskContext
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
@@ -55,19 +50,6 @@ private[sql] class ArrowPayload private[arrow] (payload: Array[Byte]) extends Se
def asPythonSerializable: Array[Byte] = payload
}
-private[sql] object ArrowPayload {
-
- /**
- * Create an ArrowPayload from an ArrowRecordBatch and Spark schema.
- */
- def apply(
- batch: ArrowRecordBatch,
- schema: StructType,
- allocator: BufferAllocator): ArrowPayload = {
- new ArrowPayload(ArrowConverters.batchToByteArray(batch, schema, allocator))
- }
-}
-
private[sql] object ArrowConverters {
/**
@@ -77,95 +59,55 @@ private[sql] object ArrowConverters {
private[sql] def toPayloadIterator(
rowIter: Iterator[InternalRow],
schema: StructType,
- maxRecordsPerBatch: Int): Iterator[ArrowPayload] = {
- new Iterator[ArrowPayload] {
- private val _allocator = new RootAllocator(Long.MaxValue)
- private var _nextPayload = if (rowIter.nonEmpty) convert() else null
+ maxRecordsPerBatch: Int,
+ context: TaskContext): Iterator[ArrowPayload] = {
- override def hasNext: Boolean = _nextPayload != null
-
- override def next(): ArrowPayload = {
- val obj = _nextPayload
- if (hasNext) {
- if (rowIter.hasNext) {
- _nextPayload = convert()
- } else {
- _allocator.close()
- _nextPayload = null
- }
- }
- obj
- }
-
- private def convert(): ArrowPayload = {
- val batch = internalRowIterToArrowBatch(rowIter, schema, _allocator, maxRecordsPerBatch)
- ArrowPayload(batch, schema, _allocator)
- }
- }
- }
+ val arrowSchema = ArrowUtils.toArrowSchema(schema)
+ val allocator =
+ ArrowUtils.rootAllocator.newChildAllocator("toPayloadIterator", 0, Long.MaxValue)
- /**
- * Iterate over InternalRows and write to an ArrowRecordBatch, stopping when rowIter is consumed
- * or the number of records in the batch equals maxRecordsInBatch. If maxRecordsPerBatch is 0,
- * then rowIter will be fully consumed.
- */
- private def internalRowIterToArrowBatch(
- rowIter: Iterator[InternalRow],
- schema: StructType,
- allocator: BufferAllocator,
- maxRecordsPerBatch: Int = 0): ArrowRecordBatch = {
+ val root = VectorSchemaRoot.create(arrowSchema, allocator)
+ val arrowWriter = ArrowWriter.create(root)
- val columnWriters = schema.fields.zipWithIndex.map { case (field, ordinal) =>
- ColumnWriter(field.dataType, ordinal, allocator).init()
- }
+ var closed = false
- val writerLength = columnWriters.length
- var recordsInBatch = 0
- while (rowIter.hasNext && (maxRecordsPerBatch <= 0 || recordsInBatch < maxRecordsPerBatch)) {
- val row = rowIter.next()
- var i = 0
- while (i < writerLength) {
- columnWriters(i).write(row)
- i += 1
+ context.addTaskCompletionListener { _ =>
+ if (!closed) {
+ root.close()
+ allocator.close()
}
- recordsInBatch += 1
}
- val (fieldNodes, bufferArrays) = columnWriters.map(_.finish()).unzip
- val buffers = bufferArrays.flatten
-
- val rowLength = if (fieldNodes.nonEmpty) fieldNodes.head.getLength else 0
- val recordBatch = new ArrowRecordBatch(rowLength,
- fieldNodes.toList.asJava, buffers.toList.asJava)
+ new Iterator[ArrowPayload] {
- buffers.foreach(_.release())
- recordBatch
- }
+ override def hasNext: Boolean = rowIter.hasNext || {
+ root.close()
+ allocator.close()
+ closed = true
+ false
+ }
- /**
- * Convert an ArrowRecordBatch to a byte array and close batch to release resources. Once closed,
- * the batch can no longer be used.
- */
- private[arrow] def batchToByteArray(
- batch: ArrowRecordBatch,
- schema: StructType,
- allocator: BufferAllocator): Array[Byte] = {
- val arrowSchema = ArrowUtils.toArrowSchema(schema)
- val root = VectorSchemaRoot.create(arrowSchema, allocator)
- val out = new ByteArrayOutputStream()
- val writer = new ArrowFileWriter(root, null, Channels.newChannel(out))
+ override def next(): ArrowPayload = {
+ val out = new ByteArrayOutputStream()
+ val writer = new ArrowFileWriter(root, null, Channels.newChannel(out))
+
+ Utils.tryWithSafeFinally {
+ var rowCount = 0
+ while (rowIter.hasNext && (maxRecordsPerBatch <= 0 || rowCount < maxRecordsPerBatch)) {
+ val row = rowIter.next()
+ arrowWriter.write(row)
+ rowCount += 1
+ }
+ arrowWriter.finish()
+ writer.writeBatch()
+ } {
+ arrowWriter.reset()
+ writer.close()
+ }
- // Write a batch to byte stream, ensure the batch, allocator and writer are closed
- Utils.tryWithSafeFinally {
- val loader = new VectorLoader(root)
- loader.load(batch)
- writer.writeBatch() // writeBatch can throw IOException
- } {
- batch.close()
- root.close()
- writer.close()
+ new ArrowPayload(out.toByteArray)
+ }
}
- out.toByteArray
}
/**
@@ -188,214 +130,3 @@ private[sql] object ArrowConverters {
}
}
}
-
-/**
- * Interface for writing InternalRows to Arrow Buffers.
- */
-private[arrow] trait ColumnWriter {
- def init(): this.type
- def write(row: InternalRow): Unit
-
- /**
- * Clear the column writer and return the ArrowFieldNode and ArrowBuf.
- * This should be called only once after all the data is written.
- */
- def finish(): (ArrowFieldNode, Array[ArrowBuf])
-}
-
-/**
- * Base class for flat arrow column writer, i.e., column without children.
- */
-private[arrow] abstract class PrimitiveColumnWriter(val ordinal: Int)
- extends ColumnWriter {
-
- def getFieldType(dtype: ArrowType): FieldType = FieldType.nullable(dtype)
-
- def valueVector: BaseDataValueVector
- def valueMutator: BaseMutator
-
- def setNull(): Unit
- def setValue(row: InternalRow): Unit
-
- protected var count = 0
- protected var nullCount = 0
-
- override def init(): this.type = {
- valueVector.allocateNew()
- this
- }
-
- override def write(row: InternalRow): Unit = {
- if (row.isNullAt(ordinal)) {
- setNull()
- nullCount += 1
- } else {
- setValue(row)
- }
- count += 1
- }
-
- override def finish(): (ArrowFieldNode, Array[ArrowBuf]) = {
- valueMutator.setValueCount(count)
- val fieldNode = new ArrowFieldNode(count, nullCount)
- val valueBuffers = valueVector.getBuffers(true)
- (fieldNode, valueBuffers)
- }
-}
-
-private[arrow] class BooleanColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator)
- extends PrimitiveColumnWriter(ordinal) {
- override val valueVector: NullableBitVector
- = new NullableBitVector("BooleanValue", getFieldType(dtype), allocator)
- override val valueMutator: NullableBitVector#Mutator = valueVector.getMutator
-
- override def setNull(): Unit = valueMutator.setNull(count)
- override def setValue(row: InternalRow): Unit
- = valueMutator.setSafe(count, if (row.getBoolean(ordinal)) 1 else 0 )
-}
-
-private[arrow] class ShortColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator)
- extends PrimitiveColumnWriter(ordinal) {
- override val valueVector: NullableSmallIntVector
- = new NullableSmallIntVector("ShortValue", getFieldType(dtype: ArrowType), allocator)
- override val valueMutator: NullableSmallIntVector#Mutator = valueVector.getMutator
-
- override def setNull(): Unit = valueMutator.setNull(count)
- override def setValue(row: InternalRow): Unit
- = valueMutator.setSafe(count, row.getShort(ordinal))
-}
-
-private[arrow] class IntegerColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator)
- extends PrimitiveColumnWriter(ordinal) {
- override val valueVector: NullableIntVector
- = new NullableIntVector("IntValue", getFieldType(dtype), allocator)
- override val valueMutator: NullableIntVector#Mutator = valueVector.getMutator
-
- override def setNull(): Unit = valueMutator.setNull(count)
- override def setValue(row: InternalRow): Unit
- = valueMutator.setSafe(count, row.getInt(ordinal))
-}
-
-private[arrow] class LongColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator)
- extends PrimitiveColumnWriter(ordinal) {
- override val valueVector: NullableBigIntVector
- = new NullableBigIntVector("LongValue", getFieldType(dtype), allocator)
- override val valueMutator: NullableBigIntVector#Mutator = valueVector.getMutator
-
- override def setNull(): Unit = valueMutator.setNull(count)
- override def setValue(row: InternalRow): Unit
- = valueMutator.setSafe(count, row.getLong(ordinal))
-}
-
-private[arrow] class FloatColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator)
- extends PrimitiveColumnWriter(ordinal) {
- override val valueVector: NullableFloat4Vector
- = new NullableFloat4Vector("FloatValue", getFieldType(dtype), allocator)
- override val valueMutator: NullableFloat4Vector#Mutator = valueVector.getMutator
-
- override def setNull(): Unit = valueMutator.setNull(count)
- override def setValue(row: InternalRow): Unit
- = valueMutator.setSafe(count, row.getFloat(ordinal))
-}
-
-private[arrow] class DoubleColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator)
- extends PrimitiveColumnWriter(ordinal) {
- override val valueVector: NullableFloat8Vector
- = new NullableFloat8Vector("DoubleValue", getFieldType(dtype), allocator)
- override val valueMutator: NullableFloat8Vector#Mutator = valueVector.getMutator
-
- override def setNull(): Unit = valueMutator.setNull(count)
- override def setValue(row: InternalRow): Unit
- = valueMutator.setSafe(count, row.getDouble(ordinal))
-}
-
-private[arrow] class ByteColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator)
- extends PrimitiveColumnWriter(ordinal) {
- override val valueVector: NullableUInt1Vector
- = new NullableUInt1Vector("ByteValue", getFieldType(dtype), allocator)
- override val valueMutator: NullableUInt1Vector#Mutator = valueVector.getMutator
-
- override def setNull(): Unit = valueMutator.setNull(count)
- override def setValue(row: InternalRow): Unit
- = valueMutator.setSafe(count, row.getByte(ordinal))
-}
-
-private[arrow] class UTF8StringColumnWriter(
- dtype: ArrowType,
- ordinal: Int,
- allocator: BufferAllocator)
- extends PrimitiveColumnWriter(ordinal) {
- override val valueVector: NullableVarCharVector
- = new NullableVarCharVector("UTF8StringValue", getFieldType(dtype), allocator)
- override val valueMutator: NullableVarCharVector#Mutator = valueVector.getMutator
-
- override def setNull(): Unit = valueMutator.setNull(count)
- override def setValue(row: InternalRow): Unit = {
- val str = row.getUTF8String(ordinal)
- valueMutator.setSafe(count, str.getByteBuffer, 0, str.numBytes)
- }
-}
-
-private[arrow] class BinaryColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator)
- extends PrimitiveColumnWriter(ordinal) {
- override val valueVector: NullableVarBinaryVector
- = new NullableVarBinaryVector("BinaryValue", getFieldType(dtype), allocator)
- override val valueMutator: NullableVarBinaryVector#Mutator = valueVector.getMutator
-
- override def setNull(): Unit = valueMutator.setNull(count)
- override def setValue(row: InternalRow): Unit = {
- val bytes = row.getBinary(ordinal)
- valueMutator.setSafe(count, bytes, 0, bytes.length)
- }
-}
-
-private[arrow] class DateColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator)
- extends PrimitiveColumnWriter(ordinal) {
- override val valueVector: NullableDateDayVector
- = new NullableDateDayVector("DateValue", getFieldType(dtype), allocator)
- override val valueMutator: NullableDateDayVector#Mutator = valueVector.getMutator
-
- override def setNull(): Unit = valueMutator.setNull(count)
- override def setValue(row: InternalRow): Unit = {
- valueMutator.setSafe(count, row.getInt(ordinal))
- }
-}
-
-private[arrow] class TimeStampColumnWriter(
- dtype: ArrowType,
- ordinal: Int,
- allocator: BufferAllocator)
- extends PrimitiveColumnWriter(ordinal) {
- override val valueVector: NullableTimeStampMicroVector
- = new NullableTimeStampMicroVector("TimeStampValue", getFieldType(dtype), allocator)
- override val valueMutator: NullableTimeStampMicroVector#Mutator = valueVector.getMutator
-
- override def setNull(): Unit = valueMutator.setNull(count)
- override def setValue(row: InternalRow): Unit = {
- valueMutator.setSafe(count, row.getLong(ordinal))
- }
-}
-
-private[arrow] object ColumnWriter {
-
- /**
- * Create an Arrow ColumnWriter given the type and ordinal of row.
- */
- def apply(dataType: DataType, ordinal: Int, allocator: BufferAllocator): ColumnWriter = {
- val dtype = ArrowUtils.toArrowType(dataType)
- dataType match {
- case BooleanType => new BooleanColumnWriter(dtype, ordinal, allocator)
- case ShortType => new ShortColumnWriter(dtype, ordinal, allocator)
- case IntegerType => new IntegerColumnWriter(dtype, ordinal, allocator)
- case LongType => new LongColumnWriter(dtype, ordinal, allocator)
- case FloatType => new FloatColumnWriter(dtype, ordinal, allocator)
- case DoubleType => new DoubleColumnWriter(dtype, ordinal, allocator)
- case ByteType => new ByteColumnWriter(dtype, ordinal, allocator)
- case StringType => new UTF8StringColumnWriter(dtype, ordinal, allocator)
- case BinaryType => new BinaryColumnWriter(dtype, ordinal, allocator)
- case DateType => new DateColumnWriter(dtype, ordinal, allocator)
- case TimestampType => new TimeStampColumnWriter(dtype, ordinal, allocator)
- case _ => throw new UnsupportedOperationException(s"Unsupported data type: $dataType")
- }
- }
-}
http://git-wip-us.apache.org/repos/asf/spark/blob/2ff35a05/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala
new file mode 100644
index 0000000..11ba04d
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala
@@ -0,0 +1,323 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.arrow
+
+import scala.collection.JavaConverters._
+
+import org.apache.arrow.vector._
+import org.apache.arrow.vector.complex._
+import org.apache.arrow.vector.util.DecimalUtility
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.SpecializedGetters
+import org.apache.spark.sql.types._
+
+object ArrowWriter {
+
+ def create(schema: StructType): ArrowWriter = {
+ val arrowSchema = ArrowUtils.toArrowSchema(schema)
+ val root = VectorSchemaRoot.create(arrowSchema, ArrowUtils.rootAllocator)
+ create(root)
+ }
+
+ def create(root: VectorSchemaRoot): ArrowWriter = {
+ val children = root.getFieldVectors().asScala.map { vector =>
+ vector.allocateNew()
+ createFieldWriter(vector)
+ }
+ new ArrowWriter(root, children.toArray)
+ }
+
+ private def createFieldWriter(vector: ValueVector): ArrowFieldWriter = {
+ val field = vector.getField()
+ (ArrowUtils.fromArrowField(field), vector) match {
+ case (BooleanType, vector: NullableBitVector) => new BooleanWriter(vector)
+ case (ByteType, vector: NullableTinyIntVector) => new ByteWriter(vector)
+ case (ShortType, vector: NullableSmallIntVector) => new ShortWriter(vector)
+ case (IntegerType, vector: NullableIntVector) => new IntegerWriter(vector)
+ case (LongType, vector: NullableBigIntVector) => new LongWriter(vector)
+ case (FloatType, vector: NullableFloat4Vector) => new FloatWriter(vector)
+ case (DoubleType, vector: NullableFloat8Vector) => new DoubleWriter(vector)
+ case (StringType, vector: NullableVarCharVector) => new StringWriter(vector)
+ case (BinaryType, vector: NullableVarBinaryVector) => new BinaryWriter(vector)
+ case (ArrayType(_, _), vector: ListVector) =>
+ val elementVector = createFieldWriter(vector.getDataVector())
+ new ArrayWriter(vector, elementVector)
+ case (StructType(_), vector: NullableMapVector) =>
+ val children = (0 until vector.size()).map { ordinal =>
+ createFieldWriter(vector.getChildByOrdinal(ordinal))
+ }
+ new StructWriter(vector, children.toArray)
+ case (dt, _) =>
+ throw new UnsupportedOperationException(s"Unsupported data type: ${dt.simpleString}")
+ }
+ }
+}
+
+class ArrowWriter(
+ val root: VectorSchemaRoot,
+ fields: Array[ArrowFieldWriter]) {
+
+ def schema: StructType = StructType(fields.map { f =>
+ StructField(f.name, f.dataType, f.nullable)
+ })
+
+ private var count: Int = 0
+
+ def write(row: InternalRow): Unit = {
+ var i = 0
+ while (i < fields.size) {
+ fields(i).write(row, i)
+ i += 1
+ }
+ count += 1
+ }
+
+ def finish(): Unit = {
+ root.setRowCount(count)
+ fields.foreach(_.finish())
+ }
+
+ def reset(): Unit = {
+ root.setRowCount(0)
+ count = 0
+ fields.foreach(_.reset())
+ }
+}
+
+private[arrow] abstract class ArrowFieldWriter {
+
+ def valueVector: ValueVector
+ def valueMutator: ValueVector.Mutator
+
+ def name: String = valueVector.getField().getName()
+ def dataType: DataType = ArrowUtils.fromArrowField(valueVector.getField())
+ def nullable: Boolean = valueVector.getField().isNullable()
+
+ def setNull(): Unit
+ def setValue(input: SpecializedGetters, ordinal: Int): Unit
+
+ private[arrow] var count: Int = 0
+
+ def write(input: SpecializedGetters, ordinal: Int): Unit = {
+ if (input.isNullAt(ordinal)) {
+ setNull()
+ } else {
+ setValue(input, ordinal)
+ }
+ count += 1
+ }
+
+ def finish(): Unit = {
+ valueMutator.setValueCount(count)
+ }
+
+ def reset(): Unit = {
+ valueMutator.reset()
+ count = 0
+ }
+}
+
+private[arrow] class BooleanWriter(val valueVector: NullableBitVector) extends ArrowFieldWriter {
+
+ override def valueMutator: NullableBitVector#Mutator = valueVector.getMutator()
+
+ override def setNull(): Unit = {
+ valueMutator.setNull(count)
+ }
+
+ override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
+ valueMutator.setSafe(count, if (input.getBoolean(ordinal)) 1 else 0)
+ }
+}
+
+private[arrow] class ByteWriter(val valueVector: NullableTinyIntVector) extends ArrowFieldWriter {
+
+ override def valueMutator: NullableTinyIntVector#Mutator = valueVector.getMutator()
+
+ override def setNull(): Unit = {
+ valueMutator.setNull(count)
+ }
+
+ override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
+ valueMutator.setSafe(count, input.getByte(ordinal))
+ }
+}
+
+private[arrow] class ShortWriter(val valueVector: NullableSmallIntVector) extends ArrowFieldWriter {
+
+ override def valueMutator: NullableSmallIntVector#Mutator = valueVector.getMutator()
+
+ override def setNull(): Unit = {
+ valueMutator.setNull(count)
+ }
+
+ override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
+ valueMutator.setSafe(count, input.getShort(ordinal))
+ }
+}
+
+private[arrow] class IntegerWriter(val valueVector: NullableIntVector) extends ArrowFieldWriter {
+
+ override def valueMutator: NullableIntVector#Mutator = valueVector.getMutator()
+
+ override def setNull(): Unit = {
+ valueMutator.setNull(count)
+ }
+
+ override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
+ valueMutator.setSafe(count, input.getInt(ordinal))
+ }
+}
+
+private[arrow] class LongWriter(val valueVector: NullableBigIntVector) extends ArrowFieldWriter {
+
+ override def valueMutator: NullableBigIntVector#Mutator = valueVector.getMutator()
+
+ override def setNull(): Unit = {
+ valueMutator.setNull(count)
+ }
+
+ override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
+ valueMutator.setSafe(count, input.getLong(ordinal))
+ }
+}
+
+private[arrow] class FloatWriter(val valueVector: NullableFloat4Vector) extends ArrowFieldWriter {
+
+ override def valueMutator: NullableFloat4Vector#Mutator = valueVector.getMutator()
+
+ override def setNull(): Unit = {
+ valueMutator.setNull(count)
+ }
+
+ override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
+ valueMutator.setSafe(count, input.getFloat(ordinal))
+ }
+}
+
+private[arrow] class DoubleWriter(val valueVector: NullableFloat8Vector) extends ArrowFieldWriter {
+
+ override def valueMutator: NullableFloat8Vector#Mutator = valueVector.getMutator()
+
+ override def setNull(): Unit = {
+ valueMutator.setNull(count)
+ }
+
+ override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
+ valueMutator.setSafe(count, input.getDouble(ordinal))
+ }
+}
+
+private[arrow] class StringWriter(val valueVector: NullableVarCharVector) extends ArrowFieldWriter {
+
+ override def valueMutator: NullableVarCharVector#Mutator = valueVector.getMutator()
+
+ override def setNull(): Unit = {
+ valueMutator.setNull(count)
+ }
+
+ override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
+ val utf8 = input.getUTF8String(ordinal)
+ // todo: for off-heap UTF8String, how to pass in to arrow without copy?
+ valueMutator.setSafe(count, utf8.getByteBuffer, 0, utf8.numBytes())
+ }
+}
+
+private[arrow] class BinaryWriter(
+ val valueVector: NullableVarBinaryVector) extends ArrowFieldWriter {
+
+ override def valueMutator: NullableVarBinaryVector#Mutator = valueVector.getMutator()
+
+ override def setNull(): Unit = {
+ valueMutator.setNull(count)
+ }
+
+ override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
+ val bytes = input.getBinary(ordinal)
+ valueMutator.setSafe(count, bytes, 0, bytes.length)
+ }
+}
+
+private[arrow] class ArrayWriter(
+ val valueVector: ListVector,
+ val elementWriter: ArrowFieldWriter) extends ArrowFieldWriter {
+
+ override def valueMutator: ListVector#Mutator = valueVector.getMutator()
+
+ override def setNull(): Unit = {
+ }
+
+ override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
+ val array = input.getArray(ordinal)
+ var i = 0
+ valueMutator.startNewValue(count)
+ while (i < array.numElements()) {
+ elementWriter.write(array, i)
+ i += 1
+ }
+ valueMutator.endValue(count, array.numElements())
+ }
+
+ override def finish(): Unit = {
+ super.finish()
+ elementWriter.finish()
+ }
+
+ override def reset(): Unit = {
+ super.reset()
+ elementWriter.reset()
+ }
+}
+
+private[arrow] class StructWriter(
+ val valueVector: NullableMapVector,
+ children: Array[ArrowFieldWriter]) extends ArrowFieldWriter {
+
+ override def valueMutator: NullableMapVector#Mutator = valueVector.getMutator()
+
+ override def setNull(): Unit = {
+ var i = 0
+ while (i < children.length) {
+ children(i).setNull()
+ children(i).count += 1
+ i += 1
+ }
+ valueMutator.setNull(count)
+ }
+
+ override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
+ val struct = input.getStruct(ordinal, children.length)
+ var i = 0
+ while (i < struct.numFields) {
+ children(i).write(struct, i)
+ i += 1
+ }
+ valueMutator.setIndexDefined(count)
+ }
+
+ override def finish(): Unit = {
+ super.finish()
+ children.foreach(_.finish())
+ }
+
+ override def reset(): Unit = {
+ super.reset()
+ children.foreach(_.reset())
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/2ff35a05/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala
index 55b4655..4893b52 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala
@@ -32,7 +32,7 @@ import org.scalatest.BeforeAndAfterAll
import org.apache.spark.SparkException
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.test.SharedSQLContext
-import org.apache.spark.sql.types.{BinaryType, StructField, StructType}
+import org.apache.spark.sql.types.{BinaryType, IntegerType, StructField, StructType}
import org.apache.spark.util.Utils
@@ -857,6 +857,449 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll {
collectAndValidate(df, json, "nanData-floating_point.json")
}
+ test("array type conversion") {
+ val json =
+ s"""
+ |{
+ | "schema" : {
+ | "fields" : [ {
+ | "name" : "a_arr",
+ | "nullable" : true,
+ | "type" : {
+ | "name" : "list"
+ | },
+ | "children" : [ {
+ | "name" : "element",
+ | "nullable" : false,
+ | "type" : {
+ | "name" : "int",
+ | "bitWidth" : 32,
+ | "isSigned" : true
+ | },
+ | "children" : [ ],
+ | "typeLayout" : {
+ | "vectors" : [ {
+ | "type" : "VALIDITY",
+ | "typeBitWidth" : 1
+ | }, {
+ | "type" : "DATA",
+ | "typeBitWidth" : 32
+ | } ]
+ | }
+ | } ],
+ | "typeLayout" : {
+ | "vectors" : [ {
+ | "type" : "VALIDITY",
+ | "typeBitWidth" : 1
+ | }, {
+ | "type" : "OFFSET",
+ | "typeBitWidth" : 32
+ | } ]
+ | }
+ | }, {
+ | "name" : "b_arr",
+ | "nullable" : true,
+ | "type" : {
+ | "name" : "list"
+ | },
+ | "children" : [ {
+ | "name" : "element",
+ | "nullable" : false,
+ | "type" : {
+ | "name" : "int",
+ | "bitWidth" : 32,
+ | "isSigned" : true
+ | },
+ | "children" : [ ],
+ | "typeLayout" : {
+ | "vectors" : [ {
+ | "type" : "VALIDITY",
+ | "typeBitWidth" : 1
+ | }, {
+ | "type" : "DATA",
+ | "typeBitWidth" : 32
+ | } ]
+ | }
+ | } ],
+ | "typeLayout" : {
+ | "vectors" : [ {
+ | "type" : "VALIDITY",
+ | "typeBitWidth" : 1
+ | }, {
+ | "type" : "OFFSET",
+ | "typeBitWidth" : 32
+ | } ]
+ | }
+ | }, {
+ | "name" : "c_arr",
+ | "nullable" : true,
+ | "type" : {
+ | "name" : "list"
+ | },
+ | "children" : [ {
+ | "name" : "element",
+ | "nullable" : true,
+ | "type" : {
+ | "name" : "int",
+ | "bitWidth" : 32,
+ | "isSigned" : true
+ | },
+ | "children" : [ ],
+ | "typeLayout" : {
+ | "vectors" : [ {
+ | "type" : "VALIDITY",
+ | "typeBitWidth" : 1
+ | }, {
+ | "type" : "DATA",
+ | "typeBitWidth" : 32
+ | } ]
+ | }
+ | } ],
+ | "typeLayout" : {
+ | "vectors" : [ {
+ | "type" : "VALIDITY",
+ | "typeBitWidth" : 1
+ | }, {
+ | "type" : "OFFSET",
+ | "typeBitWidth" : 32
+ | } ]
+ | }
+ | }, {
+ | "name" : "d_arr",
+ | "nullable" : true,
+ | "type" : {
+ | "name" : "list"
+ | },
+ | "children" : [ {
+ | "name" : "element",
+ | "nullable" : true,
+ | "type" : {
+ | "name" : "list"
+ | },
+ | "children" : [ {
+ | "name" : "element",
+ | "nullable" : false,
+ | "type" : {
+ | "name" : "int",
+ | "bitWidth" : 32,
+ | "isSigned" : true
+ | },
+ | "children" : [ ],
+ | "typeLayout" : {
+ | "vectors" : [ {
+ | "type" : "VALIDITY",
+ | "typeBitWidth" : 1
+ | }, {
+ | "type" : "DATA",
+ | "typeBitWidth" : 32
+ | } ]
+ | }
+ | } ],
+ | "typeLayout" : {
+ | "vectors" : [ {
+ | "type" : "VALIDITY",
+ | "typeBitWidth" : 1
+ | }, {
+ | "type" : "OFFSET",
+ | "typeBitWidth" : 32
+ | } ]
+ | }
+ | } ],
+ | "typeLayout" : {
+ | "vectors" : [ {
+ | "type" : "VALIDITY",
+ | "typeBitWidth" : 1
+ | }, {
+ | "type" : "OFFSET",
+ | "typeBitWidth" : 32
+ | } ]
+ | }
+ | } ]
+ | },
+ | "batches" : [ {
+ | "count" : 4,
+ | "columns" : [ {
+ | "name" : "a_arr",
+ | "count" : 4,
+ | "VALIDITY" : [ 1, 1, 1, 1 ],
+ | "OFFSET" : [ 0, 2, 4, 4, 5 ],
+ | "children" : [ {
+ | "name" : "element",
+ | "count" : 5,
+ | "VALIDITY" : [ 1, 1, 1, 1, 1 ],
+ | "DATA" : [ 1, 2, 3, 4, 5 ]
+ | } ]
+ | }, {
+ | "name" : "b_arr",
+ | "count" : 4,
+ | "VALIDITY" : [ 1, 0, 1, 0 ],
+ | "OFFSET" : [ 0, 2, 2, 2, 2 ],
+ | "children" : [ {
+ | "name" : "element",
+ | "count" : 2,
+ | "VALIDITY" : [ 1, 1 ],
+ | "DATA" : [ 1, 2 ]
+ | } ]
+ | }, {
+ | "name" : "c_arr",
+ | "count" : 4,
+ | "VALIDITY" : [ 1, 1, 1, 1 ],
+ | "OFFSET" : [ 0, 2, 4, 4, 5 ],
+ | "children" : [ {
+ | "name" : "element",
+ | "count" : 5,
+ | "VALIDITY" : [ 1, 1, 1, 0, 1 ],
+ | "DATA" : [ 1, 2, 3, 0, 5 ]
+ | } ]
+ | }, {
+ | "name" : "d_arr",
+ | "count" : 4,
+ | "VALIDITY" : [ 1, 1, 1, 1 ],
+ | "OFFSET" : [ 0, 1, 3, 3, 4 ],
+ | "children" : [ {
+ | "name" : "element",
+ | "count" : 4,
+ | "VALIDITY" : [ 1, 1, 1, 1 ],
+ | "OFFSET" : [ 0, 2, 3, 3, 4 ],
+ | "children" : [ {
+ | "name" : "element",
+ | "count" : 4,
+ | "VALIDITY" : [ 1, 1, 1, 1 ],
+ | "DATA" : [ 1, 2, 3, 5 ]
+ | } ]
+ | } ]
+ | } ]
+ | } ]
+ |}
+ """.stripMargin
+
+ val aArr = Seq(Seq(1, 2), Seq(3, 4), Seq(), Seq(5))
+ val bArr = Seq(Some(Seq(1, 2)), None, Some(Seq()), None)
+ val cArr = Seq(Seq(Some(1), Some(2)), Seq(Some(3), None), Seq(), Seq(Some(5)))
+ val dArr = Seq(Seq(Seq(1, 2)), Seq(Seq(3), Seq()), Seq(), Seq(Seq(5)))
+
+ val df = aArr.zip(bArr).zip(cArr).zip(dArr).map {
+ case (((a, b), c), d) => (a, b, c, d)
+ }.toDF("a_arr", "b_arr", "c_arr", "d_arr")
+
+ collectAndValidate(df, json, "arrayData.json")
+ }
+
+ test("struct type conversion") {
+ val json =
+ s"""
+ |{
+ | "schema" : {
+ | "fields" : [ {
+ | "name" : "a_struct",
+ | "nullable" : false,
+ | "type" : {
+ | "name" : "struct"
+ | },
+ | "children" : [ {
+ | "name" : "i",
+ | "nullable" : false,
+ | "type" : {
+ | "name" : "int",
+ | "bitWidth" : 32,
+ | "isSigned" : true
+ | },
+ | "children" : [ ],
+ | "typeLayout" : {
+ | "vectors" : [ {
+ | "type" : "VALIDITY",
+ | "typeBitWidth" : 1
+ | }, {
+ | "type" : "DATA",
+ | "typeBitWidth" : 32
+ | } ]
+ | }
+ | } ],
+ | "typeLayout" : {
+ | "vectors" : [ {
+ | "type" : "VALIDITY",
+ | "typeBitWidth" : 1
+ | } ]
+ | }
+ | }, {
+ | "name" : "b_struct",
+ | "nullable" : true,
+ | "type" : {
+ | "name" : "struct"
+ | },
+ | "children" : [ {
+ | "name" : "i",
+ | "nullable" : false,
+ | "type" : {
+ | "name" : "int",
+ | "bitWidth" : 32,
+ | "isSigned" : true
+ | },
+ | "children" : [ ],
+ | "typeLayout" : {
+ | "vectors" : [ {
+ | "type" : "VALIDITY",
+ | "typeBitWidth" : 1
+ | }, {
+ | "type" : "DATA",
+ | "typeBitWidth" : 32
+ | } ]
+ | }
+ | } ],
+ | "typeLayout" : {
+ | "vectors" : [ {
+ | "type" : "VALIDITY",
+ | "typeBitWidth" : 1
+ | } ]
+ | }
+ | }, {
+ | "name" : "c_struct",
+ | "nullable" : false,
+ | "type" : {
+ | "name" : "struct"
+ | },
+ | "children" : [ {
+ | "name" : "i",
+ | "nullable" : true,
+ | "type" : {
+ | "name" : "int",
+ | "bitWidth" : 32,
+ | "isSigned" : true
+ | },
+ | "children" : [ ],
+ | "typeLayout" : {
+ | "vectors" : [ {
+ | "type" : "VALIDITY",
+ | "typeBitWidth" : 1
+ | }, {
+ | "type" : "DATA",
+ | "typeBitWidth" : 32
+ | } ]
+ | }
+ | } ],
+ | "typeLayout" : {
+ | "vectors" : [ {
+ | "type" : "VALIDITY",
+ | "typeBitWidth" : 1
+ | } ]
+ | }
+ | }, {
+ | "name" : "d_struct",
+ | "nullable" : true,
+ | "type" : {
+ | "name" : "struct"
+ | },
+ | "children" : [ {
+ | "name" : "nested",
+ | "nullable" : true,
+ | "type" : {
+ | "name" : "struct"
+ | },
+ | "children" : [ {
+ | "name" : "i",
+ | "nullable" : true,
+ | "type" : {
+ | "name" : "int",
+ | "bitWidth" : 32,
+ | "isSigned" : true
+ | },
+ | "children" : [ ],
+ | "typeLayout" : {
+ | "vectors" : [ {
+ | "type" : "VALIDITY",
+ | "typeBitWidth" : 1
+ | }, {
+ | "type" : "DATA",
+ | "typeBitWidth" : 32
+ | } ]
+ | }
+ | } ],
+ | "typeLayout" : {
+ | "vectors" : [ {
+ | "type" : "VALIDITY",
+ | "typeBitWidth" : 1
+ | } ]
+ | }
+ | } ],
+ | "typeLayout" : {
+ | "vectors" : [ {
+ | "type" : "VALIDITY",
+ | "typeBitWidth" : 1
+ | } ]
+ | }
+ | } ]
+ | },
+ | "batches" : [ {
+ | "count" : 3,
+ | "columns" : [ {
+ | "name" : "a_struct",
+ | "count" : 3,
+ | "VALIDITY" : [ 1, 1, 1 ],
+ | "children" : [ {
+ | "name" : "i",
+ | "count" : 3,
+ | "VALIDITY" : [ 1, 1, 1 ],
+ | "DATA" : [ 1, 2, 3 ]
+ | } ]
+ | }, {
+ | "name" : "b_struct",
+ | "count" : 3,
+ | "VALIDITY" : [ 1, 0, 1 ],
+ | "children" : [ {
+ | "name" : "i",
+ | "count" : 3,
+ | "VALIDITY" : [ 1, 0, 1 ],
+ | "DATA" : [ 1, 2, 3 ]
+ | } ]
+ | }, {
+ | "name" : "c_struct",
+ | "count" : 3,
+ | "VALIDITY" : [ 1, 1, 1 ],
+ | "children" : [ {
+ | "name" : "i",
+ | "count" : 3,
+ | "VALIDITY" : [ 1, 0, 1 ],
+ | "DATA" : [ 1, 2, 3 ]
+ | } ]
+ | }, {
+ | "name" : "d_struct",
+ | "count" : 3,
+ | "VALIDITY" : [ 1, 0, 1 ],
+ | "children" : [ {
+ | "name" : "nested",
+ | "count" : 3,
+ | "VALIDITY" : [ 1, 0, 0 ],
+ | "children" : [ {
+ | "name" : "i",
+ | "count" : 3,
+ | "VALIDITY" : [ 1, 0, 0 ],
+ | "DATA" : [ 1, 2, 0 ]
+ | } ]
+ | } ]
+ | } ]
+ | } ]
+ |}
+ """.stripMargin
+
+ val aStruct = Seq(Row(1), Row(2), Row(3))
+ val bStruct = Seq(Row(1), null, Row(3))
+ val cStruct = Seq(Row(1), Row(null), Row(3))
+ val dStruct = Seq(Row(Row(1)), null, Row(null))
+ val data = aStruct.zip(bStruct).zip(cStruct).zip(dStruct).map {
+ case (((a, b), c), d) => Row(a, b, c, d)
+ }
+
+ val rdd = sparkContext.parallelize(data)
+ val schema = new StructType()
+ .add("a_struct", new StructType().add("i", IntegerType, nullable = false), nullable = false)
+ .add("b_struct", new StructType().add("i", IntegerType, nullable = false), nullable = true)
+ .add("c_struct", new StructType().add("i", IntegerType, nullable = true), nullable = false)
+ .add("d_struct", new StructType().add("nested", new StructType().add("i", IntegerType)))
+ val df = spark.createDataFrame(rdd, schema)
+
+ collectAndValidate(df, json, "structData.json")
+ }
+
test("partitioned DataFrame") {
val json1 =
s"""
@@ -1015,6 +1458,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll {
spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", maxRecordsPerBatch)
val df = spark.sparkContext.parallelize(1 to totalRecords, 2).toDF("i")
val arrowPayloads = df.toArrowPayload.collect()
+ assert(arrowPayloads.length >= 4)
val allocator = new RootAllocator(Long.MaxValue)
val arrowRecordBatches = arrowPayloads.map(_.loadBatch(allocator))
var recordCount = 0
@@ -1039,7 +1483,6 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll {
}
runUnsupported { decimalData.toArrowPayload.collect() }
- runUnsupported { arrayData.toDF().toArrowPayload.collect() }
runUnsupported { mapData.toDF().toArrowPayload.collect() }
runUnsupported { complexData.toArrowPayload.collect() }
http://git-wip-us.apache.org/repos/asf/spark/blob/2ff35a05/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala
new file mode 100644
index 0000000..e9a6293
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala
@@ -0,0 +1,260 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.arrow
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.util.ArrayData
+import org.apache.spark.sql.execution.vectorized.ArrowColumnVector
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
+
+class ArrowWriterSuite extends SparkFunSuite {
+
+ test("simple") {
+ def check(dt: DataType, data: Seq[Any]): Unit = {
+ val schema = new StructType().add("value", dt, nullable = true)
+ val writer = ArrowWriter.create(schema)
+ assert(writer.schema === schema)
+
+ data.foreach { datum =>
+ writer.write(InternalRow(datum))
+ }
+ writer.finish()
+
+ val reader = new ArrowColumnVector(writer.root.getFieldVectors().get(0))
+ data.zipWithIndex.foreach {
+ case (null, rowId) => assert(reader.isNullAt(rowId))
+ case (datum, rowId) =>
+ val value = dt match {
+ case BooleanType => reader.getBoolean(rowId)
+ case ByteType => reader.getByte(rowId)
+ case ShortType => reader.getShort(rowId)
+ case IntegerType => reader.getInt(rowId)
+ case LongType => reader.getLong(rowId)
+ case FloatType => reader.getFloat(rowId)
+ case DoubleType => reader.getDouble(rowId)
+ case StringType => reader.getUTF8String(rowId)
+ case BinaryType => reader.getBinary(rowId)
+ }
+ assert(value === datum)
+ }
+
+ writer.root.close()
+ }
+ check(BooleanType, Seq(true, null, false))
+ check(ByteType, Seq(1.toByte, 2.toByte, null, 4.toByte))
+ check(ShortType, Seq(1.toShort, 2.toShort, null, 4.toShort))
+ check(IntegerType, Seq(1, 2, null, 4))
+ check(LongType, Seq(1L, 2L, null, 4L))
+ check(FloatType, Seq(1.0f, 2.0f, null, 4.0f))
+ check(DoubleType, Seq(1.0d, 2.0d, null, 4.0d))
+ check(StringType, Seq("a", "b", null, "d").map(UTF8String.fromString))
+ check(BinaryType, Seq("a".getBytes(), "b".getBytes(), null, "d".getBytes()))
+ }
+
+ test("get multiple") {
+ def check(dt: DataType, data: Seq[Any]): Unit = {
+ val schema = new StructType().add("value", dt, nullable = false)
+ val writer = ArrowWriter.create(schema)
+ assert(writer.schema === schema)
+
+ data.foreach { datum =>
+ writer.write(InternalRow(datum))
+ }
+ writer.finish()
+
+ val reader = new ArrowColumnVector(writer.root.getFieldVectors().get(0))
+ val values = dt match {
+ case BooleanType => reader.getBooleans(0, data.size)
+ case ByteType => reader.getBytes(0, data.size)
+ case ShortType => reader.getShorts(0, data.size)
+ case IntegerType => reader.getInts(0, data.size)
+ case LongType => reader.getLongs(0, data.size)
+ case FloatType => reader.getFloats(0, data.size)
+ case DoubleType => reader.getDoubles(0, data.size)
+ }
+ assert(values === data)
+
+ writer.root.close()
+ }
+ check(BooleanType, Seq(true, false))
+ check(ByteType, (0 until 10).map(_.toByte))
+ check(ShortType, (0 until 10).map(_.toShort))
+ check(IntegerType, (0 until 10))
+ check(LongType, (0 until 10).map(_.toLong))
+ check(FloatType, (0 until 10).map(_.toFloat))
+ check(DoubleType, (0 until 10).map(_.toDouble))
+ }
+
+ test("array") {
+ val schema = new StructType()
+ .add("arr", ArrayType(IntegerType, containsNull = true), nullable = true)
+ val writer = ArrowWriter.create(schema)
+ assert(writer.schema === schema)
+
+ writer.write(InternalRow(ArrayData.toArrayData(Array(1, 2, 3))))
+ writer.write(InternalRow(ArrayData.toArrayData(Array(4, 5))))
+ writer.write(InternalRow(null))
+ writer.write(InternalRow(ArrayData.toArrayData(Array.empty[Int])))
+ writer.write(InternalRow(ArrayData.toArrayData(Array(6, null, 8))))
+ writer.finish()
+
+ val reader = new ArrowColumnVector(writer.root.getFieldVectors().get(0))
+
+ val array0 = reader.getArray(0)
+ assert(array0.numElements() === 3)
+ assert(array0.getInt(0) === 1)
+ assert(array0.getInt(1) === 2)
+ assert(array0.getInt(2) === 3)
+
+ val array1 = reader.getArray(1)
+ assert(array1.numElements() === 2)
+ assert(array1.getInt(0) === 4)
+ assert(array1.getInt(1) === 5)
+
+ assert(reader.isNullAt(2))
+
+ val array3 = reader.getArray(3)
+ assert(array3.numElements() === 0)
+
+ val array4 = reader.getArray(4)
+ assert(array4.numElements() === 3)
+ assert(array4.getInt(0) === 6)
+ assert(array4.isNullAt(1))
+ assert(array4.getInt(2) === 8)
+
+ writer.root.close()
+ }
+
+ test("nested array") {
+ val schema = new StructType().add("nested", ArrayType(ArrayType(IntegerType)))
+ val writer = ArrowWriter.create(schema)
+ assert(writer.schema === schema)
+
+ writer.write(InternalRow(ArrayData.toArrayData(Array(
+ ArrayData.toArrayData(Array(1, 2, 3)),
+ ArrayData.toArrayData(Array(4, 5)),
+ null,
+ ArrayData.toArrayData(Array.empty[Int]),
+ ArrayData.toArrayData(Array(6, null, 8))))))
+ writer.write(InternalRow(null))
+ writer.write(InternalRow(ArrayData.toArrayData(Array.empty)))
+ writer.finish()
+
+ val reader = new ArrowColumnVector(writer.root.getFieldVectors().get(0))
+
+ val array0 = reader.getArray(0)
+ assert(array0.numElements() === 5)
+
+ val array00 = array0.getArray(0)
+ assert(array00.numElements() === 3)
+ assert(array00.getInt(0) === 1)
+ assert(array00.getInt(1) === 2)
+ assert(array00.getInt(2) === 3)
+
+ val array01 = array0.getArray(1)
+ assert(array01.numElements() === 2)
+ assert(array01.getInt(0) === 4)
+ assert(array01.getInt(1) === 5)
+
+ assert(array0.isNullAt(2))
+
+ val array03 = array0.getArray(3)
+ assert(array03.numElements() === 0)
+
+ val array04 = array0.getArray(4)
+ assert(array04.numElements() === 3)
+ assert(array04.getInt(0) === 6)
+ assert(array04.isNullAt(1))
+ assert(array04.getInt(2) === 8)
+
+ assert(reader.isNullAt(1))
+
+ val array2 = reader.getArray(2)
+ assert(array2.numElements() === 0)
+
+ writer.root.close()
+ }
+
+ test("struct") {
+ val schema = new StructType()
+ .add("struct", new StructType().add("i", IntegerType).add("str", StringType))
+ val writer = ArrowWriter.create(schema)
+ assert(writer.schema === schema)
+
+ writer.write(InternalRow(InternalRow(1, UTF8String.fromString("str1"))))
+ writer.write(InternalRow(InternalRow(null, null)))
+ writer.write(InternalRow(null))
+ writer.write(InternalRow(InternalRow(4, null)))
+ writer.write(InternalRow(InternalRow(null, UTF8String.fromString("str5"))))
+ writer.finish()
+
+ val reader = new ArrowColumnVector(writer.root.getFieldVectors().get(0))
+
+ val struct0 = reader.getStruct(0, 2)
+ assert(struct0.getInt(0) === 1)
+ assert(struct0.getUTF8String(1) === UTF8String.fromString("str1"))
+
+ val struct1 = reader.getStruct(1, 2)
+ assert(struct1.isNullAt(0))
+ assert(struct1.isNullAt(1))
+
+ assert(reader.isNullAt(2))
+
+ val struct3 = reader.getStruct(3, 2)
+ assert(struct3.getInt(0) === 4)
+ assert(struct3.isNullAt(1))
+
+ val struct4 = reader.getStruct(4, 2)
+ assert(struct4.isNullAt(0))
+ assert(struct4.getUTF8String(1) === UTF8String.fromString("str5"))
+
+ writer.root.close()
+ }
+
+ test("nested struct") {
+ val schema = new StructType().add("struct",
+ new StructType().add("nested", new StructType().add("i", IntegerType).add("str", StringType)))
+ val writer = ArrowWriter.create(schema)
+ assert(writer.schema === schema)
+
+ writer.write(InternalRow(InternalRow(InternalRow(1, UTF8String.fromString("str1")))))
+ writer.write(InternalRow(InternalRow(InternalRow(null, null))))
+ writer.write(InternalRow(InternalRow(null)))
+ writer.write(InternalRow(null))
+ writer.finish()
+
+ val reader = new ArrowColumnVector(writer.root.getFieldVectors().get(0))
+
+ val struct00 = reader.getStruct(0, 1).getStruct(0, 2)
+ assert(struct00.getInt(0) === 1)
+ assert(struct00.getUTF8String(1) === UTF8String.fromString("str1"))
+
+ val struct10 = reader.getStruct(1, 1).getStruct(0, 2)
+ assert(struct10.isNullAt(0))
+ assert(struct10.isNullAt(1))
+
+ val struct2 = reader.getStruct(2, 1)
+ assert(struct2.isNullAt(0))
+
+ assert(reader.isNullAt(3))
+
+ writer.root.close()
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org