You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by do...@apache.org on 2023/06/07 02:41:24 UTC
[spark] branch master updated: [SPARK-43901][SQL] Avro to Support custom decimal type backed by Long
This is an automated email from the ASF dual-hosted git repository.
dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new d2f72a21677 [SPARK-43901][SQL] Avro to Support custom decimal type backed by Long
d2f72a21677 is described below
commit d2f72a21677748793f0bb329630d72bc91449587
Author: Siying Dong <si...@databricks.com>
AuthorDate: Tue Jun 6 19:41:09 2023 -0700
[SPARK-43901][SQL] Avro to Support custom decimal type backed by Long
### What changes were proposed in this pull request?
Add a logical type "custom-decimal" in Avro, which can only be backed by physical type long, and will be convert into decimal type.
### Why are the changes needed?
A user would like to represent currency (for money) after loading Avro into SQL type. However, there isn't a good way to represent it in Avro. This custom type will allow them to do that.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Added several unit test cases to test the new "custom-decimal" to be loaded successfully and also exception cases.
Closes #41409 from siying/customdecimal.
Authored-by: Siying Dong <si...@databricks.com>
Signed-off-by: Dongjoon Hyun <do...@apache.org>
---
.../org/apache/spark/sql/avro/CustomDecimal.scala | 79 ++++++++++++++++++
.../apache/spark/sql/avro/AvroDeserializer.scala | 4 +
.../org/apache/spark/sql/avro/AvroFileFormat.scala | 9 +++
.../apache/spark/sql/avro/SchemaConverters.scala | 2 +
.../spark/sql/avro/AvroLogicalTypeSuite.scala | 94 +++++++++++++++++++++-
5 files changed, 187 insertions(+), 1 deletion(-)
diff --git a/connector/avro/src/main/java/org/apache/spark/sql/avro/CustomDecimal.scala b/connector/avro/src/main/java/org/apache/spark/sql/avro/CustomDecimal.scala
new file mode 100644
index 00000000000..d76f40c7635
--- /dev/null
+++ b/connector/avro/src/main/java/org/apache/spark/sql/avro/CustomDecimal.scala
@@ -0,0 +1,79 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.avro
+
+import org.apache.avro.LogicalType
+import org.apache.avro.Schema
+
+import org.apache.spark.sql.types.DecimalType
+
+object CustomDecimal {
+ val TYPE_NAME = "custom-decimal"
+}
+
+// A customized logical type, which will be registered to Avro. This logical type is similar to
+// Avro's builtin Decimal type, but is meant to be registered for long type. It indicates that
+// the long type should be converted to Spark's Decimal type, with provided precision and scale.
+private class CustomDecimal(schema: Schema) extends LogicalType(CustomDecimal.TYPE_NAME) {
+ val scale : Int = {
+ val obj = schema.getObjectProp("scale")
+ obj match {
+ case null =>
+ throw new IllegalArgumentException(s"Invalid ${CustomDecimal.TYPE_NAME}: missing scale");
+ case i : Integer =>
+ i
+ case other =>
+ throw new IllegalArgumentException(s"Expected int ${CustomDecimal.TYPE_NAME}:scale")
+ }
+ }
+ val precision : Int = {
+ val obj = schema.getObjectProp("precision")
+ obj match {
+ case null =>
+ throw new IllegalArgumentException(
+ s"Invalid ${CustomDecimal.TYPE_NAME}: missing precision");
+ case i: Integer =>
+ i
+ case other =>
+ throw new IllegalArgumentException(s"Expected int ${CustomDecimal.TYPE_NAME}:precision")
+ }
+ }
+ val className : String = schema.getProp("className")
+
+ override def validate(schema: Schema): Unit = {
+ super.validate(schema)
+ if (schema.getType != Schema.Type.LONG) {
+ throw new IllegalArgumentException(
+ s"${CustomDecimal.TYPE_NAME} can only be used with an underlying long type")
+ }
+ if (precision <= 0) {
+ throw new IllegalArgumentException(s"Invalid decimal precision: $precision" +
+ " (must be positive)");
+ } else if (precision > DecimalType.MAX_PRECISION) {
+ throw new IllegalArgumentException(
+ s"cannot store $precision digits (max ${DecimalType.MAX_PRECISION})")
+ }
+ if (scale < 0) {
+ throw new IllegalArgumentException(s"Invalid decimal scale: $scale" +
+ " (must be positive)");
+ } else if (scale > precision) {
+ throw new IllegalArgumentException(s"Invalid decimal scale: $scale (greater than " +
+ s"precision: $precision)");
+ }
+ }
+}
diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala
index 78b1f01e2ef..2a5f9598518 100644
--- a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala
+++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala
@@ -137,6 +137,7 @@ private[sql] class AvroDeserializer(
updater.setBoolean(ordinal, value.asInstanceOf[Boolean])
case _ => throw new IncompatibleSchemaException(incompatibleMsg)
}
+
case INT =>
(logicalDataType, catalystType) match {
case (IntegerType, IntegerType) => (updater, ordinal, value) =>
@@ -206,6 +207,9 @@ private[sql] class AvroDeserializer(
logicalDataType.catalogString, catalystType.catalogString, confKey.key)
case (_: DayTimeIntervalType, DateType) => (updater, ordinal, value) =>
updater.setInt(ordinal, (value.asInstanceOf[Long] / MILLIS_PER_DAY).toInt)
+ case (_, dt: DecimalType) => (updater, ordinal, value) =>
+ val d = avroType.getLogicalType.asInstanceOf[CustomDecimal]
+ updater.setDecimal(ordinal, Decimal(value.asInstanceOf[Long], d.precision, d.scale))
case _ if !preventReadingIncorrectType => (updater, ordinal, value) =>
updater.setLong(ordinal, value.asInstanceOf[Long])
case _ => throw new IncompatibleSchemaException(incompatibleMsg)
diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala
index 3e16e121081..53562a3afdb 100755
--- a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala
+++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala
@@ -21,6 +21,8 @@ import java.io._
import scala.util.control.NonFatal
+import org.apache.avro.{LogicalTypes, Schema}
+import org.apache.avro.LogicalType
import org.apache.avro.file.DataFileReader
import org.apache.avro.generic.{GenericDatumReader, GenericRecord}
import org.apache.avro.mapred.FsInput
@@ -168,4 +170,11 @@ private[sql] class AvroFileFormat extends FileFormat
private[avro] object AvroFileFormat {
val IgnoreFilesWithoutExtensionProperty = "avro.mapred.ignore.inputs.without.extension"
+
+ // Register the customized decimal type backed by long.
+ LogicalTypes.register(CustomDecimal.TYPE_NAME, new LogicalTypes.LogicalTypeFactory {
+ override def fromSchema(schema: Schema): LogicalType = {
+ new CustomDecimal(schema)
+ }
+ })
}
diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala
index e2e2739e7cf..6f21639e28d 100644
--- a/connector/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala
+++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala
@@ -89,6 +89,8 @@ object SchemaConverters {
case DOUBLE => SchemaType(DoubleType, nullable = false)
case FLOAT => SchemaType(FloatType, nullable = false)
case LONG => avroSchema.getLogicalType match {
+ case d: CustomDecimal =>
+ SchemaType(DecimalType(d.precision, d.scale), nullable = false)
case _: TimestampMillis | _: TimestampMicros => SchemaType(TimestampType, nullable = false)
case _: LocalTimestampMillis | _: LocalTimestampMicros =>
SchemaType(TimestampNTZType, nullable = false)
diff --git a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroLogicalTypeSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroLogicalTypeSuite.scala
index c0022c62735..202b09242a0 100644
--- a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroLogicalTypeSuite.scala
+++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroLogicalTypeSuite.scala
@@ -29,7 +29,7 @@ import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
-import org.apache.spark.sql.types.{StructField, StructType, TimestampNTZType, TimestampType}
+import org.apache.spark.sql.types.{DecimalType, LongType, StructField, StructType, TimestampNTZType, TimestampType}
abstract class AvroLogicalTypeSuite extends QueryTest with SharedSparkSession {
import testImplicits._
@@ -446,6 +446,98 @@ abstract class AvroLogicalTypeSuite extends QueryTest with SharedSparkSession {
)
}
}
+
+ test("SPARK-43901: LogicalType: Custom Decimal for Long Type") {
+ val schema =
+ new Schema.Parser().parse("""{
+ "namespace": "logical",
+ "type": "record",
+ "name": "test",
+ "fields": [
+ {
+ "name": "field1",
+ "type": {"type": "long", "logicalType": "custom-decimal", "scale": 2, "precision": 38}
+ },
+ {
+ "name": "field2",
+ "type": {"type": "long", "logicalType": "custom-decimal", "scale": 9, "precision": 33}
+ },
+ {
+ "name": "field3",
+ "type": "long"
+ }]
+ }""")
+
+ withTempDir { dir =>
+ val datumWriter = new GenericDatumWriter[GenericRecord](schema)
+ val dataFileWriter = new DataFileWriter[GenericRecord](datumWriter)
+ dataFileWriter.create(schema, new File(s"$dir.avro"))
+ val avroRec = new GenericData.Record(schema)
+ avroRec.put("field1", 123456789L)
+ avroRec.put("field2", 123456789L)
+ avroRec.put("field3", 123456789L)
+ dataFileWriter.append(avroRec)
+ dataFileWriter.flush()
+ dataFileWriter.close()
+ val df = spark
+ .read
+ .format("avro")
+ .load(s"$dir.avro")
+ assertResult(DecimalType(38, 2))(df.schema.head.dataType)
+ val firstRow = df.take(1)(0)
+ assertResult(java.math.BigDecimal.valueOf(123456789L, 2))(firstRow.getAs("field1"))
+ assertResult(java.math.BigDecimal.valueOf(123456789L, 9))(firstRow.getAs("field2"))
+ assertResult(123456789L)(firstRow.getAs("field3"))
+ }
+ }
+
+ test("SPARK-43901: LogicalType: Decimal for Long Type Exception Cases") {
+ // Avro appears to catch all exceptions when creating a customized logical type and turn the
+ // logical null and we can't distinguish with the case where the logical type isn't given.
+ Seq(
+ """ "type": "long", "logicalType": "custom-decimal", "scale": 2, "precision": 50 """,
+ """ "type": "long", "logicalType": "custom-decimal", "scale": -2, "precision": 30 """,
+ """ "type": "long", "logicalType": "custom-decimal", "scale": 2, "precision": -30 """,
+ """ "type": "long", "logicalType": "custom-decimal", "scale": 30, "precision": 20 """,
+ """ "type": "long", "logicalType": "custom-decimal", "scale": "2", "precision": 30 """,
+ """ "type": "long", "logicalType": "custom-decimal", "scale": "xx", "precision": 30 """,
+ """ "type": "long", "logicalType": "custom-decimal", "scale": 2, "precision": "30" """,
+ """ "type": "long", "logicalType": "custom-decimal", "scale": 2, "precision": "xx" """,
+ """ "type": "long", "logicalType": "custom-decimal", "precision": 30 """,
+ """ "type": "long", "logicalType": "custom-decimal", "scale": 2 """,
+ """ "type": "long", "logicalType": "custom-decimal" """
+ ).foreach { d =>
+ val schema =
+ new Schema.Parser().parse(
+ s"""{
+ "namespace": "logical",
+ "type": "record",
+ "name": "test",
+ "fields": [
+ {
+ "name": "field",
+ "type": {$d}
+ }]
+ }""")
+
+ withTempDir { dir =>
+ val datumWriter = new GenericDatumWriter[GenericRecord](schema)
+ val dataFileWriter = new DataFileWriter[GenericRecord](datumWriter)
+ dataFileWriter.create(schema, new File(s"$dir.avro"))
+ val avroRec = new GenericData.Record(schema)
+ avroRec.put("field", 123456789L)
+ dataFileWriter.append(avroRec)
+ dataFileWriter.flush()
+ dataFileWriter.close()
+ val df = spark.read
+ .format("avro")
+ .load(s"$dir.avro")
+ assertResult(LongType)(df.schema.head.dataType)
+ val firstRow = df.take(1)(0)
+ assertResult(123456789L)(firstRow.getAs("field"))
+ }
+ }
+ }
}
class AvroV1LogicalTypeSuite extends AvroLogicalTypeSuite {
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org