You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2018/07/15 14:06:39 UTC
spark git commit: [SPARK-24800][SQL] Refactor Avro Serializer and
Deserializer
Repository: spark
Updated Branches:
refs/heads/master 69993217f -> 960308763
[SPARK-24800][SQL] Refactor Avro Serializer and Deserializer
## What changes were proposed in this pull request?
Currently the Avro Deserializer converts input Avro format data to `Row`, and then convert the `Row` to `InternalRow`.
While the Avro Serializer converts `InternalRow` to `Row`, and then output Avro format data.
This PR allows direct conversion between `InternalRow` and Avro format data.
## How was this patch tested?
Unit test
Author: Gengliang Wang <ge...@databricks.com>
Closes #21762 from gengliangwang/avro_io.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/96030876
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/96030876
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/96030876
Branch: refs/heads/master
Commit: 96030876383822645a5b35698ee407a8d4eb76af
Parents: 6999321
Author: Gengliang Wang <ge...@databricks.com>
Authored: Sun Jul 15 22:06:33 2018 +0800
Committer: Wenchen Fan <we...@databricks.com>
Committed: Sun Jul 15 22:06:33 2018 +0800
----------------------------------------------------------------------
.../spark/sql/avro/AvroDeserializer.scala | 348 +++++++++++++++++++
.../apache/spark/sql/avro/AvroFileFormat.scala | 24 +-
.../spark/sql/avro/AvroOutputWriter.scala | 109 +-----
.../sql/avro/AvroOutputWriterFactory.scala | 5 +-
.../apache/spark/sql/avro/AvroSerializer.scala | 180 ++++++++++
.../spark/sql/avro/SchemaConverters.scala | 333 ++----------------
.../spark/sql/avro/SerializableSchema.scala | 69 ++++
.../org/apache/spark/sql/avro/AvroSuite.scala | 1 -
.../sql/avro/SerializableSchemaSuite.scala | 56 +++
9 files changed, 704 insertions(+), 421 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/96030876/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala
----------------------------------------------------------------------
diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala
new file mode 100644
index 0000000..b31149a
--- /dev/null
+++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala
@@ -0,0 +1,348 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.avro
+
+import java.nio.ByteBuffer
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.avro.{Schema, SchemaBuilder}
+import org.apache.avro.Schema.Type._
+import org.apache.avro.generic._
+import org.apache.avro.util.Utf8
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{SpecificInternalRow, UnsafeArrayData}
+import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, DateTimeUtils, GenericArrayData}
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
+
+/**
+ * A deserializer to deserialize data in avro format to data in catalyst format.
+ */
+class AvroDeserializer(rootAvroType: Schema, rootCatalystType: DataType) {
+ private val converter: Any => Any = rootCatalystType match {
+ // A shortcut for empty schema.
+ case st: StructType if st.isEmpty =>
+ (data: Any) => InternalRow.empty
+
+ case st: StructType =>
+ val resultRow = new SpecificInternalRow(st.map(_.dataType))
+ val fieldUpdater = new RowUpdater(resultRow)
+ val writer = getRecordWriter(rootAvroType, st, Nil)
+ (data: Any) => {
+ val record = data.asInstanceOf[GenericRecord]
+ writer(fieldUpdater, record)
+ resultRow
+ }
+
+ case _ =>
+ val tmpRow = new SpecificInternalRow(Seq(rootCatalystType))
+ val fieldUpdater = new RowUpdater(tmpRow)
+ val writer = newWriter(rootAvroType, rootCatalystType, Nil)
+ (data: Any) => {
+ writer(fieldUpdater, 0, data)
+ tmpRow.get(0, rootCatalystType)
+ }
+ }
+
+ def deserialize(data: Any): Any = converter(data)
+
+ /**
+ * Creates a writer to write avro values to Catalyst values at the given ordinal with the given
+ * updater.
+ */
+ private def newWriter(
+ avroType: Schema,
+ catalystType: DataType,
+ path: List[String]): (CatalystDataUpdater, Int, Any) => Unit =
+ (avroType.getType, catalystType) match {
+ case (NULL, NullType) => (updater, ordinal, _) =>
+ updater.setNullAt(ordinal)
+
+ // TODO: we can avoid boxing if future version of avro provide primitive accessors.
+ case (BOOLEAN, BooleanType) => (updater, ordinal, value) =>
+ updater.setBoolean(ordinal, value.asInstanceOf[Boolean])
+
+ case (INT, IntegerType) => (updater, ordinal, value) =>
+ updater.setInt(ordinal, value.asInstanceOf[Int])
+
+ case (LONG, LongType) => (updater, ordinal, value) =>
+ updater.setLong(ordinal, value.asInstanceOf[Long])
+
+ case (LONG, TimestampType) => (updater, ordinal, value) =>
+ updater.setLong(ordinal, value.asInstanceOf[Long] * 1000)
+
+ case (LONG, DateType) => (updater, ordinal, value) =>
+ updater.setInt(ordinal, (value.asInstanceOf[Long] / DateTimeUtils.MILLIS_PER_DAY).toInt)
+
+ case (FLOAT, FloatType) => (updater, ordinal, value) =>
+ updater.setFloat(ordinal, value.asInstanceOf[Float])
+
+ case (DOUBLE, DoubleType) => (updater, ordinal, value) =>
+ updater.setDouble(ordinal, value.asInstanceOf[Double])
+
+ case (STRING, StringType) => (updater, ordinal, value) =>
+ val str = value match {
+ case s: String => UTF8String.fromString(s)
+ case s: Utf8 =>
+ val bytes = new Array[Byte](s.getByteLength)
+ System.arraycopy(s.getBytes, 0, bytes, 0, s.getByteLength)
+ UTF8String.fromBytes(bytes)
+ }
+ updater.set(ordinal, str)
+
+ case (ENUM, StringType) => (updater, ordinal, value) =>
+ updater.set(ordinal, UTF8String.fromString(value.toString))
+
+ case (FIXED, BinaryType) => (updater, ordinal, value) =>
+ updater.set(ordinal, value.asInstanceOf[GenericFixed].bytes().clone())
+
+ case (BYTES, BinaryType) => (updater, ordinal, value) =>
+ val bytes = value match {
+ case b: ByteBuffer =>
+ val bytes = new Array[Byte](b.remaining)
+ b.get(bytes)
+ bytes
+ case b: Array[Byte] => b
+ case other => throw new RuntimeException(s"$other is not a valid avro binary.")
+
+ }
+ updater.set(ordinal, bytes)
+
+ case (RECORD, st: StructType) =>
+ val writeRecord = getRecordWriter(avroType, st, path)
+ (updater, ordinal, value) =>
+ val row = new SpecificInternalRow(st)
+ writeRecord(new RowUpdater(row), value.asInstanceOf[GenericRecord])
+ updater.set(ordinal, row)
+
+ case (ARRAY, ArrayType(elementType, containsNull)) =>
+ val elementWriter = newWriter(avroType.getElementType, elementType, path)
+ (updater, ordinal, value) =>
+ val array = value.asInstanceOf[GenericData.Array[Any]]
+ val len = array.size()
+ val result = createArrayData(elementType, len)
+ val elementUpdater = new ArrayDataUpdater(result)
+
+ var i = 0
+ while (i < len) {
+ val element = array.get(i)
+ if (element == null) {
+ if (!containsNull) {
+ throw new RuntimeException(s"Array value at path ${path.mkString(".")} is not " +
+ "allowed to be null")
+ } else {
+ elementUpdater.setNullAt(i)
+ }
+ } else {
+ elementWriter(elementUpdater, i, element)
+ }
+ i += 1
+ }
+
+ updater.set(ordinal, result)
+
+ case (MAP, MapType(keyType, valueType, valueContainsNull)) if keyType == StringType =>
+ val keyWriter = newWriter(SchemaBuilder.builder().stringType(), StringType, path)
+ val valueWriter = newWriter(avroType.getValueType, valueType, path)
+ (updater, ordinal, value) =>
+ val map = value.asInstanceOf[java.util.Map[AnyRef, AnyRef]]
+ val keyArray = createArrayData(keyType, map.size())
+ val keyUpdater = new ArrayDataUpdater(keyArray)
+ val valueArray = createArrayData(valueType, map.size())
+ val valueUpdater = new ArrayDataUpdater(valueArray)
+ val iter = map.entrySet().iterator()
+ var i = 0
+ while (iter.hasNext) {
+ val entry = iter.next()
+ assert(entry.getKey != null)
+ keyWriter(keyUpdater, i, entry.getKey)
+ if (entry.getValue == null) {
+ if (!valueContainsNull) {
+ throw new RuntimeException(s"Map value at path ${path.mkString(".")} is not " +
+ "allowed to be null")
+ } else {
+ valueUpdater.setNullAt(i)
+ }
+ } else {
+ valueWriter(valueUpdater, i, entry.getValue)
+ }
+ i += 1
+ }
+
+ updater.set(ordinal, new ArrayBasedMapData(keyArray, valueArray))
+
+ case (UNION, _) =>
+ val allTypes = avroType.getTypes.asScala
+ val nonNullTypes = allTypes.filter(_.getType != NULL)
+ if (nonNullTypes.nonEmpty) {
+ if (nonNullTypes.length == 1) {
+ newWriter(nonNullTypes.head, catalystType, path)
+ } else {
+ nonNullTypes.map(_.getType) match {
+ case Seq(a, b) if Set(a, b) == Set(INT, LONG) && catalystType == LongType =>
+ (updater, ordinal, value) => value match {
+ case null => updater.setNullAt(ordinal)
+ case l: java.lang.Long => updater.setLong(ordinal, l)
+ case i: java.lang.Integer => updater.setLong(ordinal, i.longValue())
+ }
+
+ case Seq(a, b) if Set(a, b) == Set(FLOAT, DOUBLE) && catalystType == DoubleType =>
+ (updater, ordinal, value) => value match {
+ case null => updater.setNullAt(ordinal)
+ case d: java.lang.Double => updater.setDouble(ordinal, d)
+ case f: java.lang.Float => updater.setDouble(ordinal, f.doubleValue())
+ }
+
+ case _ =>
+ catalystType match {
+ case st: StructType if st.length == nonNullTypes.size =>
+ val fieldWriters = nonNullTypes.zip(st.fields).map {
+ case (schema, field) => newWriter(schema, field.dataType, path :+ field.name)
+ }.toArray
+ (updater, ordinal, value) => {
+ val row = new SpecificInternalRow(st)
+ val fieldUpdater = new RowUpdater(row)
+ val i = GenericData.get().resolveUnion(avroType, value)
+ fieldWriters(i)(fieldUpdater, i, value)
+ updater.set(ordinal, row)
+ }
+
+ case _ =>
+ throw new IncompatibleSchemaException(
+ s"Cannot convert Avro to catalyst because schema at path " +
+ s"${path.mkString(".")} is not compatible " +
+ s"(avroType = $avroType, sqlType = $catalystType).\n" +
+ s"Source Avro schema: $rootAvroType.\n" +
+ s"Target Catalyst type: $rootCatalystType")
+ }
+ }
+ }
+ } else {
+ (updater, ordinal, value) => updater.setNullAt(ordinal)
+ }
+
+ case _ =>
+ throw new IncompatibleSchemaException(
+ s"Cannot convert Avro to catalyst because schema at path ${path.mkString(".")} " +
+ s"is not compatible (avroType = $avroType, sqlType = $catalystType).\n" +
+ s"Source Avro schema: $rootAvroType.\n" +
+ s"Target Catalyst type: $rootCatalystType")
+ }
+
+ private def getRecordWriter(
+ avroType: Schema,
+ sqlType: StructType,
+ path: List[String]): (CatalystDataUpdater, GenericRecord) => Unit = {
+ val validFieldIndexes = ArrayBuffer.empty[Int]
+ val fieldWriters = ArrayBuffer.empty[(CatalystDataUpdater, Any) => Unit]
+
+ val length = sqlType.length
+ var i = 0
+ while (i < length) {
+ val sqlField = sqlType.fields(i)
+ val avroField = avroType.getField(sqlField.name)
+ if (avroField != null) {
+ validFieldIndexes += avroField.pos()
+
+ val baseWriter = newWriter(avroField.schema(), sqlField.dataType, path :+ sqlField.name)
+ val ordinal = i
+ val fieldWriter = (fieldUpdater: CatalystDataUpdater, value: Any) => {
+ if (value == null) {
+ fieldUpdater.setNullAt(ordinal)
+ } else {
+ baseWriter(fieldUpdater, ordinal, value)
+ }
+ }
+ fieldWriters += fieldWriter
+ } else if (!sqlField.nullable) {
+ throw new IncompatibleSchemaException(
+ s"""
+ |Cannot find non-nullable field ${path.mkString(".")}.${sqlField.name} in Avro schema.
+ |Source Avro schema: $rootAvroType.
+ |Target Catalyst type: $rootCatalystType.
+ """.stripMargin)
+ }
+ i += 1
+ }
+
+ (fieldUpdater, record) => {
+ var i = 0
+ while (i < validFieldIndexes.length) {
+ fieldWriters(i)(fieldUpdater, record.get(validFieldIndexes(i)))
+ i += 1
+ }
+ }
+ }
+
+ private def createArrayData(elementType: DataType, length: Int): ArrayData = elementType match {
+ case BooleanType => UnsafeArrayData.fromPrimitiveArray(new Array[Boolean](length))
+ case ByteType => UnsafeArrayData.fromPrimitiveArray(new Array[Byte](length))
+ case ShortType => UnsafeArrayData.fromPrimitiveArray(new Array[Short](length))
+ case IntegerType => UnsafeArrayData.fromPrimitiveArray(new Array[Int](length))
+ case LongType => UnsafeArrayData.fromPrimitiveArray(new Array[Long](length))
+ case FloatType => UnsafeArrayData.fromPrimitiveArray(new Array[Float](length))
+ case DoubleType => UnsafeArrayData.fromPrimitiveArray(new Array[Double](length))
+ case _ => new GenericArrayData(new Array[Any](length))
+ }
+
+ /**
+ * A base interface for updating values inside catalyst data structure like `InternalRow` and
+ * `ArrayData`.
+ */
+ sealed trait CatalystDataUpdater {
+ def set(ordinal: Int, value: Any): Unit
+
+ def setNullAt(ordinal: Int): Unit = set(ordinal, null)
+ def setBoolean(ordinal: Int, value: Boolean): Unit = set(ordinal, value)
+ def setByte(ordinal: Int, value: Byte): Unit = set(ordinal, value)
+ def setShort(ordinal: Int, value: Short): Unit = set(ordinal, value)
+ def setInt(ordinal: Int, value: Int): Unit = set(ordinal, value)
+ def setLong(ordinal: Int, value: Long): Unit = set(ordinal, value)
+ def setDouble(ordinal: Int, value: Double): Unit = set(ordinal, value)
+ def setFloat(ordinal: Int, value: Float): Unit = set(ordinal, value)
+ }
+
+ final class RowUpdater(row: InternalRow) extends CatalystDataUpdater {
+ override def set(ordinal: Int, value: Any): Unit = row.update(ordinal, value)
+
+ override def setNullAt(ordinal: Int): Unit = row.setNullAt(ordinal)
+ override def setBoolean(ordinal: Int, value: Boolean): Unit = row.setBoolean(ordinal, value)
+ override def setByte(ordinal: Int, value: Byte): Unit = row.setByte(ordinal, value)
+ override def setShort(ordinal: Int, value: Short): Unit = row.setShort(ordinal, value)
+ override def setInt(ordinal: Int, value: Int): Unit = row.setInt(ordinal, value)
+ override def setLong(ordinal: Int, value: Long): Unit = row.setLong(ordinal, value)
+ override def setDouble(ordinal: Int, value: Double): Unit = row.setDouble(ordinal, value)
+ override def setFloat(ordinal: Int, value: Float): Unit = row.setFloat(ordinal, value)
+ }
+
+ final class ArrayDataUpdater(array: ArrayData) extends CatalystDataUpdater {
+ override def set(ordinal: Int, value: Any): Unit = array.update(ordinal, value)
+
+ override def setNullAt(ordinal: Int): Unit = array.setNullAt(ordinal)
+ override def setBoolean(ordinal: Int, value: Boolean): Unit = array.setBoolean(ordinal, value)
+ override def setByte(ordinal: Int, value: Byte): Unit = array.setByte(ordinal, value)
+ override def setShort(ordinal: Int, value: Short): Unit = array.setShort(ordinal, value)
+ override def setInt(ordinal: Int, value: Int): Unit = array.setInt(ordinal, value)
+ override def setLong(ordinal: Int, value: Long): Unit = array.setLong(ordinal, value)
+ override def setDouble(ordinal: Int, value: Double): Unit = array.setDouble(ordinal, value)
+ override def setFloat(ordinal: Int, value: Float): Unit = array.setFloat(ordinal, value)
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/96030876/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala
----------------------------------------------------------------------
diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala
index 46e5a18..fb93033 100755
--- a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala
+++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala
@@ -25,7 +25,7 @@ import scala.util.control.NonFatal
import com.esotericsoftware.kryo.{Kryo, KryoSerializable}
import com.esotericsoftware.kryo.io.{Input, Output}
-import org.apache.avro.{Schema, SchemaBuilder}
+import org.apache.avro.Schema
import org.apache.avro.file.{DataFileConstants, DataFileReader}
import org.apache.avro.generic.{GenericDatumReader, GenericRecord}
import org.apache.avro.mapred.{AvroOutputFormat, FsInput}
@@ -38,8 +38,6 @@ import org.slf4j.LoggerFactory
import org.apache.spark.TaskContext
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.encoders.RowEncoder
-import org.apache.spark.sql.catalyst.expressions.GenericRow
import org.apache.spark.sql.execution.datasources.{FileFormat, OutputWriterFactory, PartitionedFile}
import org.apache.spark.sql.sources.{DataSourceRegister, Filter}
import org.apache.spark.sql.types.StructType
@@ -118,8 +116,8 @@ private[avro] class AvroFileFormat extends FileFormat with DataSourceRegister {
dataSchema: StructType): OutputWriterFactory = {
val recordName = options.getOrElse("recordName", "topLevelRecord")
val recordNamespace = options.getOrElse("recordNamespace", "")
- val build = SchemaBuilder.record(recordName).namespace(recordNamespace)
- val outputAvroSchema = SchemaConverters.convertStructToAvro(dataSchema, build, recordNamespace)
+ val outputAvroSchema = SchemaConverters.toAvroType(
+ dataSchema, nullable = false, recordName, recordNamespace)
AvroJob.setOutputKeySchema(job, outputAvroSchema)
val AVRO_COMPRESSION_CODEC = "spark.sql.avro.compression.codec"
@@ -148,7 +146,7 @@ private[avro] class AvroFileFormat extends FileFormat with DataSourceRegister {
log.error(s"unsupported compression codec $unknown")
}
- new AvroOutputWriterFactory(dataSchema, recordName, recordNamespace)
+ new AvroOutputWriterFactory(dataSchema, new SerializableSchema(outputAvroSchema))
}
override def buildReader(
@@ -205,13 +203,10 @@ private[avro] class AvroFileFormat extends FileFormat with DataSourceRegister {
reader.sync(file.start)
val stop = file.start + file.length
- val rowConverter = SchemaConverters.createConverterToSQL(
- userProvidedSchema.getOrElse(reader.getSchema), requiredSchema)
+ val deserializer =
+ new AvroDeserializer(userProvidedSchema.getOrElse(reader.getSchema), requiredSchema)
new Iterator[InternalRow] {
- // Used to convert `Row`s containing data columns into `InternalRow`s.
- private val encoderForDataColumns = RowEncoder(requiredSchema)
-
private[this] var completed = false
override def hasNext: Boolean = {
@@ -228,14 +223,11 @@ private[avro] class AvroFileFormat extends FileFormat with DataSourceRegister {
}
override def next(): InternalRow = {
- if (reader.pastSync(stop)) {
+ if (!hasNext) {
throw new NoSuchElementException("next on empty iterator")
}
val record = reader.next()
- val safeDataRow = rowConverter(record).asInstanceOf[GenericRow]
-
- // The safeDataRow is reused, we must do a copy
- encoderForDataColumns.toRow(safeDataRow)
+ deserializer.deserialize(record).asInstanceOf[InternalRow]
}
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/96030876/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriter.scala
----------------------------------------------------------------------
diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriter.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriter.scala
index 830bf3c..0650711 100644
--- a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriter.scala
+++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriter.scala
@@ -18,14 +18,8 @@
package org.apache.spark.sql.avro
import java.io.{IOException, OutputStream}
-import java.nio.ByteBuffer
-import java.sql.{Date, Timestamp}
-import java.util.HashMap
-import scala.collection.immutable.Map
-
-import org.apache.avro.{Schema, SchemaBuilder}
-import org.apache.avro.generic.GenericData.Record
+import org.apache.avro.Schema
import org.apache.avro.generic.GenericRecord
import org.apache.avro.mapred.AvroKey
import org.apache.avro.mapreduce.AvroKeyOutputFormat
@@ -33,8 +27,7 @@ import org.apache.hadoop.fs.Path
import org.apache.hadoop.io.NullWritable
import org.apache.hadoop.mapreduce.{RecordWriter, TaskAttemptContext}
-import org.apache.spark.sql.Row
-import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
+import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.datasources.OutputWriter
import org.apache.spark.sql.types._
@@ -43,13 +36,10 @@ private[avro] class AvroOutputWriter(
path: String,
context: TaskAttemptContext,
schema: StructType,
- recordName: String,
- recordNamespace: String) extends OutputWriter {
+ avroSchema: Schema) extends OutputWriter {
- private lazy val converter = createConverterToAvro(schema, recordName, recordNamespace)
- // copy of the old conversion logic after api change in SPARK-19085
- private lazy val internalRowConverter =
- CatalystTypeConverters.createToScalaConverter(schema).asInstanceOf[InternalRow => Row]
+ // The input rows will never be null.
+ private lazy val serializer = new AvroSerializer(schema, avroSchema, nullable = false)
/**
* Overrides the couple of methods responsible for generating the output streams / files so
@@ -70,95 +60,10 @@ private[avro] class AvroOutputWriter(
}.getRecordWriter(context)
- override def write(internalRow: InternalRow): Unit = {
- val row = internalRowConverter(internalRow)
- val key = new AvroKey(converter(row).asInstanceOf[GenericRecord])
+ override def write(row: InternalRow): Unit = {
+ val key = new AvroKey(serializer.serialize(row).asInstanceOf[GenericRecord])
recordWriter.write(key, NullWritable.get())
}
override def close(): Unit = recordWriter.close(context)
-
- /**
- * This function constructs converter function for a given sparkSQL datatype. This is used in
- * writing Avro records out to disk
- */
- private def createConverterToAvro(
- dataType: DataType,
- structName: String,
- recordNamespace: String): (Any) => Any = {
- dataType match {
- case BinaryType => (item: Any) => item match {
- case null => null
- case bytes: Array[Byte] => ByteBuffer.wrap(bytes)
- }
- case ByteType | ShortType | IntegerType | LongType |
- FloatType | DoubleType | StringType | BooleanType => identity
- case _: DecimalType => (item: Any) => if (item == null) null else item.toString
- case TimestampType => (item: Any) =>
- if (item == null) null else item.asInstanceOf[Timestamp].getTime
- case DateType => (item: Any) =>
- if (item == null) null else item.asInstanceOf[Date].getTime
- case ArrayType(elementType, _) =>
- val elementConverter = createConverterToAvro(
- elementType,
- structName,
- SchemaConverters.getNewRecordNamespace(elementType, recordNamespace, structName))
- (item: Any) => {
- if (item == null) {
- null
- } else {
- val sourceArray = item.asInstanceOf[Seq[Any]]
- val sourceArraySize = sourceArray.size
- val targetArray = new Array[Any](sourceArraySize)
- var idx = 0
- while (idx < sourceArraySize) {
- targetArray(idx) = elementConverter(sourceArray(idx))
- idx += 1
- }
- targetArray
- }
- }
- case MapType(StringType, valueType, _) =>
- val valueConverter = createConverterToAvro(
- valueType,
- structName,
- SchemaConverters.getNewRecordNamespace(valueType, recordNamespace, structName))
- (item: Any) => {
- if (item == null) {
- null
- } else {
- val javaMap = new HashMap[String, Any]()
- item.asInstanceOf[Map[String, Any]].foreach { case (key, value) =>
- javaMap.put(key, valueConverter(value))
- }
- javaMap
- }
- }
- case structType: StructType =>
- val builder = SchemaBuilder.record(structName).namespace(recordNamespace)
- val schema: Schema = SchemaConverters.convertStructToAvro(
- structType, builder, recordNamespace)
- val fieldConverters = structType.fields.map(field =>
- createConverterToAvro(
- field.dataType,
- field.name,
- SchemaConverters.getNewRecordNamespace(field.dataType, recordNamespace, field.name)))
- (item: Any) => {
- if (item == null) {
- null
- } else {
- val record = new Record(schema)
- val convertersIterator = fieldConverters.iterator
- val fieldNamesIterator = dataType.asInstanceOf[StructType].fieldNames.iterator
- val rowIterator = item.asInstanceOf[Row].toSeq.iterator
-
- while (convertersIterator.hasNext) {
- val converter = convertersIterator.next()
- record.put(fieldNamesIterator.next(), converter(rowIterator.next()))
- }
- record
- }
- }
- }
- }
}
http://git-wip-us.apache.org/repos/asf/spark/blob/96030876/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriterFactory.scala
----------------------------------------------------------------------
diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriterFactory.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriterFactory.scala
index 5b2ce7d..18a6d93 100644
--- a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriterFactory.scala
+++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriterFactory.scala
@@ -24,8 +24,7 @@ import org.apache.spark.sql.types.StructType
private[avro] class AvroOutputWriterFactory(
schema: StructType,
- recordName: String,
- recordNamespace: String) extends OutputWriterFactory {
+ avroSchema: SerializableSchema) extends OutputWriterFactory {
override def getFileExtension(context: TaskAttemptContext): String = ".avro"
@@ -33,6 +32,6 @@ private[avro] class AvroOutputWriterFactory(
path: String,
dataSchema: StructType,
context: TaskAttemptContext): OutputWriter = {
- new AvroOutputWriter(path, context, schema, recordName, recordNamespace)
+ new AvroOutputWriter(path, context, schema, avroSchema.value)
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/96030876/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala
----------------------------------------------------------------------
diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala
new file mode 100644
index 0000000..2b4c581
--- /dev/null
+++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala
@@ -0,0 +1,180 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.avro
+
+import java.nio.ByteBuffer
+
+import scala.collection.JavaConverters._
+
+import org.apache.avro.Schema
+import org.apache.avro.Schema.Type.NULL
+import org.apache.avro.generic.GenericData.Record
+import org.apache.avro.util.Utf8
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{SpecializedGetters, SpecificInternalRow}
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
+import org.apache.spark.sql.types._
+
+/**
+ * A serializer to serialize data in catalyst format to data in avro format.
+ */
+class AvroSerializer(rootCatalystType: DataType, rootAvroType: Schema, nullable: Boolean) {
+
+ def serialize(catalystData: Any): Any = {
+ converter.apply(catalystData)
+ }
+
+ private val converter: Any => Any = {
+ val actualAvroType = resolveNullableType(rootAvroType, nullable)
+ val baseConverter = rootCatalystType match {
+ case st: StructType =>
+ newStructConverter(st, actualAvroType).asInstanceOf[Any => Any]
+ case _ =>
+ val tmpRow = new SpecificInternalRow(Seq(rootCatalystType))
+ val converter = newConverter(rootCatalystType, actualAvroType)
+ (data: Any) =>
+ tmpRow.update(0, data)
+ converter.apply(tmpRow, 0)
+ }
+ if (nullable) {
+ (data: Any) =>
+ if (data == null) {
+ null
+ } else {
+ baseConverter.apply(data)
+ }
+ } else {
+ baseConverter
+ }
+ }
+
+ private type Converter = (SpecializedGetters, Int) => Any
+
+ private def newConverter(catalystType: DataType, avroType: Schema): Converter = {
+ catalystType match {
+ case NullType =>
+ (getter, ordinal) => null
+ case BooleanType =>
+ (getter, ordinal) => getter.getBoolean(ordinal)
+ case ByteType =>
+ (getter, ordinal) => getter.getByte(ordinal).toInt
+ case ShortType =>
+ (getter, ordinal) => getter.getShort(ordinal).toInt
+ case IntegerType =>
+ (getter, ordinal) => getter.getInt(ordinal)
+ case LongType =>
+ (getter, ordinal) => getter.getLong(ordinal)
+ case FloatType =>
+ (getter, ordinal) => getter.getFloat(ordinal)
+ case DoubleType =>
+ (getter, ordinal) => getter.getDouble(ordinal)
+ case d: DecimalType =>
+ (getter, ordinal) => getter.getDecimal(ordinal, d.precision, d.scale).toString
+ case StringType =>
+ (getter, ordinal) => new Utf8(getter.getUTF8String(ordinal).getBytes)
+ case BinaryType =>
+ (getter, ordinal) => ByteBuffer.wrap(getter.getBinary(ordinal))
+ case DateType =>
+ (getter, ordinal) => getter.getInt(ordinal) * DateTimeUtils.MILLIS_PER_DAY
+ case TimestampType =>
+ (getter, ordinal) => getter.getLong(ordinal) / 1000
+
+ case ArrayType(et, containsNull) =>
+ val elementConverter = newConverter(
+ et, resolveNullableType(avroType.getElementType, containsNull))
+ (getter, ordinal) => {
+ val arrayData = getter.getArray(ordinal)
+ val result = new java.util.ArrayList[Any]
+ var i = 0
+ while (i < arrayData.numElements()) {
+ if (arrayData.isNullAt(i)) {
+ result.add(null)
+ } else {
+ result.add(elementConverter(arrayData, i))
+ }
+ i += 1
+ }
+ result
+ }
+
+ case st: StructType =>
+ val structConverter = newStructConverter(st, avroType)
+ val numFields = st.length
+ (getter, ordinal) => structConverter(getter.getStruct(ordinal, numFields))
+
+ case MapType(kt, vt, valueContainsNull) if kt == StringType =>
+ val valueConverter = newConverter(
+ vt, resolveNullableType(avroType.getValueType, valueContainsNull))
+ (getter, ordinal) =>
+ val mapData = getter.getMap(ordinal)
+ val result = new java.util.HashMap[String, Any](mapData.numElements())
+ val keyArray = mapData.keyArray()
+ val valueArray = mapData.valueArray()
+ var i = 0
+ while (i < mapData.numElements()) {
+ val key = keyArray.getUTF8String(i).toString
+ if (valueArray.isNullAt(i)) {
+ result.put(key, null)
+ } else {
+ result.put(key, valueConverter(valueArray, i))
+ }
+ i += 1
+ }
+ result
+
+ case other =>
+ throw new IncompatibleSchemaException(s"Unexpected type: $other")
+ }
+ }
+
+ private def newStructConverter(
+ catalystStruct: StructType, avroStruct: Schema): InternalRow => Record = {
+ val avroFields = avroStruct.getFields
+ assert(avroFields.size() == catalystStruct.length)
+ val fieldConverters = catalystStruct.zip(avroFields.asScala).map {
+ case (f1, f2) => newConverter(f1.dataType, resolveNullableType(f2.schema(), f1.nullable))
+ }
+ val numFields = catalystStruct.length
+ (row: InternalRow) =>
+ val result = new Record(avroStruct)
+ var i = 0
+ while (i < numFields) {
+ if (row.isNullAt(i)) {
+ result.put(i, null)
+ } else {
+ result.put(i, fieldConverters(i).apply(row, i))
+ }
+ i += 1
+ }
+ result
+ }
+
+ private def resolveNullableType(avroType: Schema, nullable: Boolean): Schema = {
+ if (nullable) {
+ // avro uses union to represent nullable type.
+ val fields = avroType.getTypes.asScala
+ assert(fields.length == 2)
+ val actualType = fields.filter(_.getType != NULL)
+ assert(actualType.length == 1)
+ actualType.head
+ } else {
+ avroType
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/96030876/external/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala
----------------------------------------------------------------------
diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala
index 01f8c74..87fae63 100644
--- a/external/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala
+++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala
@@ -17,18 +17,11 @@
package org.apache.spark.sql.avro
-import java.nio.ByteBuffer
-import java.sql.{Date, Timestamp}
-
import scala.collection.JavaConverters._
import org.apache.avro.{Schema, SchemaBuilder}
import org.apache.avro.Schema.Type._
-import org.apache.avro.SchemaBuilder._
-import org.apache.avro.generic.{GenericData, GenericRecord}
-import org.apache.avro.generic.GenericFixed
-import org.apache.spark.sql.catalyst.expressions.GenericRow
import org.apache.spark.sql.types._
/**
@@ -36,9 +29,6 @@ import org.apache.spark.sql.types._
* versa.
*/
object SchemaConverters {
-
- class IncompatibleSchemaException(msg: String, ex: Throwable = null) extends Exception(msg, ex)
-
case class SchemaType(dataType: DataType, nullable: Boolean)
/**
@@ -109,298 +99,43 @@ object SchemaConverters {
}
}
- /**
- * This function converts sparkSQL StructType into avro schema. This method uses two other
- * converter methods in order to do the conversion.
- */
- def convertStructToAvro[T](
- structType: StructType,
- schemaBuilder: RecordBuilder[T],
- recordNamespace: String): T = {
- val fieldsAssembler: FieldAssembler[T] = schemaBuilder.fields()
- structType.fields.foreach { field =>
- val newField = fieldsAssembler.name(field.name).`type`()
-
- if (field.nullable) {
- convertFieldTypeToAvro(field.dataType, newField.nullable(), field.name, recordNamespace)
- .noDefault
- } else {
- convertFieldTypeToAvro(field.dataType, newField, field.name, recordNamespace)
- .noDefault
- }
- }
- fieldsAssembler.endRecord()
- }
-
- /**
- * Returns a converter function to convert row in avro format to GenericRow of catalyst.
- *
- * @param sourceAvroSchema Source schema before conversion inferred from avro file by passed in
- * by user.
- * @param targetSqlType Target catalyst sql type after the conversion.
- * @return returns a converter function to convert row in avro format to GenericRow of catalyst.
- */
- private[avro] def createConverterToSQL(
- sourceAvroSchema: Schema,
- targetSqlType: DataType): AnyRef => AnyRef = {
-
- def createConverter(avroSchema: Schema,
- sqlType: DataType, path: List[String]): AnyRef => AnyRef = {
- val avroType = avroSchema.getType
- (sqlType, avroType) match {
- // Avro strings are in Utf8, so we have to call toString on them
- case (StringType, STRING) | (StringType, ENUM) =>
- (item: AnyRef) => item.toString
- // Byte arrays are reused by avro, so we have to make a copy of them.
- case (IntegerType, INT) | (BooleanType, BOOLEAN) | (DoubleType, DOUBLE) |
- (FloatType, FLOAT) | (LongType, LONG) =>
- identity
- case (TimestampType, LONG) =>
- (item: AnyRef) => new Timestamp(item.asInstanceOf[Long])
- case (DateType, LONG) =>
- (item: AnyRef) => new Date(item.asInstanceOf[Long])
- case (BinaryType, FIXED) =>
- (item: AnyRef) => item.asInstanceOf[GenericFixed].bytes().clone()
- case (BinaryType, BYTES) =>
- (item: AnyRef) =>
- val byteBuffer = item.asInstanceOf[ByteBuffer]
- val bytes = new Array[Byte](byteBuffer.remaining)
- byteBuffer.get(bytes)
- bytes
- case (struct: StructType, RECORD) =>
- val length = struct.fields.length
- val converters = new Array[AnyRef => AnyRef](length)
- val avroFieldIndexes = new Array[Int](length)
- var i = 0
- while (i < length) {
- val sqlField = struct.fields(i)
- val avroField = avroSchema.getField(sqlField.name)
- if (avroField != null) {
- val converter = (item: AnyRef) => {
- if (item == null) {
- item
- } else {
- createConverter(avroField.schema, sqlField.dataType, path :+ sqlField.name)(item)
- }
- }
- converters(i) = converter
- avroFieldIndexes(i) = avroField.pos()
- } else if (!sqlField.nullable) {
- throw new IncompatibleSchemaException(
- s"Cannot find non-nullable field ${sqlField.name} at path ${path.mkString(".")} " +
- "in Avro schema\n" +
- s"Source Avro schema: $sourceAvroSchema.\n" +
- s"Target Catalyst type: $targetSqlType")
- }
- i += 1
- }
-
- (item: AnyRef) =>
- val record = item.asInstanceOf[GenericRecord]
- val result = new Array[Any](length)
- var i = 0
- while (i < converters.length) {
- if (converters(i) != null) {
- val converter = converters(i)
- result(i) = converter(record.get(avroFieldIndexes(i)))
- }
- i += 1
- }
- new GenericRow(result)
- case (arrayType: ArrayType, ARRAY) =>
- val elementConverter = createConverter(avroSchema.getElementType, arrayType.elementType,
- path)
- val allowsNull = arrayType.containsNull
- (item: AnyRef) =>
- item.asInstanceOf[java.lang.Iterable[AnyRef]].asScala.map { element =>
- if (element == null && !allowsNull) {
- throw new RuntimeException(s"Array value at path ${path.mkString(".")} is not " +
- "allowed to be null")
- } else {
- elementConverter(element)
- }
- }
- case (mapType: MapType, MAP) if mapType.keyType == StringType =>
- val valueConverter = createConverter(avroSchema.getValueType, mapType.valueType, path)
- val allowsNull = mapType.valueContainsNull
- (item: AnyRef) =>
- item.asInstanceOf[java.util.Map[AnyRef, AnyRef]].asScala.map { case (k, v) =>
- if (v == null && !allowsNull) {
- throw new RuntimeException(s"Map value at path ${path.mkString(".")} is not " +
- "allowed to be null")
- } else {
- (k.toString, valueConverter(v))
- }
- }.toMap
- case (sqlType, UNION) =>
- if (avroSchema.getTypes.asScala.exists(_.getType == NULL)) {
- val remainingUnionTypes = avroSchema.getTypes.asScala.filterNot(_.getType == NULL)
- if (remainingUnionTypes.size == 1) {
- createConverter(remainingUnionTypes.head, sqlType, path)
- } else {
- createConverter(Schema.createUnion(remainingUnionTypes.asJava), sqlType, path)
- }
- } else avroSchema.getTypes.asScala.map(_.getType) match {
- case Seq(t1) => createConverter(avroSchema.getTypes.get(0), sqlType, path)
- case Seq(a, b) if Set(a, b) == Set(INT, LONG) && sqlType == LongType =>
- (item: AnyRef) =>
- item match {
- case l: java.lang.Long => l
- case i: java.lang.Integer => new java.lang.Long(i.longValue())
- }
- case Seq(a, b) if Set(a, b) == Set(FLOAT, DOUBLE) && sqlType == DoubleType =>
- (item: AnyRef) =>
- item match {
- case d: java.lang.Double => d
- case f: java.lang.Float => new java.lang.Double(f.doubleValue())
- }
- case other =>
- sqlType match {
- case t: StructType if t.fields.length == avroSchema.getTypes.size =>
- val fieldConverters = t.fields.zip(avroSchema.getTypes.asScala).map {
- case (field, schema) =>
- createConverter(schema, field.dataType, path :+ field.name)
- }
- (item: AnyRef) =>
- val i = GenericData.get().resolveUnion(avroSchema, item)
- val converted = new Array[Any](fieldConverters.length)
- converted(i) = fieldConverters(i)(item)
- new GenericRow(converted)
- case _ => throw new IncompatibleSchemaException(
- s"Cannot convert Avro schema to catalyst type because schema at path " +
- s"${path.mkString(".")} is not compatible " +
- s"(avroType = $other, sqlType = $sqlType). \n" +
- s"Source Avro schema: $sourceAvroSchema.\n" +
- s"Target Catalyst type: $targetSqlType")
- }
- }
- case (left, right) =>
- throw new IncompatibleSchemaException(
- s"Cannot convert Avro schema to catalyst type because schema at path " +
- s"${path.mkString(".")} is not compatible (avroType = $right, sqlType = $left). \n" +
- s"Source Avro schema: $sourceAvroSchema.\n" +
- s"Target Catalyst type: $targetSqlType")
- }
- }
- createConverter(sourceAvroSchema, targetSqlType, List.empty[String])
- }
-
- /**
- * This function is used to convert some sparkSQL type to avro type. Note that this function won't
- * be used to construct fields of avro record (convertFieldTypeToAvro is used for that).
- */
- private def convertTypeToAvro[T](
- dataType: DataType,
- schemaBuilder: BaseTypeBuilder[T],
- structName: String,
- recordNamespace: String): T = {
- dataType match {
- case ByteType => schemaBuilder.intType()
- case ShortType => schemaBuilder.intType()
- case IntegerType => schemaBuilder.intType()
- case LongType => schemaBuilder.longType()
- case FloatType => schemaBuilder.floatType()
- case DoubleType => schemaBuilder.doubleType()
- case _: DecimalType => schemaBuilder.stringType()
- case StringType => schemaBuilder.stringType()
- case BinaryType => schemaBuilder.bytesType()
- case BooleanType => schemaBuilder.booleanType()
- case TimestampType => schemaBuilder.longType()
- case DateType => schemaBuilder.longType()
-
- case ArrayType(elementType, _) =>
- val builder = getSchemaBuilder(dataType.asInstanceOf[ArrayType].containsNull)
- val elementSchema = convertTypeToAvro(elementType, builder, structName, recordNamespace)
- schemaBuilder.array().items(elementSchema)
-
- case MapType(StringType, valueType, _) =>
- val builder = getSchemaBuilder(dataType.asInstanceOf[MapType].valueContainsNull)
- val valueSchema = convertTypeToAvro(valueType, builder, structName, recordNamespace)
- schemaBuilder.map().values(valueSchema)
-
- case structType: StructType =>
- convertStructToAvro(
- structType,
- schemaBuilder.record(structName).namespace(recordNamespace),
- recordNamespace)
-
- case other => throw new IncompatibleSchemaException(s"Unexpected type $dataType.")
- }
- }
-
- /**
- * This function is used to construct fields of the avro record, where schema of the field is
- * specified by avro representation of dataType. Since builders for record fields are different
- * from those for everything else, we have to use a separate method.
- */
- private def convertFieldTypeToAvro[T](
- dataType: DataType,
- newFieldBuilder: BaseFieldTypeBuilder[T],
- structName: String,
- recordNamespace: String): FieldDefault[T, _] = {
- dataType match {
- case ByteType => newFieldBuilder.intType()
- case ShortType => newFieldBuilder.intType()
- case IntegerType => newFieldBuilder.intType()
- case LongType => newFieldBuilder.longType()
- case FloatType => newFieldBuilder.floatType()
- case DoubleType => newFieldBuilder.doubleType()
- case _: DecimalType => newFieldBuilder.stringType()
- case StringType => newFieldBuilder.stringType()
- case BinaryType => newFieldBuilder.bytesType()
- case BooleanType => newFieldBuilder.booleanType()
- case TimestampType => newFieldBuilder.longType()
- case DateType => newFieldBuilder.longType()
-
- case ArrayType(elementType, _) =>
- val builder = getSchemaBuilder(dataType.asInstanceOf[ArrayType].containsNull)
- val elementSchema = convertTypeToAvro(
- elementType,
- builder,
- structName,
- getNewRecordNamespace(elementType, recordNamespace, structName))
- newFieldBuilder.array().items(elementSchema)
-
- case MapType(StringType, valueType, _) =>
- val builder = getSchemaBuilder(dataType.asInstanceOf[MapType].valueContainsNull)
- val valueSchema = convertTypeToAvro(
- valueType,
- builder,
- structName,
- getNewRecordNamespace(valueType, recordNamespace, structName))
- newFieldBuilder.map().values(valueSchema)
-
- case structType: StructType =>
- convertStructToAvro(
- structType,
- newFieldBuilder.record(structName).namespace(s"$recordNamespace.$structName"),
- s"$recordNamespace.$structName")
-
- case other => throw new IncompatibleSchemaException(s"Unexpected type $dataType.")
- }
- }
-
- /**
- * Returns a new namespace depending on the data type of the element.
- * If the data type is a StructType it returns the current namespace concatenated
- * with the element name, otherwise it returns the current namespace as it is.
- */
- private[avro] def getNewRecordNamespace(
- elementDataType: DataType,
- currentRecordNamespace: String,
- elementName: String): String = {
-
- elementDataType match {
- case StructType(_) => s"$currentRecordNamespace.$elementName"
- case _ => currentRecordNamespace
- }
- }
-
- private def getSchemaBuilder(isNullable: Boolean): BaseTypeBuilder[Schema] = {
- if (isNullable) {
+ def toAvroType(
+ catalystType: DataType,
+ nullable: Boolean = false,
+ recordName: String = "topLevelRecord",
+ prevNameSpace: String = ""): Schema = {
+ val builder = if (nullable) {
SchemaBuilder.builder().nullable()
} else {
SchemaBuilder.builder()
}
+ catalystType match {
+ case BooleanType => builder.booleanType()
+ case ByteType | ShortType | IntegerType => builder.intType()
+ case LongType => builder.longType()
+ case DateType => builder.longType()
+ case TimestampType => builder.longType()
+ case FloatType => builder.floatType()
+ case DoubleType => builder.doubleType()
+ case _: DecimalType | StringType => builder.stringType()
+ case BinaryType => builder.bytesType()
+ case ArrayType(et, containsNull) =>
+ builder.array().items(toAvroType(et, containsNull, recordName, prevNameSpace))
+ case MapType(StringType, vt, valueContainsNull) =>
+ builder.map().values(toAvroType(vt, valueContainsNull, recordName, prevNameSpace))
+ case st: StructType =>
+ val nameSpace = s"$prevNameSpace.$recordName"
+ val fieldsAssembler = builder.record(recordName).namespace(nameSpace).fields()
+ st.foreach { f =>
+ val fieldAvroType = toAvroType(f.dataType, f.nullable, f.name, nameSpace)
+ fieldsAssembler.name(f.name).`type`(fieldAvroType).noDefault()
+ }
+ fieldsAssembler.endRecord()
+
+ // This should never happen.
+ case other => throw new IncompatibleSchemaException(s"Unexpected type $other.")
+ }
}
}
+
+class IncompatibleSchemaException(msg: String, ex: Throwable = null) extends Exception(msg, ex)
http://git-wip-us.apache.org/repos/asf/spark/blob/96030876/external/avro/src/main/scala/org/apache/spark/sql/avro/SerializableSchema.scala
----------------------------------------------------------------------
diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/SerializableSchema.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/SerializableSchema.scala
new file mode 100644
index 0000000..ec0ddc7
--- /dev/null
+++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/SerializableSchema.scala
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.avro
+
+import java.io._
+
+import scala.util.control.NonFatal
+
+import com.esotericsoftware.kryo.{Kryo, KryoSerializable}
+import com.esotericsoftware.kryo.io.{Input, Output}
+import org.apache.avro.Schema
+import org.slf4j.LoggerFactory
+
+class SerializableSchema(@transient var value: Schema)
+ extends Serializable with KryoSerializable {
+
+ @transient private[avro] lazy val log = LoggerFactory.getLogger(getClass)
+
+ private def writeObject(out: ObjectOutputStream): Unit = tryOrIOException {
+ out.defaultWriteObject()
+ out.writeUTF(value.toString())
+ out.flush()
+ }
+
+ private def readObject(in: ObjectInputStream): Unit = tryOrIOException {
+ val json = in.readUTF()
+ value = new Schema.Parser().parse(json)
+ }
+
+ private def tryOrIOException[T](block: => T): T = {
+ try {
+ block
+ } catch {
+ case e: IOException =>
+ log.error("Exception encountered", e)
+ throw e
+ case NonFatal(e) =>
+ log.error("Exception encountered", e)
+ throw new IOException(e)
+ }
+ }
+
+ def write(kryo: Kryo, out: Output): Unit = {
+ val dos = new DataOutputStream(out)
+ dos.writeUTF(value.toString())
+ dos.flush()
+ }
+
+ def read(kryo: Kryo, in: Input): Unit = {
+ val dis = new DataInputStream(in)
+ val json = dis.readUTF()
+ value = new Schema.Parser().parse(json)
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/96030876/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala
----------------------------------------------------------------------
diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala
index 4f94d82..6ed6656 100644
--- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala
+++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala
@@ -32,7 +32,6 @@ import org.apache.avro.generic.GenericData.{EnumSymbol, Fixed}
import org.apache.commons.io.FileUtils
import org.apache.spark.sql._
-import org.apache.spark.sql.avro.SchemaConverters.IncompatibleSchemaException
import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils}
import org.apache.spark.sql.types._
http://git-wip-us.apache.org/repos/asf/spark/blob/96030876/external/avro/src/test/scala/org/apache/spark/sql/avro/SerializableSchemaSuite.scala
----------------------------------------------------------------------
diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/SerializableSchemaSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/SerializableSchemaSuite.scala
new file mode 100644
index 0000000..510bcbd
--- /dev/null
+++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/SerializableSchemaSuite.scala
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.avro
+
+import org.apache.avro.Schema
+
+import org.apache.spark.{SparkConf, SparkFunSuite}
+import org.apache.spark.serializer.{JavaSerializer, KryoSerializer, SerializerInstance}
+
+class SerializableSchemaSuite extends SparkFunSuite {
+
+ private def testSerialization(serializer: SerializerInstance): Unit = {
+ val avroTypeJson =
+ s"""
+ |{
+ | "type": "string",
+ | "name": "my_string"
+ |}
+ """.stripMargin
+ val avroSchema = new Schema.Parser().parse(avroTypeJson)
+ val serializableSchema = new SerializableSchema(avroSchema)
+ val serialized = serializer.serialize(serializableSchema)
+
+ serializer.deserialize[Any](serialized) match {
+ case c: SerializableSchema =>
+ assert(c.log != null, "log was null")
+ assert(c.value != null, "value was null")
+ assert(c.value == avroSchema)
+ case other => fail(
+ s"Expecting ${classOf[SerializableSchema]}, but got ${other.getClass}.")
+ }
+ }
+
+ test("serialization with JavaSerializer") {
+ testSerialization(new JavaSerializer(new SparkConf()).newInstance())
+ }
+
+ test("serialization with KryoSerializer") {
+ testSerialization(new KryoSerializer(new SparkConf()).newInstance())
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org