You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by do...@apache.org on 2021/01/26 23:34:25 UTC
[spark] branch branch-2.4 updated: [SPARK-34212][SQL] Fix incorrect
decimal reading from Parquet files
This is an automated email from the ASF dual-hosted git repository.
dongjoon pushed a commit to branch branch-2.4
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-2.4 by this push:
new 6519a7e [SPARK-34212][SQL] Fix incorrect decimal reading from Parquet files
6519a7e is described below
commit 6519a7e060f4be134dc51a3792ad26dbbda0a050
Author: Dongjoon Hyun <dh...@apple.com>
AuthorDate: Tue Jan 26 15:13:39 2021 -0800
[SPARK-34212][SQL] Fix incorrect decimal reading from Parquet files
This PR aims to the correctness issues during reading decimal values from Parquet files.
- For **MR** code path, `ParquetRowConverter` can read Parquet's decimal values with the original precision and scale written in the corresponding footer.
- For **Vectorized** code path, `VectorizedColumnReader` throws `SchemaColumnConvertNotSupportedException`.
Currently, Spark returns incorrect results when the Parquet file's decimal precision and scale are different from the Spark's schema. This happens when there is multiple files with different decimal schema or HiveMetastore has a new schema.
**BEFORE (Simplified example for correctness)**
```scala
scala> sql("SELECT 1.0 a").write.parquet("/tmp/decimal")
scala> spark.read.schema("a DECIMAL(3,2)").parquet("/tmp/decimal").show
+----+
| a|
+----+
|0.10|
+----+
```
This works correctly in the other data sources, `ORC/JSON/CSV`, like the following.
```scala
scala> sql("SELECT 1.0 a").write.orc("/tmp/decimal_orc")
scala> spark.read.schema("a DECIMAL(3,2)").orc("/tmp/decimal_orc").show
+----+
| a|
+----+
|1.00|
+----+
```
**AFTER**
1. **Vectorized** path: Instead of incorrect result, we will raise an explicit exception.
```scala
scala> spark.read.schema("a DECIMAL(3,2)").parquet("/tmp/decimal").show
java.lang.UnsupportedOperationException: Schema evolution not supported.
```
2. **MR** path (complex schema or explicit configuration): Spark returns correct results.
```scala
scala> spark.read.schema("a DECIMAL(3,2), b DECIMAL(18, 3), c MAP<INT,INT>").parquet("/tmp/decimal").show
+----+-------+--------+
| a| b| c|
+----+-------+--------+
|1.00|100.000|{1 -> 2}|
+----+-------+--------+
scala> spark.read.schema("a DECIMAL(3,2), b DECIMAL(18, 3), c MAP<INT,INT>").parquet("/tmp/decimal").printSchema
root
|-- a: decimal(3,2) (nullable = true)
|-- b: decimal(18,3) (nullable = true)
|-- c: map (nullable = true)
| |-- key: integer
| |-- value: integer (valueContainsNull = true)
```
Yes. This fixes the correctness issue.
Pass with the newly added test case.
Closes #31319 from dongjoon-hyun/SPARK-34212.
Lead-authored-by: Dongjoon Hyun <dh...@apple.com>
Co-authored-by: Wenchen Fan <we...@databricks.com>
Signed-off-by: Dongjoon Hyun <dh...@apple.com>
(cherry picked from commit dbf051c50a17d644ecc1823e96eede4a5a6437fd)
Signed-off-by: Dongjoon Hyun <dh...@apple.com>
---
.../parquet/VectorizedColumnReader.java | 45 ++++++++++++++-----
.../datasources/parquet/ParquetRowConverter.scala | 22 +++++++--
.../scala/org/apache/spark/sql/SQLQuerySuite.scala | 52 ++++++++++++++++++++++
3 files changed, 105 insertions(+), 14 deletions(-)
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java
index ba26b57..4739089 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java
@@ -30,12 +30,14 @@ import org.apache.parquet.column.Encoding;
import org.apache.parquet.column.page.*;
import org.apache.parquet.column.values.ValuesReader;
import org.apache.parquet.io.api.Binary;
+import org.apache.parquet.schema.DecimalMetadata;
import org.apache.parquet.schema.OriginalType;
import org.apache.parquet.schema.PrimitiveType;
import org.apache.spark.sql.catalyst.util.DateTimeUtils;
import org.apache.spark.sql.execution.datasources.SchemaColumnConvertNotSupportedException;
import org.apache.spark.sql.execution.vectorized.WritableColumnVector;
+import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.DecimalType;
@@ -101,6 +103,27 @@ public class VectorizedColumnReader {
private final TimeZone convertTz;
private static final TimeZone UTC = DateTimeUtils.TimeZoneUTC();
+ private boolean isDecimalTypeMatched(DataType dt) {
+ DecimalType d = (DecimalType) dt;
+ DecimalMetadata dm = descriptor.getPrimitiveType().getDecimalMetadata();
+ return dm != null && dm.getPrecision() == d.precision() && dm.getScale() == d.scale();
+ }
+
+ private boolean canReadAsIntDecimal(DataType dt) {
+ if (!DecimalType.is32BitDecimalType(dt)) return false;
+ return isDecimalTypeMatched(dt);
+ }
+
+ private boolean canReadAsLongDecimal(DataType dt) {
+ if (!DecimalType.is64BitDecimalType(dt)) return false;
+ return isDecimalTypeMatched(dt);
+ }
+
+ private boolean canReadAsBinaryDecimal(DataType dt) {
+ if (!DecimalType.isByteArrayDecimalType(dt)) return false;
+ return isDecimalTypeMatched(dt);
+ }
+
public VectorizedColumnReader(
ColumnDescriptor descriptor,
OriginalType originalType,
@@ -261,7 +284,7 @@ public class VectorizedColumnReader {
switch (descriptor.getPrimitiveType().getPrimitiveTypeName()) {
case INT32:
if (column.dataType() == DataTypes.IntegerType ||
- DecimalType.is32BitDecimalType(column.dataType())) {
+ canReadAsIntDecimal(column.dataType())) {
for (int i = rowId; i < rowId + num; ++i) {
if (!column.isNullAt(i)) {
column.putInt(i, dictionary.decodeToInt(dictionaryIds.getDictId(i)));
@@ -286,7 +309,7 @@ public class VectorizedColumnReader {
case INT64:
if (column.dataType() == DataTypes.LongType ||
- DecimalType.is64BitDecimalType(column.dataType()) ||
+ canReadAsLongDecimal(column.dataType()) ||
originalType == OriginalType.TIMESTAMP_MICROS) {
for (int i = rowId; i < rowId + num; ++i) {
if (!column.isNullAt(i)) {
@@ -357,21 +380,21 @@ public class VectorizedColumnReader {
break;
case FIXED_LEN_BYTE_ARRAY:
// DecimalType written in the legacy mode
- if (DecimalType.is32BitDecimalType(column.dataType())) {
+ if (canReadAsIntDecimal(column.dataType())) {
for (int i = rowId; i < rowId + num; ++i) {
if (!column.isNullAt(i)) {
Binary v = dictionary.decodeToBinary(dictionaryIds.getDictId(i));
column.putInt(i, (int) ParquetRowConverter.binaryToUnscaledLong(v));
}
}
- } else if (DecimalType.is64BitDecimalType(column.dataType())) {
+ } else if (canReadAsLongDecimal(column.dataType())) {
for (int i = rowId; i < rowId + num; ++i) {
if (!column.isNullAt(i)) {
Binary v = dictionary.decodeToBinary(dictionaryIds.getDictId(i));
column.putLong(i, ParquetRowConverter.binaryToUnscaledLong(v));
}
}
- } else if (DecimalType.isByteArrayDecimalType(column.dataType())) {
+ } else if (canReadAsBinaryDecimal(column.dataType())) {
for (int i = rowId; i < rowId + num; ++i) {
if (!column.isNullAt(i)) {
Binary v = dictionary.decodeToBinary(dictionaryIds.getDictId(i));
@@ -407,7 +430,7 @@ public class VectorizedColumnReader {
// This is where we implement support for the valid type conversions.
// TODO: implement remaining type conversions
if (column.dataType() == DataTypes.IntegerType || column.dataType() == DataTypes.DateType ||
- DecimalType.is32BitDecimalType(column.dataType())) {
+ canReadAsIntDecimal(column.dataType())) {
defColumn.readIntegers(
num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn);
} else if (column.dataType() == DataTypes.ByteType) {
@@ -424,7 +447,7 @@ public class VectorizedColumnReader {
private void readLongBatch(int rowId, int num, WritableColumnVector column) throws IOException {
// This is where we implement support for the valid type conversions.
if (column.dataType() == DataTypes.LongType ||
- DecimalType.is64BitDecimalType(column.dataType()) ||
+ canReadAsLongDecimal(column.dataType()) ||
originalType == OriginalType.TIMESTAMP_MICROS) {
defColumn.readLongs(
num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn);
@@ -468,7 +491,7 @@ public class VectorizedColumnReader {
// TODO: implement remaining type conversions
VectorizedValuesReader data = (VectorizedValuesReader) dataColumn;
if (column.dataType() == DataTypes.StringType || column.dataType() == DataTypes.BinaryType
- || DecimalType.isByteArrayDecimalType(column.dataType())) {
+ || canReadAsBinaryDecimal(column.dataType())) {
defColumn.readBinarys(num, column, rowId, maxDefLevel, data);
} else if (column.dataType() == DataTypes.TimestampType) {
if (!shouldConvertTimestamps()) {
@@ -506,7 +529,7 @@ public class VectorizedColumnReader {
VectorizedValuesReader data = (VectorizedValuesReader) dataColumn;
// This is where we implement support for the valid type conversions.
// TODO: implement remaining type conversions
- if (DecimalType.is32BitDecimalType(column.dataType())) {
+ if (canReadAsIntDecimal(column.dataType())) {
for (int i = 0; i < num; i++) {
if (defColumn.readInteger() == maxDefLevel) {
column.putInt(rowId + i,
@@ -515,7 +538,7 @@ public class VectorizedColumnReader {
column.putNull(rowId + i);
}
}
- } else if (DecimalType.is64BitDecimalType(column.dataType())) {
+ } else if (canReadAsLongDecimal(column.dataType())) {
for (int i = 0; i < num; i++) {
if (defColumn.readInteger() == maxDefLevel) {
column.putLong(rowId + i,
@@ -524,7 +547,7 @@ public class VectorizedColumnReader {
column.putNull(rowId + i);
}
}
- } else if (DecimalType.isByteArrayDecimalType(column.dataType())) {
+ } else if (canReadAsBinaryDecimal(column.dataType())) {
for (int i = 0; i < num; i++) {
if (defColumn.readInteger() == maxDefLevel) {
column.putByteArray(rowId + i, data.readBinary(arrayLen).getBytes());
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala
index 1199725..0d22fe5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala
@@ -210,6 +210,19 @@ private[parquet] class ParquetRowConverter(
}
/**
+ * Get a precision and a scale to interpret parquet decimal values.
+ * 1. If there is a decimal metadata, we read decimal values with the given precision and scale.
+ * 2. If there is no metadata, we read decimal values with scale `0` because it's plain integers
+ * when it is written into INT32/INT64/BINARY/FIXED_LEN_BYTE_ARRAY types.
+ */
+ private def getPrecisionAndScale(parquetType: Type, t: DecimalType): (Int, Int) = {
+ val metadata = parquetType.asPrimitiveType().getDecimalMetadata
+ val precision = if (metadata == null) t.precision else metadata.getPrecision()
+ val scale = if (metadata == null) 0 else metadata.getScale()
+ (precision, scale)
+ }
+
+ /**
* Creates a converter for the given Parquet type `parquetType` and Spark SQL data type
* `catalystType`. Converted values are handled by `updater`.
*/
@@ -236,17 +249,20 @@ private[parquet] class ParquetRowConverter(
// For INT32 backed decimals
case t: DecimalType if parquetType.asPrimitiveType().getPrimitiveTypeName == INT32 =>
- new ParquetIntDictionaryAwareDecimalConverter(t.precision, t.scale, updater)
+ val (precision, scale) = getPrecisionAndScale(parquetType, t)
+ new ParquetIntDictionaryAwareDecimalConverter(precision, scale, updater)
// For INT64 backed decimals
case t: DecimalType if parquetType.asPrimitiveType().getPrimitiveTypeName == INT64 =>
- new ParquetLongDictionaryAwareDecimalConverter(t.precision, t.scale, updater)
+ val (precision, scale) = getPrecisionAndScale(parquetType, t)
+ new ParquetLongDictionaryAwareDecimalConverter(precision, scale, updater)
// For BINARY and FIXED_LEN_BYTE_ARRAY backed decimals
case t: DecimalType
if parquetType.asPrimitiveType().getPrimitiveTypeName == FIXED_LEN_BYTE_ARRAY ||
parquetType.asPrimitiveType().getPrimitiveTypeName == BINARY =>
- new ParquetBinaryDictionaryAwareDecimalConverter(t.precision, t.scale, updater)
+ val (precision, scale) = getPrecisionAndScale(parquetType, t)
+ new ParquetBinaryDictionaryAwareDecimalConverter(precision, scale, updater)
case t: DecimalType =>
throw new RuntimeException(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index ab2a1c9..1af50bf 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -29,6 +29,7 @@ import org.apache.spark.sql.execution.aggregate
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, SortAggregateExec}
import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
import org.apache.spark.sql.execution.datasources.FilePartition
+import org.apache.spark.sql.execution.datasources.SchemaColumnConvertNotSupportedException
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, CartesianProductExec, SortMergeJoinExec}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
@@ -3140,6 +3141,57 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
}
})
}
+
+ test("SPARK-34212 Parquet should read decimals correctly") {
+ // a is int-decimal (4 bytes), b is long-decimal (8 bytes), c is binary-decimal (16 bytes)
+ val df = sql("SELECT 1.0 a, CAST(1.23 AS DECIMAL(17, 2)) b, CAST(1.23 AS DECIMAL(36, 2)) c")
+
+ withTempPath { path =>
+ df.write.parquet(path.toString)
+
+ withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") {
+ val schema1 = "a DECIMAL(3, 2), b DECIMAL(18, 3), c DECIMAL(37, 3)"
+ checkAnswer(spark.read.schema(schema1).parquet(path.toString), df)
+ val schema2 = "a DECIMAL(3, 0), b DECIMAL(18, 1), c DECIMAL(37, 1)"
+ checkAnswer(spark.read.schema(schema2).parquet(path.toString), Row(1, 1.2, 1.2))
+ }
+
+ withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "true") {
+ val e1 = intercept[SparkException] {
+ spark.read.schema("a DECIMAL(3, 2)").parquet(path.toString).collect()
+ }.getCause.getCause
+ assert(e1.isInstanceOf[SchemaColumnConvertNotSupportedException])
+
+ val e2 = intercept[SparkException] {
+ spark.read.schema("b DECIMAL(18, 1)").parquet(path.toString).collect()
+ }.getCause.getCause
+ assert(e2.isInstanceOf[SchemaColumnConvertNotSupportedException])
+
+ val e3 = intercept[SparkException] {
+ spark.read.schema("c DECIMAL(37, 1)").parquet(path.toString).collect()
+ }.getCause.getCause
+ assert(e3.isInstanceOf[SchemaColumnConvertNotSupportedException])
+ }
+ }
+
+ withTempPath { path =>
+ val df2 = sql(s"SELECT 1 a, ${Int.MaxValue + 1L} b")
+ df2.write.parquet(path.toString)
+
+ withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") {
+ val schema = "a DECIMAL(3, 2), b DECIMAL(17, 2)"
+ checkAnswer(spark.read.schema(schema).parquet(path.toString),
+ Row(BigDecimal(100, 2), BigDecimal((Int.MaxValue + 1L) * 100, 2)))
+ }
+
+ withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "true") {
+ val e = intercept[SparkException] {
+ spark.read.schema("a DECIMAL(3, 2)").parquet(path.toString).collect()
+ }.getCause.getCause
+ assert(e.isInstanceOf[SchemaColumnConvertNotSupportedException])
+ }
+ }
+ }
}
case class Foo(bar: Option[String])
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org