You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by do...@apache.org on 2023/06/07 02:41:24 UTC

[spark] branch master updated: [SPARK-43901][SQL] Avro to Support custom decimal type backed by Long

This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new d2f72a21677 [SPARK-43901][SQL] Avro to Support custom decimal type backed by Long
d2f72a21677 is described below

commit d2f72a21677748793f0bb329630d72bc91449587
Author: Siying Dong <si...@databricks.com>
AuthorDate: Tue Jun 6 19:41:09 2023 -0700

    [SPARK-43901][SQL] Avro to Support custom decimal type backed by Long
    
    ### What changes were proposed in this pull request?
    Add a logical type "custom-decimal" in Avro, which can only be backed by physical type long, and will be convert into decimal type.
    
    ### Why are the changes needed?
    A user would like to represent currency (for money) after loading Avro into SQL type. However, there isn't a good way to represent it in Avro. This custom type will allow them to do that.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    Added several unit test cases to test the new "custom-decimal" to be loaded successfully and also exception cases.
    
    Closes #41409 from siying/customdecimal.
    
    Authored-by: Siying Dong <si...@databricks.com>
    Signed-off-by: Dongjoon Hyun <do...@apache.org>
---
 .../org/apache/spark/sql/avro/CustomDecimal.scala  | 79 ++++++++++++++++++
 .../apache/spark/sql/avro/AvroDeserializer.scala   |  4 +
 .../org/apache/spark/sql/avro/AvroFileFormat.scala |  9 +++
 .../apache/spark/sql/avro/SchemaConverters.scala   |  2 +
 .../spark/sql/avro/AvroLogicalTypeSuite.scala      | 94 +++++++++++++++++++++-
 5 files changed, 187 insertions(+), 1 deletion(-)

diff --git a/connector/avro/src/main/java/org/apache/spark/sql/avro/CustomDecimal.scala b/connector/avro/src/main/java/org/apache/spark/sql/avro/CustomDecimal.scala
new file mode 100644
index 00000000000..d76f40c7635
--- /dev/null
+++ b/connector/avro/src/main/java/org/apache/spark/sql/avro/CustomDecimal.scala
@@ -0,0 +1,79 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.avro
+
+import org.apache.avro.LogicalType
+import org.apache.avro.Schema
+
+import org.apache.spark.sql.types.DecimalType
+
+object CustomDecimal {
+  val TYPE_NAME = "custom-decimal"
+}
+
+// A customized logical type, which will be registered to Avro. This logical type is similar to
+// Avro's builtin Decimal type, but is meant to be registered for long type. It indicates that
+// the long type should be converted to Spark's Decimal type, with provided precision and scale.
+private class CustomDecimal(schema: Schema) extends LogicalType(CustomDecimal.TYPE_NAME) {
+  val scale : Int = {
+    val obj = schema.getObjectProp("scale")
+    obj match {
+      case null =>
+        throw new IllegalArgumentException(s"Invalid ${CustomDecimal.TYPE_NAME}: missing scale");
+      case i : Integer =>
+        i
+      case other =>
+        throw new IllegalArgumentException(s"Expected int ${CustomDecimal.TYPE_NAME}:scale")
+    }
+  }
+  val precision : Int = {
+    val obj = schema.getObjectProp("precision")
+    obj match {
+      case null =>
+        throw new IllegalArgumentException(
+          s"Invalid ${CustomDecimal.TYPE_NAME}: missing precision");
+      case i: Integer =>
+        i
+      case other =>
+        throw new IllegalArgumentException(s"Expected int ${CustomDecimal.TYPE_NAME}:precision")
+    }
+  }
+  val className : String = schema.getProp("className")
+
+  override def validate(schema: Schema): Unit = {
+    super.validate(schema)
+    if (schema.getType != Schema.Type.LONG) {
+      throw new IllegalArgumentException(
+        s"${CustomDecimal.TYPE_NAME} can only be used with an underlying long type")
+    }
+    if (precision <= 0) {
+      throw new IllegalArgumentException(s"Invalid decimal precision: $precision" +
+        " (must be positive)");
+    } else if (precision > DecimalType.MAX_PRECISION) {
+      throw new IllegalArgumentException(
+        s"cannot store $precision digits (max ${DecimalType.MAX_PRECISION})")
+    }
+    if (scale < 0) {
+      throw new IllegalArgumentException(s"Invalid decimal scale: $scale" +
+        " (must be positive)");
+    } else if (scale > precision) {
+      throw new IllegalArgumentException(s"Invalid decimal scale: $scale (greater than " +
+        s"precision: $precision)");
+    }
+  }
+}
diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala
index 78b1f01e2ef..2a5f9598518 100644
--- a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala
+++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala
@@ -137,6 +137,7 @@ private[sql] class AvroDeserializer(
             updater.setBoolean(ordinal, value.asInstanceOf[Boolean])
           case _ => throw new IncompatibleSchemaException(incompatibleMsg)
         }
+
       case INT =>
         (logicalDataType, catalystType) match {
           case (IntegerType, IntegerType) => (updater, ordinal, value) =>
@@ -206,6 +207,9 @@ private[sql] class AvroDeserializer(
               logicalDataType.catalogString, catalystType.catalogString, confKey.key)
           case (_: DayTimeIntervalType, DateType) => (updater, ordinal, value) =>
             updater.setInt(ordinal, (value.asInstanceOf[Long] / MILLIS_PER_DAY).toInt)
+          case (_, dt: DecimalType) => (updater, ordinal, value) =>
+            val d = avroType.getLogicalType.asInstanceOf[CustomDecimal]
+            updater.setDecimal(ordinal, Decimal(value.asInstanceOf[Long], d.precision, d.scale))
           case _ if !preventReadingIncorrectType => (updater, ordinal, value) =>
             updater.setLong(ordinal, value.asInstanceOf[Long])
           case _ => throw new IncompatibleSchemaException(incompatibleMsg)
diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala
index 3e16e121081..53562a3afdb 100755
--- a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala
+++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala
@@ -21,6 +21,8 @@ import java.io._
 
 import scala.util.control.NonFatal
 
+import org.apache.avro.{LogicalTypes, Schema}
+import org.apache.avro.LogicalType
 import org.apache.avro.file.DataFileReader
 import org.apache.avro.generic.{GenericDatumReader, GenericRecord}
 import org.apache.avro.mapred.FsInput
@@ -168,4 +170,11 @@ private[sql] class AvroFileFormat extends FileFormat
 
 private[avro] object AvroFileFormat {
   val IgnoreFilesWithoutExtensionProperty = "avro.mapred.ignore.inputs.without.extension"
+
+  // Register the customized decimal type backed by long.
+  LogicalTypes.register(CustomDecimal.TYPE_NAME, new LogicalTypes.LogicalTypeFactory {
+    override def fromSchema(schema: Schema): LogicalType = {
+      new CustomDecimal(schema)
+    }
+  })
 }
diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala
index e2e2739e7cf..6f21639e28d 100644
--- a/connector/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala
+++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala
@@ -89,6 +89,8 @@ object SchemaConverters {
       case DOUBLE => SchemaType(DoubleType, nullable = false)
       case FLOAT => SchemaType(FloatType, nullable = false)
       case LONG => avroSchema.getLogicalType match {
+        case d: CustomDecimal =>
+          SchemaType(DecimalType(d.precision, d.scale), nullable = false)
         case _: TimestampMillis | _: TimestampMicros => SchemaType(TimestampType, nullable = false)
         case _: LocalTimestampMillis | _: LocalTimestampMicros =>
           SchemaType(TimestampNTZType, nullable = false)
diff --git a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroLogicalTypeSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroLogicalTypeSuite.scala
index c0022c62735..202b09242a0 100644
--- a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroLogicalTypeSuite.scala
+++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroLogicalTypeSuite.scala
@@ -29,7 +29,7 @@ import org.apache.spark.sql.{QueryTest, Row}
 import org.apache.spark.sql.catalyst.util.DateTimeUtils
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SharedSparkSession
-import org.apache.spark.sql.types.{StructField, StructType, TimestampNTZType, TimestampType}
+import org.apache.spark.sql.types.{DecimalType, LongType, StructField, StructType, TimestampNTZType, TimestampType}
 
 abstract class AvroLogicalTypeSuite extends QueryTest with SharedSparkSession {
   import testImplicits._
@@ -446,6 +446,98 @@ abstract class AvroLogicalTypeSuite extends QueryTest with SharedSparkSession {
       )
     }
   }
+
+  test("SPARK-43901: LogicalType: Custom Decimal for Long Type") {
+    val schema =
+      new Schema.Parser().parse("""{
+        "namespace": "logical",
+        "type": "record",
+        "name": "test",
+        "fields": [
+         {
+           "name": "field1",
+           "type": {"type": "long", "logicalType": "custom-decimal", "scale": 2, "precision": 38}
+         },
+         {
+           "name": "field2",
+           "type": {"type": "long", "logicalType": "custom-decimal", "scale": 9, "precision": 33}
+         },
+         {
+           "name": "field3",
+           "type": "long"
+         }]
+        }""")
+
+    withTempDir { dir =>
+      val datumWriter = new GenericDatumWriter[GenericRecord](schema)
+      val dataFileWriter = new DataFileWriter[GenericRecord](datumWriter)
+      dataFileWriter.create(schema, new File(s"$dir.avro"))
+      val avroRec = new GenericData.Record(schema)
+      avroRec.put("field1", 123456789L)
+      avroRec.put("field2", 123456789L)
+      avroRec.put("field3", 123456789L)
+      dataFileWriter.append(avroRec)
+      dataFileWriter.flush()
+      dataFileWriter.close()
+      val df = spark
+        .read
+        .format("avro")
+        .load(s"$dir.avro")
+      assertResult(DecimalType(38, 2))(df.schema.head.dataType)
+      val firstRow = df.take(1)(0)
+      assertResult(java.math.BigDecimal.valueOf(123456789L, 2))(firstRow.getAs("field1"))
+      assertResult(java.math.BigDecimal.valueOf(123456789L, 9))(firstRow.getAs("field2"))
+      assertResult(123456789L)(firstRow.getAs("field3"))
+    }
+  }
+
+  test("SPARK-43901: LogicalType: Decimal for Long Type Exception Cases") {
+    // Avro appears to catch all exceptions when creating a customized logical type and turn the
+    // logical null and we can't distinguish with the case where the logical type isn't given.
+    Seq(
+      """ "type": "long", "logicalType": "custom-decimal", "scale": 2, "precision": 50 """,
+      """ "type": "long", "logicalType": "custom-decimal", "scale": -2, "precision": 30 """,
+      """ "type": "long", "logicalType": "custom-decimal", "scale": 2, "precision": -30 """,
+      """ "type": "long", "logicalType": "custom-decimal", "scale": 30, "precision": 20 """,
+      """ "type": "long", "logicalType": "custom-decimal", "scale": "2", "precision": 30 """,
+      """ "type": "long", "logicalType": "custom-decimal", "scale": "xx", "precision": 30 """,
+      """ "type": "long", "logicalType": "custom-decimal", "scale": 2, "precision": "30" """,
+      """ "type": "long", "logicalType": "custom-decimal", "scale": 2, "precision": "xx" """,
+      """ "type": "long", "logicalType": "custom-decimal", "precision": 30 """,
+      """ "type": "long", "logicalType": "custom-decimal", "scale": 2 """,
+      """ "type": "long", "logicalType": "custom-decimal" """
+    ).foreach { d =>
+      val schema =
+        new Schema.Parser().parse(
+          s"""{
+            "namespace": "logical",
+            "type": "record",
+            "name": "test",
+            "fields": [
+            {
+              "name": "field",
+              "type": {$d}
+             }]
+          }""")
+
+      withTempDir { dir =>
+        val datumWriter = new GenericDatumWriter[GenericRecord](schema)
+        val dataFileWriter = new DataFileWriter[GenericRecord](datumWriter)
+        dataFileWriter.create(schema, new File(s"$dir.avro"))
+        val avroRec = new GenericData.Record(schema)
+        avroRec.put("field", 123456789L)
+        dataFileWriter.append(avroRec)
+        dataFileWriter.flush()
+        dataFileWriter.close()
+          val df = spark.read
+            .format("avro")
+            .load(s"$dir.avro")
+        assertResult(LongType)(df.schema.head.dataType)
+        val firstRow = df.take(1)(0)
+        assertResult(123456789L)(firstRow.getAs("field"))
+      }
+    }
+  }
 }
 
 class AvroV1LogicalTypeSuite extends AvroLogicalTypeSuite {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org