You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by do...@apache.org on 2021/01/26 23:34:25 UTC

[spark] branch branch-2.4 updated: [SPARK-34212][SQL] Fix incorrect decimal reading from Parquet files

This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch branch-2.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-2.4 by this push:
     new 6519a7e  [SPARK-34212][SQL] Fix incorrect decimal reading from Parquet files
6519a7e is described below

commit 6519a7e060f4be134dc51a3792ad26dbbda0a050
Author: Dongjoon Hyun <dh...@apple.com>
AuthorDate: Tue Jan 26 15:13:39 2021 -0800

    [SPARK-34212][SQL] Fix incorrect decimal reading from Parquet files
    
    This PR aims to the correctness issues during reading decimal values from Parquet files.
    - For **MR** code path, `ParquetRowConverter` can read Parquet's decimal values with the original precision and scale written in the corresponding footer.
    - For **Vectorized** code path, `VectorizedColumnReader` throws `SchemaColumnConvertNotSupportedException`.
    
    Currently, Spark returns incorrect results when the Parquet file's decimal precision and scale are different from the Spark's schema. This happens when there is multiple files with different decimal schema or HiveMetastore has a new schema.
    
    **BEFORE (Simplified example for correctness)**
    
    ```scala
    scala> sql("SELECT 1.0 a").write.parquet("/tmp/decimal")
    scala> spark.read.schema("a DECIMAL(3,2)").parquet("/tmp/decimal").show
    +----+
    |   a|
    +----+
    |0.10|
    +----+
    ```
    
    This works correctly in the other data sources, `ORC/JSON/CSV`, like the following.
    ```scala
    scala> sql("SELECT 1.0 a").write.orc("/tmp/decimal_orc")
    scala> spark.read.schema("a DECIMAL(3,2)").orc("/tmp/decimal_orc").show
    +----+
    |   a|
    +----+
    |1.00|
    +----+
    ```
    
    **AFTER**
    1. **Vectorized** path: Instead of incorrect result, we will raise an explicit exception.
    ```scala
    scala> spark.read.schema("a DECIMAL(3,2)").parquet("/tmp/decimal").show
    java.lang.UnsupportedOperationException: Schema evolution not supported.
    ```
    
    2. **MR** path (complex schema or explicit configuration): Spark returns correct results.
    ```scala
    scala> spark.read.schema("a DECIMAL(3,2), b DECIMAL(18, 3), c MAP<INT,INT>").parquet("/tmp/decimal").show
    +----+-------+--------+
    |   a|      b|       c|
    +----+-------+--------+
    |1.00|100.000|{1 -> 2}|
    +----+-------+--------+
    
    scala> spark.read.schema("a DECIMAL(3,2), b DECIMAL(18, 3), c MAP<INT,INT>").parquet("/tmp/decimal").printSchema
    root
     |-- a: decimal(3,2) (nullable = true)
     |-- b: decimal(18,3) (nullable = true)
     |-- c: map (nullable = true)
     |    |-- key: integer
     |    |-- value: integer (valueContainsNull = true)
    ```
    
    Yes. This fixes the correctness issue.
    
    Pass with the newly added test case.
    
    Closes #31319 from dongjoon-hyun/SPARK-34212.
    
    Lead-authored-by: Dongjoon Hyun <dh...@apple.com>
    Co-authored-by: Wenchen Fan <we...@databricks.com>
    Signed-off-by: Dongjoon Hyun <dh...@apple.com>
    (cherry picked from commit dbf051c50a17d644ecc1823e96eede4a5a6437fd)
    Signed-off-by: Dongjoon Hyun <dh...@apple.com>
---
 .../parquet/VectorizedColumnReader.java            | 45 ++++++++++++++-----
 .../datasources/parquet/ParquetRowConverter.scala  | 22 +++++++--
 .../scala/org/apache/spark/sql/SQLQuerySuite.scala | 52 ++++++++++++++++++++++
 3 files changed, 105 insertions(+), 14 deletions(-)

diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java
index ba26b57..4739089 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java
@@ -30,12 +30,14 @@ import org.apache.parquet.column.Encoding;
 import org.apache.parquet.column.page.*;
 import org.apache.parquet.column.values.ValuesReader;
 import org.apache.parquet.io.api.Binary;
+import org.apache.parquet.schema.DecimalMetadata;
 import org.apache.parquet.schema.OriginalType;
 import org.apache.parquet.schema.PrimitiveType;
 
 import org.apache.spark.sql.catalyst.util.DateTimeUtils;
 import org.apache.spark.sql.execution.datasources.SchemaColumnConvertNotSupportedException;
 import org.apache.spark.sql.execution.vectorized.WritableColumnVector;
+import org.apache.spark.sql.types.DataType;
 import org.apache.spark.sql.types.DataTypes;
 import org.apache.spark.sql.types.DecimalType;
 
@@ -101,6 +103,27 @@ public class VectorizedColumnReader {
   private final TimeZone convertTz;
   private static final TimeZone UTC = DateTimeUtils.TimeZoneUTC();
 
+  private boolean isDecimalTypeMatched(DataType dt) {
+    DecimalType d = (DecimalType) dt;
+    DecimalMetadata dm = descriptor.getPrimitiveType().getDecimalMetadata();
+    return dm != null && dm.getPrecision() == d.precision() && dm.getScale() == d.scale();
+  }
+
+  private boolean canReadAsIntDecimal(DataType dt) {
+    if (!DecimalType.is32BitDecimalType(dt)) return false;
+    return isDecimalTypeMatched(dt);
+  }
+
+  private boolean canReadAsLongDecimal(DataType dt) {
+    if (!DecimalType.is64BitDecimalType(dt)) return false;
+    return isDecimalTypeMatched(dt);
+  }
+
+  private boolean canReadAsBinaryDecimal(DataType dt) {
+    if (!DecimalType.isByteArrayDecimalType(dt)) return false;
+    return isDecimalTypeMatched(dt);
+  }
+
   public VectorizedColumnReader(
       ColumnDescriptor descriptor,
       OriginalType originalType,
@@ -261,7 +284,7 @@ public class VectorizedColumnReader {
     switch (descriptor.getPrimitiveType().getPrimitiveTypeName()) {
       case INT32:
         if (column.dataType() == DataTypes.IntegerType ||
-            DecimalType.is32BitDecimalType(column.dataType())) {
+            canReadAsIntDecimal(column.dataType())) {
           for (int i = rowId; i < rowId + num; ++i) {
             if (!column.isNullAt(i)) {
               column.putInt(i, dictionary.decodeToInt(dictionaryIds.getDictId(i)));
@@ -286,7 +309,7 @@ public class VectorizedColumnReader {
 
       case INT64:
         if (column.dataType() == DataTypes.LongType ||
-            DecimalType.is64BitDecimalType(column.dataType()) ||
+            canReadAsLongDecimal(column.dataType()) ||
             originalType == OriginalType.TIMESTAMP_MICROS) {
           for (int i = rowId; i < rowId + num; ++i) {
             if (!column.isNullAt(i)) {
@@ -357,21 +380,21 @@ public class VectorizedColumnReader {
         break;
       case FIXED_LEN_BYTE_ARRAY:
         // DecimalType written in the legacy mode
-        if (DecimalType.is32BitDecimalType(column.dataType())) {
+        if (canReadAsIntDecimal(column.dataType())) {
           for (int i = rowId; i < rowId + num; ++i) {
             if (!column.isNullAt(i)) {
               Binary v = dictionary.decodeToBinary(dictionaryIds.getDictId(i));
               column.putInt(i, (int) ParquetRowConverter.binaryToUnscaledLong(v));
             }
           }
-        } else if (DecimalType.is64BitDecimalType(column.dataType())) {
+        } else if (canReadAsLongDecimal(column.dataType())) {
           for (int i = rowId; i < rowId + num; ++i) {
             if (!column.isNullAt(i)) {
               Binary v = dictionary.decodeToBinary(dictionaryIds.getDictId(i));
               column.putLong(i, ParquetRowConverter.binaryToUnscaledLong(v));
             }
           }
-        } else if (DecimalType.isByteArrayDecimalType(column.dataType())) {
+        } else if (canReadAsBinaryDecimal(column.dataType())) {
           for (int i = rowId; i < rowId + num; ++i) {
             if (!column.isNullAt(i)) {
               Binary v = dictionary.decodeToBinary(dictionaryIds.getDictId(i));
@@ -407,7 +430,7 @@ public class VectorizedColumnReader {
     // This is where we implement support for the valid type conversions.
     // TODO: implement remaining type conversions
     if (column.dataType() == DataTypes.IntegerType || column.dataType() == DataTypes.DateType ||
-        DecimalType.is32BitDecimalType(column.dataType())) {
+        canReadAsIntDecimal(column.dataType())) {
       defColumn.readIntegers(
           num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn);
     } else if (column.dataType() == DataTypes.ByteType) {
@@ -424,7 +447,7 @@ public class VectorizedColumnReader {
   private void readLongBatch(int rowId, int num, WritableColumnVector column) throws IOException {
     // This is where we implement support for the valid type conversions.
     if (column.dataType() == DataTypes.LongType ||
-        DecimalType.is64BitDecimalType(column.dataType()) ||
+        canReadAsLongDecimal(column.dataType()) ||
         originalType == OriginalType.TIMESTAMP_MICROS) {
       defColumn.readLongs(
         num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn);
@@ -468,7 +491,7 @@ public class VectorizedColumnReader {
     // TODO: implement remaining type conversions
     VectorizedValuesReader data = (VectorizedValuesReader) dataColumn;
     if (column.dataType() == DataTypes.StringType || column.dataType() == DataTypes.BinaryType
-            || DecimalType.isByteArrayDecimalType(column.dataType())) {
+            || canReadAsBinaryDecimal(column.dataType())) {
       defColumn.readBinarys(num, column, rowId, maxDefLevel, data);
     } else if (column.dataType() == DataTypes.TimestampType) {
       if (!shouldConvertTimestamps()) {
@@ -506,7 +529,7 @@ public class VectorizedColumnReader {
     VectorizedValuesReader data = (VectorizedValuesReader) dataColumn;
     // This is where we implement support for the valid type conversions.
     // TODO: implement remaining type conversions
-    if (DecimalType.is32BitDecimalType(column.dataType())) {
+    if (canReadAsIntDecimal(column.dataType())) {
       for (int i = 0; i < num; i++) {
         if (defColumn.readInteger() == maxDefLevel) {
           column.putInt(rowId + i,
@@ -515,7 +538,7 @@ public class VectorizedColumnReader {
           column.putNull(rowId + i);
         }
       }
-    } else if (DecimalType.is64BitDecimalType(column.dataType())) {
+    } else if (canReadAsLongDecimal(column.dataType())) {
       for (int i = 0; i < num; i++) {
         if (defColumn.readInteger() == maxDefLevel) {
           column.putLong(rowId + i,
@@ -524,7 +547,7 @@ public class VectorizedColumnReader {
           column.putNull(rowId + i);
         }
       }
-    } else if (DecimalType.isByteArrayDecimalType(column.dataType())) {
+    } else if (canReadAsBinaryDecimal(column.dataType())) {
       for (int i = 0; i < num; i++) {
         if (defColumn.readInteger() == maxDefLevel) {
           column.putByteArray(rowId + i, data.readBinary(arrayLen).getBytes());
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala
index 1199725..0d22fe5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala
@@ -210,6 +210,19 @@ private[parquet] class ParquetRowConverter(
   }
 
   /**
+   * Get a precision and a scale to interpret parquet decimal values.
+   * 1. If there is a decimal metadata, we read decimal values with the given precision and scale.
+   * 2. If there is no metadata, we read decimal values with scale `0` because it's plain integers
+   *    when it is written into INT32/INT64/BINARY/FIXED_LEN_BYTE_ARRAY types.
+   */
+  private def getPrecisionAndScale(parquetType: Type, t: DecimalType): (Int, Int) = {
+    val metadata = parquetType.asPrimitiveType().getDecimalMetadata
+    val precision = if (metadata == null) t.precision else metadata.getPrecision()
+    val scale = if (metadata == null) 0 else metadata.getScale()
+    (precision, scale)
+  }
+
+  /**
    * Creates a converter for the given Parquet type `parquetType` and Spark SQL data type
    * `catalystType`. Converted values are handled by `updater`.
    */
@@ -236,17 +249,20 @@ private[parquet] class ParquetRowConverter(
 
       // For INT32 backed decimals
       case t: DecimalType if parquetType.asPrimitiveType().getPrimitiveTypeName == INT32 =>
-        new ParquetIntDictionaryAwareDecimalConverter(t.precision, t.scale, updater)
+        val (precision, scale) = getPrecisionAndScale(parquetType, t)
+        new ParquetIntDictionaryAwareDecimalConverter(precision, scale, updater)
 
       // For INT64 backed decimals
       case t: DecimalType if parquetType.asPrimitiveType().getPrimitiveTypeName == INT64 =>
-        new ParquetLongDictionaryAwareDecimalConverter(t.precision, t.scale, updater)
+        val (precision, scale) = getPrecisionAndScale(parquetType, t)
+        new ParquetLongDictionaryAwareDecimalConverter(precision, scale, updater)
 
       // For BINARY and FIXED_LEN_BYTE_ARRAY backed decimals
       case t: DecimalType
         if parquetType.asPrimitiveType().getPrimitiveTypeName == FIXED_LEN_BYTE_ARRAY ||
            parquetType.asPrimitiveType().getPrimitiveTypeName == BINARY =>
-        new ParquetBinaryDictionaryAwareDecimalConverter(t.precision, t.scale, updater)
+        val (precision, scale) = getPrecisionAndScale(parquetType, t)
+        new ParquetBinaryDictionaryAwareDecimalConverter(precision, scale, updater)
 
       case t: DecimalType =>
         throw new RuntimeException(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index ab2a1c9..1af50bf 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -29,6 +29,7 @@ import org.apache.spark.sql.execution.aggregate
 import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, SortAggregateExec}
 import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
 import org.apache.spark.sql.execution.datasources.FilePartition
+import org.apache.spark.sql.execution.datasources.SchemaColumnConvertNotSupportedException
 import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, CartesianProductExec, SortMergeJoinExec}
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.internal.SQLConf
@@ -3140,6 +3141,57 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
       }
     })
   }
+
+  test("SPARK-34212 Parquet should read decimals correctly") {
+    // a is int-decimal (4 bytes), b is long-decimal (8 bytes), c is binary-decimal (16 bytes)
+    val df = sql("SELECT 1.0 a, CAST(1.23 AS DECIMAL(17, 2)) b, CAST(1.23 AS DECIMAL(36, 2)) c")
+
+    withTempPath { path =>
+      df.write.parquet(path.toString)
+
+      withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") {
+        val schema1 = "a DECIMAL(3, 2), b DECIMAL(18, 3), c DECIMAL(37, 3)"
+        checkAnswer(spark.read.schema(schema1).parquet(path.toString), df)
+        val schema2 = "a DECIMAL(3, 0), b DECIMAL(18, 1), c DECIMAL(37, 1)"
+        checkAnswer(spark.read.schema(schema2).parquet(path.toString), Row(1, 1.2, 1.2))
+      }
+
+      withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "true") {
+        val e1 = intercept[SparkException] {
+          spark.read.schema("a DECIMAL(3, 2)").parquet(path.toString).collect()
+        }.getCause.getCause
+        assert(e1.isInstanceOf[SchemaColumnConvertNotSupportedException])
+
+        val e2 = intercept[SparkException] {
+          spark.read.schema("b DECIMAL(18, 1)").parquet(path.toString).collect()
+        }.getCause.getCause
+        assert(e2.isInstanceOf[SchemaColumnConvertNotSupportedException])
+
+        val e3 = intercept[SparkException] {
+          spark.read.schema("c DECIMAL(37, 1)").parquet(path.toString).collect()
+        }.getCause.getCause
+        assert(e3.isInstanceOf[SchemaColumnConvertNotSupportedException])
+      }
+    }
+
+    withTempPath { path =>
+      val df2 = sql(s"SELECT 1 a, ${Int.MaxValue + 1L} b")
+      df2.write.parquet(path.toString)
+
+      withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") {
+        val schema = "a DECIMAL(3, 2), b DECIMAL(17, 2)"
+        checkAnswer(spark.read.schema(schema).parquet(path.toString),
+          Row(BigDecimal(100, 2), BigDecimal((Int.MaxValue + 1L) * 100, 2)))
+      }
+
+      withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "true") {
+        val e = intercept[SparkException] {
+          spark.read.schema("a DECIMAL(3, 2)").parquet(path.toString).collect()
+        }.getCause.getCause
+        assert(e.isInstanceOf[SchemaColumnConvertNotSupportedException])
+      }
+    }
+  }
 }
 
 case class Foo(bar: Option[String])


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org