You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by do...@apache.org on 2021/01/26 23:21:08 UTC

[spark] branch branch-3.0 updated: [SPARK-34212][SQL] Fix incorrect decimal reading from Parquet files

This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.0 by this push:
     new 1deb1f8  [SPARK-34212][SQL] Fix incorrect decimal reading from Parquet files
1deb1f8 is described below

commit 1deb1f8c508579964765df94c1183bda6dcad5de
Author: Dongjoon Hyun <dh...@apple.com>
AuthorDate: Tue Jan 26 15:13:39 2021 -0800

    [SPARK-34212][SQL] Fix incorrect decimal reading from Parquet files
    
    This PR aims to the correctness issues during reading decimal values from Parquet files.
    - For **MR** code path, `ParquetRowConverter` can read Parquet's decimal values with the original precision and scale written in the corresponding footer.
    - For **Vectorized** code path, `VectorizedColumnReader` throws `SchemaColumnConvertNotSupportedException`.
    
    Currently, Spark returns incorrect results when the Parquet file's decimal precision and scale are different from the Spark's schema. This happens when there is multiple files with different decimal schema or HiveMetastore has a new schema.
    
    **BEFORE (Simplified example for correctness)**
    
    ```scala
    scala> sql("SELECT 1.0 a").write.parquet("/tmp/decimal")
    scala> spark.read.schema("a DECIMAL(3,2)").parquet("/tmp/decimal").show
    +----+
    |   a|
    +----+
    |0.10|
    +----+
    ```
    
    This works correctly in the other data sources, `ORC/JSON/CSV`, like the following.
    ```scala
    scala> sql("SELECT 1.0 a").write.orc("/tmp/decimal_orc")
    scala> spark.read.schema("a DECIMAL(3,2)").orc("/tmp/decimal_orc").show
    +----+
    |   a|
    +----+
    |1.00|
    +----+
    ```
    
    **AFTER**
    1. **Vectorized** path: Instead of incorrect result, we will raise an explicit exception.
    ```scala
    scala> spark.read.schema("a DECIMAL(3,2)").parquet("/tmp/decimal").show
    java.lang.UnsupportedOperationException: Schema evolution not supported.
    ```
    
    2. **MR** path (complex schema or explicit configuration): Spark returns correct results.
    ```scala
    scala> spark.read.schema("a DECIMAL(3,2), b DECIMAL(18, 3), c MAP<INT,INT>").parquet("/tmp/decimal").show
    +----+-------+--------+
    |   a|      b|       c|
    +----+-------+--------+
    |1.00|100.000|{1 -> 2}|
    +----+-------+--------+
    
    scala> spark.read.schema("a DECIMAL(3,2), b DECIMAL(18, 3), c MAP<INT,INT>").parquet("/tmp/decimal").printSchema
    root
     |-- a: decimal(3,2) (nullable = true)
     |-- b: decimal(18,3) (nullable = true)
     |-- c: map (nullable = true)
     |    |-- key: integer
     |    |-- value: integer (valueContainsNull = true)
    ```
    
    Yes. This fixes the correctness issue.
    
    Pass with the newly added test case.
    
    Closes #31319 from dongjoon-hyun/SPARK-34212.
    
    Lead-authored-by: Dongjoon Hyun <dh...@apple.com>
    Co-authored-by: Wenchen Fan <we...@databricks.com>
    Signed-off-by: Dongjoon Hyun <dh...@apple.com>
    (cherry picked from commit dbf051c50a17d644ecc1823e96eede4a5a6437fd)
    Signed-off-by: Dongjoon Hyun <dh...@apple.com>
---
 .../parquet/VectorizedColumnReader.java            | 45 ++++++++++++++-----
 .../datasources/parquet/ParquetRowConverter.scala  | 22 +++++++--
 .../scala/org/apache/spark/sql/SQLQuerySuite.scala | 52 ++++++++++++++++++++++
 3 files changed, 105 insertions(+), 14 deletions(-)

diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java
index f264281..7681ba9 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java
@@ -31,6 +31,7 @@ import org.apache.parquet.column.Encoding;
 import org.apache.parquet.column.page.*;
 import org.apache.parquet.column.values.ValuesReader;
 import org.apache.parquet.io.api.Binary;
+import org.apache.parquet.schema.DecimalMetadata;
 import org.apache.parquet.schema.OriginalType;
 import org.apache.parquet.schema.PrimitiveType;
 
@@ -39,6 +40,7 @@ import org.apache.spark.sql.catalyst.util.RebaseDateTime;
 import org.apache.spark.sql.execution.datasources.DataSourceUtils;
 import org.apache.spark.sql.execution.datasources.SchemaColumnConvertNotSupportedException;
 import org.apache.spark.sql.execution.vectorized.WritableColumnVector;
+import org.apache.spark.sql.types.DataType;
 import org.apache.spark.sql.types.DataTypes;
 import org.apache.spark.sql.types.DecimalType;
 
@@ -105,6 +107,27 @@ public class VectorizedColumnReader {
   private static final ZoneId UTC = ZoneOffset.UTC;
   private final String datetimeRebaseMode;
 
+  private boolean isDecimalTypeMatched(DataType dt) {
+    DecimalType d = (DecimalType) dt;
+    DecimalMetadata dm = descriptor.getPrimitiveType().getDecimalMetadata();
+    return dm != null && dm.getPrecision() == d.precision() && dm.getScale() == d.scale();
+  }
+
+  private boolean canReadAsIntDecimal(DataType dt) {
+    if (!DecimalType.is32BitDecimalType(dt)) return false;
+    return isDecimalTypeMatched(dt);
+  }
+
+  private boolean canReadAsLongDecimal(DataType dt) {
+    if (!DecimalType.is64BitDecimalType(dt)) return false;
+    return isDecimalTypeMatched(dt);
+  }
+
+  private boolean canReadAsBinaryDecimal(DataType dt) {
+    if (!DecimalType.isByteArrayDecimalType(dt)) return false;
+    return isDecimalTypeMatched(dt);
+  }
+
   public VectorizedColumnReader(
       ColumnDescriptor descriptor,
       OriginalType originalType,
@@ -309,7 +332,7 @@ public class VectorizedColumnReader {
     switch (descriptor.getPrimitiveType().getPrimitiveTypeName()) {
       case INT32:
         if (column.dataType() == DataTypes.IntegerType ||
-            DecimalType.is32BitDecimalType(column.dataType()) ||
+            canReadAsIntDecimal(column.dataType()) ||
             (column.dataType() == DataTypes.DateType && "CORRECTED".equals(datetimeRebaseMode))) {
           for (int i = rowId; i < rowId + num; ++i) {
             if (!column.isNullAt(i)) {
@@ -343,7 +366,7 @@ public class VectorizedColumnReader {
 
       case INT64:
         if (column.dataType() == DataTypes.LongType ||
-            DecimalType.is64BitDecimalType(column.dataType()) ||
+            canReadAsLongDecimal(column.dataType()) ||
             (originalType == OriginalType.TIMESTAMP_MICROS &&
               "CORRECTED".equals(datetimeRebaseMode))) {
           for (int i = rowId; i < rowId + num; ++i) {
@@ -434,21 +457,21 @@ public class VectorizedColumnReader {
         break;
       case FIXED_LEN_BYTE_ARRAY:
         // DecimalType written in the legacy mode
-        if (DecimalType.is32BitDecimalType(column.dataType())) {
+        if (canReadAsIntDecimal(column.dataType())) {
           for (int i = rowId; i < rowId + num; ++i) {
             if (!column.isNullAt(i)) {
               Binary v = dictionary.decodeToBinary(dictionaryIds.getDictId(i));
               column.putInt(i, (int) ParquetRowConverter.binaryToUnscaledLong(v));
             }
           }
-        } else if (DecimalType.is64BitDecimalType(column.dataType())) {
+        } else if (canReadAsLongDecimal(column.dataType())) {
           for (int i = rowId; i < rowId + num; ++i) {
             if (!column.isNullAt(i)) {
               Binary v = dictionary.decodeToBinary(dictionaryIds.getDictId(i));
               column.putLong(i, ParquetRowConverter.binaryToUnscaledLong(v));
             }
           }
-        } else if (DecimalType.isByteArrayDecimalType(column.dataType())) {
+        } else if (canReadAsBinaryDecimal(column.dataType())) {
           for (int i = rowId; i < rowId + num; ++i) {
             if (!column.isNullAt(i)) {
               Binary v = dictionary.decodeToBinary(dictionaryIds.getDictId(i));
@@ -484,7 +507,7 @@ public class VectorizedColumnReader {
     // This is where we implement support for the valid type conversions.
     // TODO: implement remaining type conversions
     if (column.dataType() == DataTypes.IntegerType ||
-        DecimalType.is32BitDecimalType(column.dataType())) {
+        canReadAsIntDecimal(column.dataType())) {
       defColumn.readIntegers(
           num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn);
     } else if (column.dataType() == DataTypes.ByteType) {
@@ -510,7 +533,7 @@ public class VectorizedColumnReader {
   private void readLongBatch(int rowId, int num, WritableColumnVector column) throws IOException {
     // This is where we implement support for the valid type conversions.
     if (column.dataType() == DataTypes.LongType ||
-        DecimalType.is64BitDecimalType(column.dataType())) {
+        canReadAsLongDecimal(column.dataType())) {
       defColumn.readLongs(
         num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn);
     } else if (originalType == OriginalType.TIMESTAMP_MICROS) {
@@ -574,7 +597,7 @@ public class VectorizedColumnReader {
     // TODO: implement remaining type conversions
     VectorizedValuesReader data = (VectorizedValuesReader) dataColumn;
     if (column.dataType() == DataTypes.StringType || column.dataType() == DataTypes.BinaryType
-            || DecimalType.isByteArrayDecimalType(column.dataType())) {
+            || canReadAsBinaryDecimal(column.dataType())) {
       defColumn.readBinarys(num, column, rowId, maxDefLevel, data);
     } else if (column.dataType() == DataTypes.TimestampType) {
       if (!shouldConvertTimestamps()) {
@@ -612,7 +635,7 @@ public class VectorizedColumnReader {
     VectorizedValuesReader data = (VectorizedValuesReader) dataColumn;
     // This is where we implement support for the valid type conversions.
     // TODO: implement remaining type conversions
-    if (DecimalType.is32BitDecimalType(column.dataType())) {
+    if (canReadAsIntDecimal(column.dataType())) {
       for (int i = 0; i < num; i++) {
         if (defColumn.readInteger() == maxDefLevel) {
           column.putInt(rowId + i,
@@ -621,7 +644,7 @@ public class VectorizedColumnReader {
           column.putNull(rowId + i);
         }
       }
-    } else if (DecimalType.is64BitDecimalType(column.dataType())) {
+    } else if (canReadAsLongDecimal(column.dataType())) {
       for (int i = 0; i < num; i++) {
         if (defColumn.readInteger() == maxDefLevel) {
           column.putLong(rowId + i,
@@ -630,7 +653,7 @@ public class VectorizedColumnReader {
           column.putNull(rowId + i);
         }
       }
-    } else if (DecimalType.isByteArrayDecimalType(column.dataType())) {
+    } else if (canReadAsBinaryDecimal(column.dataType())) {
       for (int i = 0; i < num; i++) {
         if (defColumn.readInteger() == maxDefLevel) {
           column.putByteArray(rowId + i, data.readBinary(arrayLen).getBytes());
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala
index 9d37f17..5005d41 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala
@@ -234,6 +234,19 @@ private[parquet] class ParquetRowConverter(
   }
 
   /**
+   * Get a precision and a scale to interpret parquet decimal values.
+   * 1. If there is a decimal metadata, we read decimal values with the given precision and scale.
+   * 2. If there is no metadata, we read decimal values with scale `0` because it's plain integers
+   *    when it is written into INT32/INT64/BINARY/FIXED_LEN_BYTE_ARRAY types.
+   */
+  private def getPrecisionAndScale(parquetType: Type, t: DecimalType): (Int, Int) = {
+    val metadata = parquetType.asPrimitiveType().getDecimalMetadata
+    val precision = if (metadata == null) t.precision else metadata.getPrecision()
+    val scale = if (metadata == null) 0 else metadata.getScale()
+    (precision, scale)
+  }
+
+  /**
    * Creates a converter for the given Parquet type `parquetType` and Spark SQL data type
    * `catalystType`. Converted values are handled by `updater`.
    */
@@ -260,17 +273,20 @@ private[parquet] class ParquetRowConverter(
 
       // For INT32 backed decimals
       case t: DecimalType if parquetType.asPrimitiveType().getPrimitiveTypeName == INT32 =>
-        new ParquetIntDictionaryAwareDecimalConverter(t.precision, t.scale, updater)
+        val (precision, scale) = getPrecisionAndScale(parquetType, t)
+        new ParquetIntDictionaryAwareDecimalConverter(precision, scale, updater)
 
       // For INT64 backed decimals
       case t: DecimalType if parquetType.asPrimitiveType().getPrimitiveTypeName == INT64 =>
-        new ParquetLongDictionaryAwareDecimalConverter(t.precision, t.scale, updater)
+        val (precision, scale) = getPrecisionAndScale(parquetType, t)
+        new ParquetLongDictionaryAwareDecimalConverter(precision, scale, updater)
 
       // For BINARY and FIXED_LEN_BYTE_ARRAY backed decimals
       case t: DecimalType
         if parquetType.asPrimitiveType().getPrimitiveTypeName == FIXED_LEN_BYTE_ARRAY ||
            parquetType.asPrimitiveType().getPrimitiveTypeName == BINARY =>
-        new ParquetBinaryDictionaryAwareDecimalConverter(t.precision, t.scale, updater)
+        val (precision, scale) = getPrecisionAndScale(parquetType, t)
+        new ParquetBinaryDictionaryAwareDecimalConverter(precision, scale, updater)
 
       case t: DecimalType =>
         throw new RuntimeException(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index fdad69d..c8e02c7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -33,6 +33,7 @@ import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
 import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec}
 import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
 import org.apache.spark.sql.execution.command.FunctionsCommand
+import org.apache.spark.sql.execution.datasources.SchemaColumnConvertNotSupportedException
 import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
 import org.apache.spark.sql.execution.datasources.v2.orc.OrcScan
 import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan
@@ -3575,6 +3576,57 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
       }
     })
   }
+
+  test("SPARK-34212 Parquet should read decimals correctly") {
+    // a is int-decimal (4 bytes), b is long-decimal (8 bytes), c is binary-decimal (16 bytes)
+    val df = sql("SELECT 1.0 a, CAST(1.23 AS DECIMAL(17, 2)) b, CAST(1.23 AS DECIMAL(36, 2)) c")
+
+    withTempPath { path =>
+      df.write.parquet(path.toString)
+
+      withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") {
+        val schema1 = "a DECIMAL(3, 2), b DECIMAL(18, 3), c DECIMAL(37, 3)"
+        checkAnswer(spark.read.schema(schema1).parquet(path.toString), df)
+        val schema2 = "a DECIMAL(3, 0), b DECIMAL(18, 1), c DECIMAL(37, 1)"
+        checkAnswer(spark.read.schema(schema2).parquet(path.toString), Row(1, 1.2, 1.2))
+      }
+
+      withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "true") {
+        val e1 = intercept[SparkException] {
+          spark.read.schema("a DECIMAL(3, 2)").parquet(path.toString).collect()
+        }.getCause.getCause
+        assert(e1.isInstanceOf[SchemaColumnConvertNotSupportedException])
+
+        val e2 = intercept[SparkException] {
+          spark.read.schema("b DECIMAL(18, 1)").parquet(path.toString).collect()
+        }.getCause.getCause
+        assert(e2.isInstanceOf[SchemaColumnConvertNotSupportedException])
+
+        val e3 = intercept[SparkException] {
+          spark.read.schema("c DECIMAL(37, 1)").parquet(path.toString).collect()
+        }.getCause.getCause
+        assert(e3.isInstanceOf[SchemaColumnConvertNotSupportedException])
+      }
+    }
+
+    withTempPath { path =>
+      val df2 = sql(s"SELECT 1 a, ${Int.MaxValue + 1L} b, CAST(2 AS BINARY) c")
+      df2.write.parquet(path.toString)
+
+      withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") {
+        val schema = "a DECIMAL(3, 2), b DECIMAL(17, 2), c DECIMAL(36, 2)"
+        checkAnswer(spark.read.schema(schema).parquet(path.toString),
+          Row(BigDecimal(100, 2), BigDecimal((Int.MaxValue + 1L) * 100, 2), BigDecimal(200, 2)))
+      }
+
+      withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "true") {
+        val e = intercept[SparkException] {
+          spark.read.schema("a DECIMAL(3, 2)").parquet(path.toString).collect()
+        }.getCause.getCause
+        assert(e.isInstanceOf[SchemaColumnConvertNotSupportedException])
+      }
+    }
+  }
 }
 
 case class Foo(bar: Option[String])


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org