You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2020/03/02 14:52:26 UTC

[spark] branch branch-3.0 updated: [SPARK-30993][SQL] Use its sql type for UDT when checking the type of length (fixed/var) or mutable

This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.0 by this push:
     new 6c4977d  [SPARK-30993][SQL] Use its sql type for UDT when checking the type of length (fixed/var) or mutable
6c4977d is described below

commit 6c4977d38f13628abfa24129ae6844146672d96d
Author: Jungtaek Lim (HeartSaVioR) <ka...@gmail.com>
AuthorDate: Mon Mar 2 22:33:11 2020 +0800

    [SPARK-30993][SQL] Use its sql type for UDT when checking the type of length (fixed/var) or mutable
    
    ### What changes were proposed in this pull request?
    
    This patch fixes the bug of UnsafeRow which misses to handle the UDT specifically, in `isFixedLength` and `isMutable`. These methods don't check its SQL type for UDT, always treating UDT as variable-length, and non-mutable.
    
    It doesn't bring any issue if UDT is used to represent complicated type, but when UDT is used to represent some type which is matched with fixed length of SQL type, it exposes the chance of correctness issues, as these informations sometimes decide how the value should be handled.
    
    We got report from user mailing list which suspected as mapGroupsWithState looks like handling UDT incorrectly, but after some investigation it was from GenerateUnsafeRowJoiner in shuffle phase.
    
    https://github.com/apache/spark/blob/0e2ca11d80c3921387d7b077cb64c3a0c06b08d7/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala#L32-L43
    
    Here updating position should not happen on fixed-length column, but due to this bug, the value of UDT having fixed-length as sql type would be modified, which actually corrupts the value.
    
    ### Why are the changes needed?
    
    Misclassifying of the type of length for UDT can corrupt the value when the row is presented to the input of GenerateUnsafeRowJoiner, which brings correctness issue.
    
    ### Does this PR introduce any user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    New UT added.
    
    Closes #27747 from HeartSaVioR/SPARK-30993.
    
    Authored-by: Jungtaek Lim (HeartSaVioR) <ka...@gmail.com>
    Signed-off-by: Wenchen Fan <we...@databricks.com>
    (cherry picked from commit f24a46011c8cba086193f697d653b6eccd029e8f)
    Signed-off-by: Wenchen Fan <we...@databricks.com>
---
 .../spark/sql/catalyst/expressions/UnsafeRow.java  |  8 +++++
 .../codegen/GenerateUnsafeRowJoinerSuite.scala     | 41 +++++++++++++++++++++-
 .../apache/spark/sql/UserDefinedTypeSuite.scala    | 37 +++++++++++++++++++
 3 files changed, 85 insertions(+), 1 deletion(-)

diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
index 23e7d1f..034894b 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
@@ -95,6 +95,10 @@ public final class UnsafeRow extends InternalRow implements Externalizable, Kryo
   }
 
   public static boolean isFixedLength(DataType dt) {
+    if (dt instanceof UserDefinedType) {
+      return isFixedLength(((UserDefinedType) dt).sqlType());
+    }
+
     if (dt instanceof DecimalType) {
       return ((DecimalType) dt).precision() <= Decimal.MAX_LONG_DIGITS();
     } else {
@@ -103,6 +107,10 @@ public final class UnsafeRow extends InternalRow implements Externalizable, Kryo
   }
 
   public static boolean isMutable(DataType dt) {
+    if (dt instanceof UserDefinedType) {
+      return isMutable(((UserDefinedType) dt).sqlType());
+    }
+
     return mutableFieldTypes.contains(dt) || dt instanceof DecimalType ||
       dt instanceof CalendarIntervalType;
   }
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala
index 81e2993..fb1ea7b 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala
@@ -17,13 +17,15 @@
 
 package org.apache.spark.sql.catalyst.expressions.codegen
 
+import java.time.{LocalDateTime, ZoneOffset}
+
 import scala.util.Random
 
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.RandomDataGenerator
 import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
 import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
-import org.apache.spark.sql.catalyst.expressions.{JoinedRow, UnsafeProjection, UnsafeRow}
+import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, JoinedRow, UnsafeProjection, UnsafeRow}
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.UTF8String
 
@@ -99,6 +101,23 @@ class GenerateUnsafeRowJoinerSuite extends SparkFunSuite {
     testConcatOnce(N, N, variable)
   }
 
+  test("SPARK-30993: UserDefinedType matched to fixed length SQL type shouldn't be corrupted") {
+    val schema1 = new StructType(Array(
+      StructField("date", new WrappedDateTimeUDT),
+      StructField("s", StringType),
+      StructField("i", IntegerType)))
+    val proj1 = UnsafeProjection.create(schema1.fields.map(_.dataType))
+    val intRow1 = new GenericInternalRow(Array[Any](
+      LocalDateTime.now().toEpochSecond(ZoneOffset.UTC),
+      UTF8String.fromString("hello"), 1))
+
+    val schema2 = new StructType(Array(StructField("i", IntegerType)))
+    val proj2 = UnsafeProjection.create(schema2.fields.map(_.dataType))
+    val intRow2 = new GenericInternalRow(Array[Any](2))
+
+    testConcat(schema1, proj1.apply(intRow1), schema2, proj2.apply(intRow2))
+  }
+
   private def testConcat(numFields1: Int, numFields2: Int, candidateTypes: Seq[DataType]): Unit = {
     for (i <- 0 until 10) {
       testConcatOnce(numFields1, numFields2, candidateTypes)
@@ -204,3 +223,23 @@ class GenerateUnsafeRowJoinerSuite extends SparkFunSuite {
   }
 
 }
+
+private[sql] case class WrappedDateTime(dt: LocalDateTime)
+
+private[sql] class WrappedDateTimeUDT extends UserDefinedType[WrappedDateTime] {
+  override def sqlType: DataType = LongType
+
+  override def serialize(obj: WrappedDateTime): Long = {
+    obj.dt.toEpochSecond(ZoneOffset.UTC)
+  }
+
+  def deserialize(datum: Any): WrappedDateTime = datum match {
+    case value: Long =>
+      val v = LocalDateTime.ofEpochSecond(value, 0, ZoneOffset.UTC)
+      WrappedDateTime(v)
+  }
+
+  override def userClass: Class[WrappedDateTime] = classOf[WrappedDateTime]
+
+  private[spark] override def asNullable: WrappedDateTimeUDT = this
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
index ffc2018d..157610f5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
@@ -17,6 +17,7 @@
 
 package org.apache.spark.sql
 
+import java.time.{LocalDateTime, ZoneOffset}
 import java.util.Arrays
 
 import org.apache.spark.rdd.RDD
@@ -103,6 +104,24 @@ private[spark] class ExampleSubTypeUDT extends UserDefinedType[IExampleSubType]
   override def userClass: Class[IExampleSubType] = classOf[IExampleSubType]
 }
 
+private[sql] case class FooWithDate(date: LocalDateTime, s: String, i: Int)
+
+private[sql] class LocalDateTimeUDT extends UserDefinedType[LocalDateTime] {
+  override def sqlType: DataType = LongType
+
+  override def serialize(obj: LocalDateTime): Long = {
+    obj.toEpochSecond(ZoneOffset.UTC)
+  }
+
+  def deserialize(datum: Any): LocalDateTime = datum match {
+    case value: Long => LocalDateTime.ofEpochSecond(value, 0, ZoneOffset.UTC)
+  }
+
+  override def userClass: Class[LocalDateTime] = classOf[LocalDateTime]
+
+  private[spark] override def asNullable: LocalDateTimeUDT = this
+}
+
 class UserDefinedTypeSuite extends QueryTest with SharedSparkSession with ParquetTest
     with ExpressionEvalHelper {
   import testImplicits._
@@ -287,4 +306,22 @@ class UserDefinedTypeSuite extends QueryTest with SharedSparkSession with Parque
     checkAnswer(spark.createDataFrame(data, schema).selectExpr("typeof(a)"),
       Seq(Row("array<double>")))
   }
+
+  test("SPARK-30993: UserDefinedType matched to fixed length SQL type shouldn't be corrupted") {
+    def concatFoo(a: FooWithDate, b: FooWithDate): FooWithDate = {
+      FooWithDate(b.date, a.s + b.s, a.i)
+    }
+
+    UDTRegistration.register(classOf[LocalDateTime].getName, classOf[LocalDateTimeUDT].getName)
+
+    // remove sub-millisecond part as we only use millis based timestamp while serde
+    val date = LocalDateTime.ofEpochSecond(LocalDateTime.now().toEpochSecond(ZoneOffset.UTC),
+      0, ZoneOffset.UTC)
+    val inputDS = List(FooWithDate(date, "Foo", 1), FooWithDate(date, "Foo", 3),
+      FooWithDate(date, "Foo", 3)).toDS()
+    val agg = inputDS.groupByKey(x => x.i).mapGroups((_, iter) => iter.reduce(concatFoo))
+    val result = agg.collect()
+
+    assert(result.toSet === Set(FooWithDate(date, "FooFoo", 3), FooWithDate(date, "Foo", 1)))
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org