You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2021/02/19 08:36:54 UTC
[spark] branch master updated: [SPARK-34314][SQL] Fix partitions
schema inference
This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new b26e7b5 [SPARK-34314][SQL] Fix partitions schema inference
b26e7b5 is described below
commit b26e7b510bbaee63c4095ab47e75ff2a70e377d7
Author: Max Gekk <ma...@gmail.com>
AuthorDate: Fri Feb 19 08:36:13 2021 +0000
[SPARK-34314][SQL] Fix partitions schema inference
### What changes were proposed in this pull request?
Infer the partitions schema by:
1. interring the common type over all partition part values, and
2. casting those values to the common type
Before the changes:
1. Spark creates a literal with most appropriate type for concrete partition value i.e. `part0=-0` -> `Literal(0, IntegerType)`, `part0=abc` -> `Literal(UTF8String.fromString("abc"), StringType)`.
2. Finds the common type for all literals of a partition column. For the example above, it is `StringType`.
3. Casts those literal to the desired type:
- `Cast(Literal(0, IntegerType), StringType)` -> `UTF8String.fromString("0")`
- `Cast(Literal(UTF8String.fromString("abc", StringType), StringType)` -> `UTF8String.fromString("abc")`
In the example, we get a partition part value "0" which is different from the original one "-0". Spark shouldn't modify partition part values of the string type because it can influence on query results.
Closes #31423
### Why are the changes needed?
The changes fix the bug demonstrated by the example:
1. There are partitioned parquet files (file format doesn't matter):
```
/private/var/folders/p3/dfs6mf655d7fnjrsjvldh0tc0000gn/T/spark-e09eae99-7ecf-4ab2-b99b-f63f8dea658d
├── _SUCCESS
├── part=-0
│ └── part-00001-02144398-2896-4d21-9628-a8743d098cb4.c000.snappy.parquet
└── part=AA
└── part-00000-02144398-2896-4d21-9628-a8743d098cb4.c000.snappy.parquet
```
placed to two partitions "AA" and **"-0"**.
2. When reading them w/o specified schema:
```
val df = spark.read.parquet(path)
df.printSchema()
root
|-- id: integer (nullable = true)
|-- part: string (nullable = true)
```
the inferred type of the partition column `part` is the **string** type.
3. The expected values in the column `part` are "AA" and "-0" but we get:
```
df.show(false)
+---+----+
|id |part|
+---+----+
|0 |AA |
|1 |0 |
+---+----+
```
So, Spark returns **"0"** instead of **"-0"**.
### Does this PR introduce _any_ user-facing change?
This PR can change query results.
### How was this patch tested?
By running new test and existing test suites:
```
$ build/sbt "test:testOnly *FileIndexSuite"
$ build/sbt "test:testOnly *ParquetV1PartitionDiscoverySuite"
$ build/sbt "test:testOnly *ParquetV2PartitionDiscoverySuite"
```
Closes #31549 from MaxGekk/fix-partition-file-index-2.
Authored-by: Max Gekk <ma...@gmail.com>
Signed-off-by: Wenchen Fan <we...@databricks.com>
---
.../execution/datasources/PartitioningUtils.scala | 118 ++++++++++++---------
.../parquet/ParquetPartitionDiscoverySuite.scala | 73 ++++++++-----
2 files changed, 109 insertions(+), 82 deletions(-)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala
index 69123ee..0a5ddae 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala
@@ -25,6 +25,7 @@ import java.util.Locale
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.util.Try
+import scala.util.control.NonFatal
import org.apache.hadoop.fs.Path
@@ -36,6 +37,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal}
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateFormatter, DateTimeUtils, TimestampFormatter}
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.SchemaUtils
+import org.apache.spark.unsafe.types.UTF8String
// TODO: We should tighten up visibility of the classes here once we clean up Hive coupling.
@@ -62,9 +64,11 @@ object PartitioningUtils {
val timestampPartitionPattern = "yyyy-MM-dd HH:mm:ss[.S]"
- private[datasources] case class PartitionValues(columnNames: Seq[String], literals: Seq[Literal])
+ case class TypedPartValue(value: String, dataType: DataType)
+
+ case class PartitionValues(columnNames: Seq[String], typedValues: Seq[TypedPartValue])
{
- require(columnNames.size == literals.size)
+ require(columnNames.size == typedValues.size)
}
import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.{escapePathName, unescapePathName, DEFAULT_PARTITION_NAME}
@@ -172,13 +176,12 @@ object PartitioningUtils {
"root directory of the table. If there are multiple root directories, " +
"please load them separately and then union them.")
- val resolvedPartitionValues =
- resolvePartitions(pathsWithPartitionValues, caseSensitive, zoneId)
+ val resolvedPartitionValues = resolvePartitions(pathsWithPartitionValues, caseSensitive)
// Creates the StructType which represents the partition columns.
val fields = {
- val PartitionValues(columnNames, literals) = resolvedPartitionValues.head
- columnNames.zip(literals).map { case (name, Literal(_, dataType)) =>
+ val PartitionValues(columnNames, typedValues) = resolvedPartitionValues.head
+ columnNames.zip(typedValues).map { case (name, TypedPartValue(_, dataType)) =>
// We always assume partition columns are nullable since we've no idea whether null values
// will be appended in the future.
val resultName = userSpecifiedNames.getOrElse(name, name)
@@ -189,8 +192,19 @@ object PartitioningUtils {
// Finally, we create `Partition`s based on paths and resolved partition values.
val partitions = resolvedPartitionValues.zip(pathsWithPartitionValues).map {
- case (PartitionValues(_, literals), (path, _)) =>
- PartitionPath(InternalRow.fromSeq(literals.map(_.value)), path)
+ case (PartitionValues(columnNames, typedValues), (path, _)) =>
+ val rowValues = columnNames.zip(typedValues).map { case (columnName, typedValue) =>
+ try {
+ castPartValueToDesiredType(typedValue.dataType, typedValue.value, zoneId)
+ } catch {
+ case NonFatal(_) =>
+ if (validatePartitionColumns) {
+ throw new RuntimeException(s"Failed to cast value `${typedValue.value}` to " +
+ s"`${typedValue.dataType}` for partition column `$columnName`")
+ } else null
+ }
+ }
+ PartitionPath(InternalRow.fromSeq(rowValues), path)
}
PartitionSpec(StructType(fields), partitions)
@@ -226,7 +240,7 @@ object PartitioningUtils {
zoneId: ZoneId,
dateFormatter: DateFormatter,
timestampFormatter: TimestampFormatter): (Option[PartitionValues], Option[Path]) = {
- val columns = ArrayBuffer.empty[(String, Literal)]
+ val columns = ArrayBuffer.empty[(String, TypedPartValue)]
// Old Hadoop versions don't have `Path.isRoot`
var finished = path.getParent == null
// currentPath is the current path that we will use to parse partition column value.
@@ -284,7 +298,7 @@ object PartitioningUtils {
validatePartitionColumns: Boolean,
zoneId: ZoneId,
dateFormatter: DateFormatter,
- timestampFormatter: TimestampFormatter): Option[(String, Literal)] = {
+ timestampFormatter: TimestampFormatter): Option[(String, TypedPartValue)] = {
val equalSignIndex = columnSpec.indexOf('=')
if (equalSignIndex == -1) {
None
@@ -295,23 +309,10 @@ object PartitioningUtils {
val rawColumnValue = columnSpec.drop(equalSignIndex + 1)
assert(rawColumnValue.nonEmpty, s"Empty partition column value in '$columnSpec'")
- val literal = if (userSpecifiedDataTypes.contains(columnName)) {
+ val dataType = if (userSpecifiedDataTypes.contains(columnName)) {
// SPARK-26188: if user provides corresponding column schema, get the column value without
// inference, and then cast it as user specified data type.
- val dataType = userSpecifiedDataTypes(columnName)
- val columnValueLiteral = inferPartitionColumnValue(
- rawColumnValue,
- false,
- zoneId,
- dateFormatter,
- timestampFormatter)
- val columnValue = columnValueLiteral.eval()
- val castedValue = Cast(columnValueLiteral, dataType, Option(zoneId.getId)).eval()
- if (validatePartitionColumns && columnValue != null && castedValue == null) {
- throw new RuntimeException(s"Failed to cast value `$columnValue` to `$dataType` " +
- s"for partition column `$columnName`")
- }
- Literal.create(castedValue, dataType)
+ userSpecifiedDataTypes(columnName)
} else {
inferPartitionColumnValue(
rawColumnValue,
@@ -320,7 +321,7 @@ object PartitioningUtils {
dateFormatter,
timestampFormatter)
}
- Some(columnName -> literal)
+ Some(columnName -> TypedPartValue(rawColumnValue, dataType))
}
}
@@ -363,8 +364,7 @@ object PartitioningUtils {
*/
def resolvePartitions(
pathsWithPartitionValues: Seq[(Path, PartitionValues)],
- caseSensitive: Boolean,
- zoneId: ZoneId): Seq[PartitionValues] = {
+ caseSensitive: Boolean): Seq[PartitionValues] = {
if (pathsWithPartitionValues.isEmpty) {
Seq.empty
} else {
@@ -381,12 +381,12 @@ object PartitioningUtils {
val values = pathsWithPartitionValues.map(_._2)
val columnCount = values.head.columnNames.size
val resolvedValues = (0 until columnCount).map { i =>
- resolveTypeConflicts(values.map(_.literals(i)), zoneId)
+ resolveTypeConflicts(values.map(_.typedValues(i)))
}
// Fills resolved literals back to each partition
values.zipWithIndex.map { case (d, index) =>
- d.copy(literals = resolvedValues.map(_(index)))
+ d.copy(typedValues = resolvedValues.map(_(index)))
}
}
}
@@ -449,7 +449,7 @@ object PartitioningUtils {
typeInference: Boolean,
zoneId: ZoneId,
dateFormatter: DateFormatter,
- timestampFormatter: TimestampFormatter): Literal = {
+ timestampFormatter: TimestampFormatter): DataType = {
val decimalTry = Try {
// `BigDecimal` conversion can fail when the `field` is not a form of number.
val bigDecimal = new JBigDecimal(raw)
@@ -458,7 +458,7 @@ object PartitioningUtils {
// `DecimalType` conversion can fail when
// 1. The precision is bigger than 38.
// 2. scale is bigger than precision.
- Literal(bigDecimal)
+ DecimalType.fromDecimal(Decimal(bigDecimal))
}
val dateTry = Try {
@@ -474,7 +474,7 @@ object PartitioningUtils {
val dateValue = Cast(Literal(raw), DateType, Some(zoneId.getId)).eval()
// Disallow DateType if the cast returned null
require(dateValue != null)
- Literal.create(dateValue, DateType)
+ DateType
}
val timestampTry = Try {
@@ -486,36 +486,50 @@ object PartitioningUtils {
val timestampValue = Cast(Literal(unescapedRaw), TimestampType, Some(zoneId.getId)).eval()
// Disallow TimestampType if the cast returned null
require(timestampValue != null)
- Literal.create(timestampValue, TimestampType)
+ TimestampType
}
if (typeInference) {
// First tries integral types
- Try(Literal.create(Integer.parseInt(raw), IntegerType))
- .orElse(Try(Literal.create(JLong.parseLong(raw), LongType)))
+ Try({ Integer.parseInt(raw); IntegerType })
+ .orElse(Try { JLong.parseLong(raw); LongType })
.orElse(decimalTry)
// Then falls back to fractional types
- .orElse(Try(Literal.create(JDouble.parseDouble(raw), DoubleType)))
+ .orElse(Try { JDouble.parseDouble(raw); DoubleType })
// Then falls back to date/timestamp types
.orElse(timestampTry)
.orElse(dateTry)
// Then falls back to string
.getOrElse {
- if (raw == DEFAULT_PARTITION_NAME) {
- Literal.create(null, NullType)
- } else {
- Literal.create(unescapePathName(raw), StringType)
- }
+ if (raw == DEFAULT_PARTITION_NAME) NullType else StringType
}
} else {
- if (raw == DEFAULT_PARTITION_NAME) {
- Literal.create(null, NullType)
- } else {
- Literal.create(unescapePathName(raw), StringType)
- }
+ if (raw == DEFAULT_PARTITION_NAME) NullType else StringType
}
}
+ def castPartValueToDesiredType(
+ desiredType: DataType,
+ value: String,
+ zoneId: ZoneId): Any = desiredType match {
+ case _ if value == DEFAULT_PARTITION_NAME => null
+ case NullType => null
+ case StringType => UTF8String.fromString(unescapePathName(value))
+ case IntegerType => Integer.parseInt(value)
+ case LongType => JLong.parseLong(value)
+ case DoubleType => JDouble.parseDouble(value)
+ case _: DecimalType => Literal(new JBigDecimal(value)).value
+ case DateType =>
+ Cast(Literal(value), DateType, Some(zoneId.getId)).eval()
+ case TimestampType =>
+ Try {
+ Cast(Literal(unescapePathName(value)), TimestampType, Some(zoneId.getId)).eval()
+ }.getOrElse {
+ Cast(Cast(Literal(value), DateType, Some(zoneId.getId)), TimestampType).eval()
+ }
+ case dt => throw new IllegalArgumentException(s"Unexpected type $dt")
+ }
+
def validatePartitionColumn(
schema: StructType,
partitionColumns: Seq[String],
@@ -590,13 +604,11 @@ object PartitioningUtils {
* Given a collection of [[Literal]]s, resolves possible type conflicts by
* [[findWiderTypeForPartitionColumn]].
*/
- private def resolveTypeConflicts(literals: Seq[Literal], zoneId: ZoneId): Seq[Literal] = {
- val litTypes = literals.map(_.dataType)
- val desiredType = litTypes.reduce(findWiderTypeForPartitionColumn)
+ private def resolveTypeConflicts(typedValues: Seq[TypedPartValue]): Seq[TypedPartValue] = {
+ val dataTypes = typedValues.map(_.dataType)
+ val desiredType = dataTypes.reduce(findWiderTypeForPartitionColumn)
- literals.map { case l @ Literal(_, dataType) =>
- Literal.create(Cast(l, desiredType, Some(zoneId.getId)).eval(), desiredType)
- }
+ typedValues.map(tv => tv.copy(dataType = desiredType))
}
/**
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala
index 400f4d8..5ea8c61 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources.parquet
import java.io.File
import java.math.BigInteger
-import java.sql.{Date, Timestamp}
+import java.sql.Timestamp
import java.time.{ZoneId, ZoneOffset}
import java.util.{Calendar, Locale}
@@ -31,7 +31,6 @@ import org.apache.spark.SparkConf
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils
-import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, TimestampFormatter}
import org.apache.spark.sql.catalyst.util.DateTimeUtils.TimeZoneUTC
import org.apache.spark.sql.execution.datasources._
@@ -72,29 +71,25 @@ abstract class ParquetPartitionDiscoverySuite
}
test("column type inference") {
- def check(raw: String, literal: Literal, zoneId: ZoneId = timeZoneId): Unit = {
- assert(inferPartitionColumnValue(raw, true, zoneId, df, tf) === literal)
+ def check(raw: String, dataType: DataType, zoneId: ZoneId = timeZoneId): Unit = {
+ assert(inferPartitionColumnValue(raw, true, zoneId, df, tf) === dataType)
}
- check("10", Literal.create(10, IntegerType))
- check("1000000000000000", Literal.create(1000000000000000L, LongType))
+ check("10", IntegerType)
+ check("1000000000000000", LongType)
val decimal = Decimal("1" * 20)
- check("1" * 20,
- Literal.create(decimal, DecimalType(decimal.precision, decimal.scale)))
- check("1.5", Literal.create(1.5, DoubleType))
- check("hello", Literal.create("hello", StringType))
- check("1990-02-24", Literal.create(Date.valueOf("1990-02-24"), DateType))
- check("1990-02-24 12:00:30",
- Literal.create(Timestamp.valueOf("1990-02-24 12:00:30"), TimestampType))
+ check("1" * 20, DecimalType(decimal.precision, decimal.scale))
+ check("1.5", DoubleType)
+ check("hello", StringType)
+ check("1990-02-24", DateType)
+ check("1990-02-24 12:00:30", TimestampType)
val c = Calendar.getInstance(TimeZoneUTC)
c.set(1990, 1, 24, 12, 0, 30)
c.set(Calendar.MILLISECOND, 0)
- check("1990-02-24 12:00:30",
- Literal.create(new Timestamp(c.getTimeInMillis), TimestampType),
- ZoneOffset.UTC)
+ check("1990-02-24 12:00:30", TimestampType, ZoneOffset.UTC)
- check(defaultPartitionName, Literal.create(null, NullType))
+ check(defaultPartitionName, NullType)
}
test("parse invalid partitioned directories") {
@@ -216,22 +211,22 @@ abstract class ParquetPartitionDiscoverySuite
check("file://path/a=10", Some {
PartitionValues(
Seq("a"),
- Seq(Literal.create(10, IntegerType)))
+ Seq(TypedPartValue("10", IntegerType)))
})
check("file://path/a=10/b=hello/c=1.5", Some {
PartitionValues(
Seq("a", "b", "c"),
Seq(
- Literal.create(10, IntegerType),
- Literal.create("hello", StringType),
- Literal.create(1.5, DoubleType)))
+ TypedPartValue("10", IntegerType),
+ TypedPartValue("hello", StringType),
+ TypedPartValue("1.5", DoubleType)))
})
check("file://path/a=10/b_hello/c=1.5", Some {
PartitionValues(
Seq("c"),
- Seq(Literal.create(1.5, DoubleType)))
+ Seq(TypedPartValue("1.5", DoubleType)))
})
check("file:///", None)
@@ -273,7 +268,7 @@ abstract class ParquetPartitionDiscoverySuite
assert(partitionSpec2 ==
Option(PartitionValues(
Seq("a"),
- Seq(Literal.create(10, IntegerType)))))
+ Seq(TypedPartValue("10", IntegerType)))))
}
test("parse partitions") {
@@ -911,15 +906,19 @@ abstract class ParquetPartitionDiscoverySuite
assert(
listConflictingPartitionColumns(
Seq(
- (new Path("file:/tmp/foo/a=1"), PartitionValues(Seq("a"), Seq(Literal(1)))),
- (new Path("file:/tmp/foo/b=1"), PartitionValues(Seq("b"), Seq(Literal(1)))))).trim ===
+ (new Path("file:/tmp/foo/a=1"),
+ PartitionValues(Seq("a"), Seq(TypedPartValue("1", IntegerType)))),
+ (new Path("file:/tmp/foo/b=1"),
+ PartitionValues(Seq("b"), Seq(TypedPartValue("1", IntegerType)))))).trim ===
makeExpectedMessage(Seq("a", "b"), Seq("file:/tmp/foo/a=1", "file:/tmp/foo/b=1")))
assert(
listConflictingPartitionColumns(
Seq(
- (new Path("file:/tmp/foo/a=1/_temporary"), PartitionValues(Seq("a"), Seq(Literal(1)))),
- (new Path("file:/tmp/foo/a=1"), PartitionValues(Seq("a"), Seq(Literal(1)))))).trim ===
+ (new Path("file:/tmp/foo/a=1/_temporary"),
+ PartitionValues(Seq("a"), Seq(TypedPartValue("1", IntegerType)))),
+ (new Path("file:/tmp/foo/a=1"),
+ PartitionValues(Seq("a"), Seq(TypedPartValue("1", IntegerType)))))).trim ===
makeExpectedMessage(
Seq("a"),
Seq("file:/tmp/foo/a=1/_temporary", "file:/tmp/foo/a=1")))
@@ -928,9 +927,10 @@ abstract class ParquetPartitionDiscoverySuite
listConflictingPartitionColumns(
Seq(
(new Path("file:/tmp/foo/a=1"),
- PartitionValues(Seq("a"), Seq(Literal(1)))),
+ PartitionValues(Seq("a"), Seq(TypedPartValue("1", IntegerType)))),
(new Path("file:/tmp/foo/a=1/b=foo"),
- PartitionValues(Seq("a", "b"), Seq(Literal(1), Literal("foo")))))).trim ===
+ PartitionValues(Seq("a", "b"),
+ Seq(TypedPartValue("1", IntegerType), TypedPartValue("foo", StringType)))))).trim ===
makeExpectedMessage(
Seq("a", "a, b"),
Seq("file:/tmp/foo/a=1", "file:/tmp/foo/a=1/b=foo")))
@@ -1039,6 +1039,21 @@ abstract class ParquetPartitionDiscoverySuite
checkAnswer(input, data)
}
}
+
+ test("SPARK-34314: preserve partition values of the string type") {
+ import testImplicits._
+ withTempPath { file =>
+ val path = file.getCanonicalPath
+ val df = Seq((0, "AA"), (1, "-0")).toDF("id", "part")
+ df.write
+ .partitionBy("part")
+ .format("parquet")
+ .save(path)
+ val readback = spark.read.parquet(path)
+ assert(readback.schema("part").dataType === StringType)
+ checkAnswer(readback, Row(0, "AA") :: Row(1, "-0") :: Nil)
+ }
+ }
}
class ParquetV1PartitionDiscoverySuite extends ParquetPartitionDiscoverySuite {
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org