You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2021/02/19 08:36:54 UTC

[spark] branch master updated: [SPARK-34314][SQL] Fix partitions schema inference

This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new b26e7b5  [SPARK-34314][SQL] Fix partitions schema inference
b26e7b5 is described below

commit b26e7b510bbaee63c4095ab47e75ff2a70e377d7
Author: Max Gekk <ma...@gmail.com>
AuthorDate: Fri Feb 19 08:36:13 2021 +0000

    [SPARK-34314][SQL] Fix partitions schema inference
    
    ### What changes were proposed in this pull request?
    Infer the partitions schema by:
    1. interring the common type over all partition part values, and
    2. casting those values to the common type
    
    Before the changes:
    1. Spark creates a literal with most appropriate type for concrete partition value i.e. `part0=-0` -> `Literal(0, IntegerType)`, `part0=abc` -> `Literal(UTF8String.fromString("abc"), StringType)`.
    2. Finds the common type for all literals of a partition column. For the example above, it is `StringType`.
    3. Casts those literal to the desired type:
      - `Cast(Literal(0, IntegerType), StringType)` -> `UTF8String.fromString("0")`
      - `Cast(Literal(UTF8String.fromString("abc", StringType), StringType)` -> `UTF8String.fromString("abc")`
    
    In the example, we get a partition part value "0" which is different from the original one "-0". Spark shouldn't modify partition part values of the string type because it can influence on query results.
    
    Closes #31423
    
    ### Why are the changes needed?
    The changes fix the bug demonstrated by the example:
    1. There are partitioned parquet files (file format doesn't matter):
    ```
    /private/var/folders/p3/dfs6mf655d7fnjrsjvldh0tc0000gn/T/spark-e09eae99-7ecf-4ab2-b99b-f63f8dea658d
    ├── _SUCCESS
    ├── part=-0
    │   └── part-00001-02144398-2896-4d21-9628-a8743d098cb4.c000.snappy.parquet
    └── part=AA
        └── part-00000-02144398-2896-4d21-9628-a8743d098cb4.c000.snappy.parquet
    ```
    placed to two partitions "AA" and **"-0"**.
    
    2. When reading them w/o specified schema:
    ```
    val df = spark.read.parquet(path)
    df.printSchema()
    root
     |-- id: integer (nullable = true)
     |-- part: string (nullable = true)
    ```
    the inferred type of the partition column `part` is the **string** type.
    3. The expected values in the column `part` are "AA" and "-0" but we get:
    ```
    df.show(false)
    +---+----+
    |id |part|
    +---+----+
    |0  |AA  |
    |1  |0   |
    +---+----+
    ```
    So, Spark returns **"0"** instead of **"-0"**.
    
    ### Does this PR introduce _any_ user-facing change?
    This PR can change query results.
    
    ### How was this patch tested?
    By running new test and existing test suites:
    ```
    $ build/sbt "test:testOnly *FileIndexSuite"
    $ build/sbt "test:testOnly *ParquetV1PartitionDiscoverySuite"
    $ build/sbt "test:testOnly *ParquetV2PartitionDiscoverySuite"
    ```
    
    Closes #31549 from MaxGekk/fix-partition-file-index-2.
    
    Authored-by: Max Gekk <ma...@gmail.com>
    Signed-off-by: Wenchen Fan <we...@databricks.com>
---
 .../execution/datasources/PartitioningUtils.scala  | 118 ++++++++++++---------
 .../parquet/ParquetPartitionDiscoverySuite.scala   |  73 ++++++++-----
 2 files changed, 109 insertions(+), 82 deletions(-)

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala
index 69123ee..0a5ddae 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala
@@ -25,6 +25,7 @@ import java.util.Locale
 import scala.collection.mutable
 import scala.collection.mutable.ArrayBuffer
 import scala.util.Try
+import scala.util.control.NonFatal
 
 import org.apache.hadoop.fs.Path
 
@@ -36,6 +37,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal}
 import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateFormatter, DateTimeUtils, TimestampFormatter}
 import org.apache.spark.sql.types._
 import org.apache.spark.sql.util.SchemaUtils
+import org.apache.spark.unsafe.types.UTF8String
 
 // TODO: We should tighten up visibility of the classes here once we clean up Hive coupling.
 
@@ -62,9 +64,11 @@ object PartitioningUtils {
 
   val timestampPartitionPattern = "yyyy-MM-dd HH:mm:ss[.S]"
 
-  private[datasources] case class PartitionValues(columnNames: Seq[String], literals: Seq[Literal])
+  case class TypedPartValue(value: String, dataType: DataType)
+
+  case class PartitionValues(columnNames: Seq[String], typedValues: Seq[TypedPartValue])
   {
-    require(columnNames.size == literals.size)
+    require(columnNames.size == typedValues.size)
   }
 
   import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.{escapePathName, unescapePathName, DEFAULT_PARTITION_NAME}
@@ -172,13 +176,12 @@ object PartitioningUtils {
           "root directory of the table. If there are multiple root directories, " +
           "please load them separately and then union them.")
 
-      val resolvedPartitionValues =
-        resolvePartitions(pathsWithPartitionValues, caseSensitive, zoneId)
+      val resolvedPartitionValues = resolvePartitions(pathsWithPartitionValues, caseSensitive)
 
       // Creates the StructType which represents the partition columns.
       val fields = {
-        val PartitionValues(columnNames, literals) = resolvedPartitionValues.head
-        columnNames.zip(literals).map { case (name, Literal(_, dataType)) =>
+        val PartitionValues(columnNames, typedValues) = resolvedPartitionValues.head
+        columnNames.zip(typedValues).map { case (name, TypedPartValue(_, dataType)) =>
           // We always assume partition columns are nullable since we've no idea whether null values
           // will be appended in the future.
           val resultName = userSpecifiedNames.getOrElse(name, name)
@@ -189,8 +192,19 @@ object PartitioningUtils {
 
       // Finally, we create `Partition`s based on paths and resolved partition values.
       val partitions = resolvedPartitionValues.zip(pathsWithPartitionValues).map {
-        case (PartitionValues(_, literals), (path, _)) =>
-          PartitionPath(InternalRow.fromSeq(literals.map(_.value)), path)
+        case (PartitionValues(columnNames, typedValues), (path, _)) =>
+          val rowValues = columnNames.zip(typedValues).map { case (columnName, typedValue) =>
+            try {
+              castPartValueToDesiredType(typedValue.dataType, typedValue.value, zoneId)
+            } catch {
+              case NonFatal(_) =>
+                if (validatePartitionColumns) {
+                  throw new RuntimeException(s"Failed to cast value `${typedValue.value}` to " +
+                    s"`${typedValue.dataType}` for partition column `$columnName`")
+                } else null
+            }
+          }
+          PartitionPath(InternalRow.fromSeq(rowValues), path)
       }
 
       PartitionSpec(StructType(fields), partitions)
@@ -226,7 +240,7 @@ object PartitioningUtils {
       zoneId: ZoneId,
       dateFormatter: DateFormatter,
       timestampFormatter: TimestampFormatter): (Option[PartitionValues], Option[Path]) = {
-    val columns = ArrayBuffer.empty[(String, Literal)]
+    val columns = ArrayBuffer.empty[(String, TypedPartValue)]
     // Old Hadoop versions don't have `Path.isRoot`
     var finished = path.getParent == null
     // currentPath is the current path that we will use to parse partition column value.
@@ -284,7 +298,7 @@ object PartitioningUtils {
       validatePartitionColumns: Boolean,
       zoneId: ZoneId,
       dateFormatter: DateFormatter,
-      timestampFormatter: TimestampFormatter): Option[(String, Literal)] = {
+      timestampFormatter: TimestampFormatter): Option[(String, TypedPartValue)] = {
     val equalSignIndex = columnSpec.indexOf('=')
     if (equalSignIndex == -1) {
       None
@@ -295,23 +309,10 @@ object PartitioningUtils {
       val rawColumnValue = columnSpec.drop(equalSignIndex + 1)
       assert(rawColumnValue.nonEmpty, s"Empty partition column value in '$columnSpec'")
 
-      val literal = if (userSpecifiedDataTypes.contains(columnName)) {
+      val dataType = if (userSpecifiedDataTypes.contains(columnName)) {
         // SPARK-26188: if user provides corresponding column schema, get the column value without
         //              inference, and then cast it as user specified data type.
-        val dataType = userSpecifiedDataTypes(columnName)
-        val columnValueLiteral = inferPartitionColumnValue(
-          rawColumnValue,
-          false,
-          zoneId,
-          dateFormatter,
-          timestampFormatter)
-        val columnValue = columnValueLiteral.eval()
-        val castedValue = Cast(columnValueLiteral, dataType, Option(zoneId.getId)).eval()
-        if (validatePartitionColumns && columnValue != null && castedValue == null) {
-          throw new RuntimeException(s"Failed to cast value `$columnValue` to `$dataType` " +
-            s"for partition column `$columnName`")
-        }
-        Literal.create(castedValue, dataType)
+        userSpecifiedDataTypes(columnName)
       } else {
         inferPartitionColumnValue(
           rawColumnValue,
@@ -320,7 +321,7 @@ object PartitioningUtils {
           dateFormatter,
           timestampFormatter)
       }
-      Some(columnName -> literal)
+      Some(columnName -> TypedPartValue(rawColumnValue, dataType))
     }
   }
 
@@ -363,8 +364,7 @@ object PartitioningUtils {
    */
   def resolvePartitions(
       pathsWithPartitionValues: Seq[(Path, PartitionValues)],
-      caseSensitive: Boolean,
-      zoneId: ZoneId): Seq[PartitionValues] = {
+      caseSensitive: Boolean): Seq[PartitionValues] = {
     if (pathsWithPartitionValues.isEmpty) {
       Seq.empty
     } else {
@@ -381,12 +381,12 @@ object PartitioningUtils {
       val values = pathsWithPartitionValues.map(_._2)
       val columnCount = values.head.columnNames.size
       val resolvedValues = (0 until columnCount).map { i =>
-        resolveTypeConflicts(values.map(_.literals(i)), zoneId)
+        resolveTypeConflicts(values.map(_.typedValues(i)))
       }
 
       // Fills resolved literals back to each partition
       values.zipWithIndex.map { case (d, index) =>
-        d.copy(literals = resolvedValues.map(_(index)))
+        d.copy(typedValues = resolvedValues.map(_(index)))
       }
     }
   }
@@ -449,7 +449,7 @@ object PartitioningUtils {
       typeInference: Boolean,
       zoneId: ZoneId,
       dateFormatter: DateFormatter,
-      timestampFormatter: TimestampFormatter): Literal = {
+      timestampFormatter: TimestampFormatter): DataType = {
     val decimalTry = Try {
       // `BigDecimal` conversion can fail when the `field` is not a form of number.
       val bigDecimal = new JBigDecimal(raw)
@@ -458,7 +458,7 @@ object PartitioningUtils {
       // `DecimalType` conversion can fail when
       //   1. The precision is bigger than 38.
       //   2. scale is bigger than precision.
-      Literal(bigDecimal)
+      DecimalType.fromDecimal(Decimal(bigDecimal))
     }
 
     val dateTry = Try {
@@ -474,7 +474,7 @@ object PartitioningUtils {
       val dateValue = Cast(Literal(raw), DateType, Some(zoneId.getId)).eval()
       // Disallow DateType if the cast returned null
       require(dateValue != null)
-      Literal.create(dateValue, DateType)
+      DateType
     }
 
     val timestampTry = Try {
@@ -486,36 +486,50 @@ object PartitioningUtils {
       val timestampValue = Cast(Literal(unescapedRaw), TimestampType, Some(zoneId.getId)).eval()
       // Disallow TimestampType if the cast returned null
       require(timestampValue != null)
-      Literal.create(timestampValue, TimestampType)
+      TimestampType
     }
 
     if (typeInference) {
       // First tries integral types
-      Try(Literal.create(Integer.parseInt(raw), IntegerType))
-        .orElse(Try(Literal.create(JLong.parseLong(raw), LongType)))
+      Try({ Integer.parseInt(raw); IntegerType })
+        .orElse(Try { JLong.parseLong(raw); LongType })
         .orElse(decimalTry)
         // Then falls back to fractional types
-        .orElse(Try(Literal.create(JDouble.parseDouble(raw), DoubleType)))
+        .orElse(Try { JDouble.parseDouble(raw); DoubleType })
         // Then falls back to date/timestamp types
         .orElse(timestampTry)
         .orElse(dateTry)
         // Then falls back to string
         .getOrElse {
-          if (raw == DEFAULT_PARTITION_NAME) {
-            Literal.create(null, NullType)
-          } else {
-            Literal.create(unescapePathName(raw), StringType)
-          }
+          if (raw == DEFAULT_PARTITION_NAME) NullType else StringType
         }
     } else {
-      if (raw == DEFAULT_PARTITION_NAME) {
-        Literal.create(null, NullType)
-      } else {
-        Literal.create(unescapePathName(raw), StringType)
-      }
+      if (raw == DEFAULT_PARTITION_NAME) NullType else StringType
     }
   }
 
+  def castPartValueToDesiredType(
+      desiredType: DataType,
+      value: String,
+      zoneId: ZoneId): Any = desiredType match {
+    case _ if value == DEFAULT_PARTITION_NAME => null
+    case NullType => null
+    case StringType => UTF8String.fromString(unescapePathName(value))
+    case IntegerType => Integer.parseInt(value)
+    case LongType => JLong.parseLong(value)
+    case DoubleType => JDouble.parseDouble(value)
+    case _: DecimalType => Literal(new JBigDecimal(value)).value
+    case DateType =>
+      Cast(Literal(value), DateType, Some(zoneId.getId)).eval()
+    case TimestampType =>
+      Try {
+        Cast(Literal(unescapePathName(value)), TimestampType, Some(zoneId.getId)).eval()
+      }.getOrElse {
+        Cast(Cast(Literal(value), DateType, Some(zoneId.getId)), TimestampType).eval()
+      }
+    case dt => throw new IllegalArgumentException(s"Unexpected type $dt")
+  }
+
   def validatePartitionColumn(
       schema: StructType,
       partitionColumns: Seq[String],
@@ -590,13 +604,11 @@ object PartitioningUtils {
    * Given a collection of [[Literal]]s, resolves possible type conflicts by
    * [[findWiderTypeForPartitionColumn]].
    */
-  private def resolveTypeConflicts(literals: Seq[Literal], zoneId: ZoneId): Seq[Literal] = {
-    val litTypes = literals.map(_.dataType)
-    val desiredType = litTypes.reduce(findWiderTypeForPartitionColumn)
+  private def resolveTypeConflicts(typedValues: Seq[TypedPartValue]): Seq[TypedPartValue] = {
+    val dataTypes = typedValues.map(_.dataType)
+    val desiredType = dataTypes.reduce(findWiderTypeForPartitionColumn)
 
-    literals.map { case l @ Literal(_, dataType) =>
-      Literal.create(Cast(l, desiredType, Some(zoneId.getId)).eval(), desiredType)
-    }
+    typedValues.map(tv => tv.copy(dataType = desiredType))
   }
 
   /**
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala
index 400f4d8..5ea8c61 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources.parquet
 
 import java.io.File
 import java.math.BigInteger
-import java.sql.{Date, Timestamp}
+import java.sql.Timestamp
 import java.time.{ZoneId, ZoneOffset}
 import java.util.{Calendar, Locale}
 
@@ -31,7 +31,6 @@ import org.apache.spark.SparkConf
 import org.apache.spark.sql._
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils
-import org.apache.spark.sql.catalyst.expressions.Literal
 import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, TimestampFormatter}
 import org.apache.spark.sql.catalyst.util.DateTimeUtils.TimeZoneUTC
 import org.apache.spark.sql.execution.datasources._
@@ -72,29 +71,25 @@ abstract class ParquetPartitionDiscoverySuite
   }
 
   test("column type inference") {
-    def check(raw: String, literal: Literal, zoneId: ZoneId = timeZoneId): Unit = {
-      assert(inferPartitionColumnValue(raw, true, zoneId, df, tf) === literal)
+    def check(raw: String, dataType: DataType, zoneId: ZoneId = timeZoneId): Unit = {
+      assert(inferPartitionColumnValue(raw, true, zoneId, df, tf) === dataType)
     }
 
-    check("10", Literal.create(10, IntegerType))
-    check("1000000000000000", Literal.create(1000000000000000L, LongType))
+    check("10", IntegerType)
+    check("1000000000000000", LongType)
     val decimal = Decimal("1" * 20)
-    check("1" * 20,
-      Literal.create(decimal, DecimalType(decimal.precision, decimal.scale)))
-    check("1.5", Literal.create(1.5, DoubleType))
-    check("hello", Literal.create("hello", StringType))
-    check("1990-02-24", Literal.create(Date.valueOf("1990-02-24"), DateType))
-    check("1990-02-24 12:00:30",
-      Literal.create(Timestamp.valueOf("1990-02-24 12:00:30"), TimestampType))
+    check("1" * 20, DecimalType(decimal.precision, decimal.scale))
+    check("1.5", DoubleType)
+    check("hello", StringType)
+    check("1990-02-24", DateType)
+    check("1990-02-24 12:00:30", TimestampType)
 
     val c = Calendar.getInstance(TimeZoneUTC)
     c.set(1990, 1, 24, 12, 0, 30)
     c.set(Calendar.MILLISECOND, 0)
-    check("1990-02-24 12:00:30",
-      Literal.create(new Timestamp(c.getTimeInMillis), TimestampType),
-      ZoneOffset.UTC)
+    check("1990-02-24 12:00:30", TimestampType, ZoneOffset.UTC)
 
-    check(defaultPartitionName, Literal.create(null, NullType))
+    check(defaultPartitionName, NullType)
   }
 
   test("parse invalid partitioned directories") {
@@ -216,22 +211,22 @@ abstract class ParquetPartitionDiscoverySuite
     check("file://path/a=10", Some {
       PartitionValues(
         Seq("a"),
-        Seq(Literal.create(10, IntegerType)))
+        Seq(TypedPartValue("10", IntegerType)))
     })
 
     check("file://path/a=10/b=hello/c=1.5", Some {
       PartitionValues(
         Seq("a", "b", "c"),
         Seq(
-          Literal.create(10, IntegerType),
-          Literal.create("hello", StringType),
-          Literal.create(1.5, DoubleType)))
+          TypedPartValue("10", IntegerType),
+          TypedPartValue("hello", StringType),
+          TypedPartValue("1.5", DoubleType)))
     })
 
     check("file://path/a=10/b_hello/c=1.5", Some {
       PartitionValues(
         Seq("c"),
-        Seq(Literal.create(1.5, DoubleType)))
+        Seq(TypedPartValue("1.5", DoubleType)))
     })
 
     check("file:///", None)
@@ -273,7 +268,7 @@ abstract class ParquetPartitionDiscoverySuite
     assert(partitionSpec2 ==
       Option(PartitionValues(
         Seq("a"),
-        Seq(Literal.create(10, IntegerType)))))
+        Seq(TypedPartValue("10", IntegerType)))))
   }
 
   test("parse partitions") {
@@ -911,15 +906,19 @@ abstract class ParquetPartitionDiscoverySuite
     assert(
       listConflictingPartitionColumns(
         Seq(
-          (new Path("file:/tmp/foo/a=1"), PartitionValues(Seq("a"), Seq(Literal(1)))),
-          (new Path("file:/tmp/foo/b=1"), PartitionValues(Seq("b"), Seq(Literal(1)))))).trim ===
+          (new Path("file:/tmp/foo/a=1"),
+            PartitionValues(Seq("a"), Seq(TypedPartValue("1", IntegerType)))),
+          (new Path("file:/tmp/foo/b=1"),
+            PartitionValues(Seq("b"), Seq(TypedPartValue("1", IntegerType)))))).trim ===
         makeExpectedMessage(Seq("a", "b"), Seq("file:/tmp/foo/a=1", "file:/tmp/foo/b=1")))
 
     assert(
       listConflictingPartitionColumns(
         Seq(
-          (new Path("file:/tmp/foo/a=1/_temporary"), PartitionValues(Seq("a"), Seq(Literal(1)))),
-          (new Path("file:/tmp/foo/a=1"), PartitionValues(Seq("a"), Seq(Literal(1)))))).trim ===
+          (new Path("file:/tmp/foo/a=1/_temporary"),
+            PartitionValues(Seq("a"), Seq(TypedPartValue("1", IntegerType)))),
+          (new Path("file:/tmp/foo/a=1"),
+            PartitionValues(Seq("a"), Seq(TypedPartValue("1", IntegerType)))))).trim ===
         makeExpectedMessage(
           Seq("a"),
           Seq("file:/tmp/foo/a=1/_temporary", "file:/tmp/foo/a=1")))
@@ -928,9 +927,10 @@ abstract class ParquetPartitionDiscoverySuite
       listConflictingPartitionColumns(
         Seq(
           (new Path("file:/tmp/foo/a=1"),
-            PartitionValues(Seq("a"), Seq(Literal(1)))),
+            PartitionValues(Seq("a"), Seq(TypedPartValue("1", IntegerType)))),
           (new Path("file:/tmp/foo/a=1/b=foo"),
-            PartitionValues(Seq("a", "b"), Seq(Literal(1), Literal("foo")))))).trim ===
+            PartitionValues(Seq("a", "b"),
+              Seq(TypedPartValue("1", IntegerType), TypedPartValue("foo", StringType)))))).trim ===
         makeExpectedMessage(
           Seq("a", "a, b"),
           Seq("file:/tmp/foo/a=1", "file:/tmp/foo/a=1/b=foo")))
@@ -1039,6 +1039,21 @@ abstract class ParquetPartitionDiscoverySuite
       checkAnswer(input, data)
     }
   }
+
+  test("SPARK-34314: preserve partition values of the string type") {
+    import testImplicits._
+    withTempPath { file =>
+      val path = file.getCanonicalPath
+      val df = Seq((0, "AA"), (1, "-0")).toDF("id", "part")
+      df.write
+        .partitionBy("part")
+        .format("parquet")
+        .save(path)
+      val readback = spark.read.parquet(path)
+      assert(readback.schema("part").dataType === StringType)
+      checkAnswer(readback, Row(0, "AA") :: Row(1, "-0") :: Nil)
+    }
+  }
 }
 
 class ParquetV1PartitionDiscoverySuite extends ParquetPartitionDiscoverySuite {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org