You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by li...@apache.org on 2015/12/23 03:21:10 UTC

spark git commit: [SPARK-12478][SQL] Bugfix: Dataset fields of product types can't be null

Repository: spark
Updated Branches:
  refs/heads/master 20591afd7 -> 86761e10e


[SPARK-12478][SQL] Bugfix: Dataset fields of product types can't be null

When creating extractors for product types (i.e. case classes and tuples), a null check is missing, thus we always assume input product values are non-null.

This PR adds a null check in the extractor expression for product types. The null check is stripped off for top level product fields, which are mapped to the outermost `Row`s, since they can't be null.

Thanks cloud-fan for helping investigating this issue!

Author: Cheng Lian <li...@databricks.com>

Closes #10431 from liancheng/spark-12478.top-level-null-field.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/86761e10
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/86761e10
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/86761e10

Branch: refs/heads/master
Commit: 86761e10e145b6867cbe86b1e924ec237ba408af
Parents: 20591af
Author: Cheng Lian <li...@databricks.com>
Authored: Wed Dec 23 10:21:00 2015 +0800
Committer: Cheng Lian <li...@databricks.com>
Committed: Wed Dec 23 10:21:00 2015 +0800

----------------------------------------------------------------------
 .../org/apache/spark/sql/catalyst/ScalaReflection.scala  |  8 ++++----
 .../test/scala/org/apache/spark/sql/DatasetSuite.scala   | 11 +++++++++++
 2 files changed, 15 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/86761e10/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index becd019..8a22b37 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -380,7 +380,7 @@ object ScalaReflection extends ScalaReflection {
     val clsName = getClassNameFromType(tpe)
     val walkedTypePath = s"""- root class: "${clsName}"""" :: Nil
     extractorFor(inputObject, tpe, walkedTypePath) match {
-      case s: CreateNamedStruct => s
+      case expressions.If(_, _, s: CreateNamedStruct) if tpe <:< localTypeOf[Product] => s
       case other => CreateNamedStruct(expressions.Literal("value") :: other :: Nil)
     }
   }
@@ -466,14 +466,14 @@ object ScalaReflection extends ScalaReflection {
 
         case t if t <:< localTypeOf[Product] =>
           val params = getConstructorParameters(t)
-
-          CreateNamedStruct(params.flatMap { case (fieldName, fieldType) =>
+          val nonNullOutput = CreateNamedStruct(params.flatMap { case (fieldName, fieldType) =>
             val fieldValue = Invoke(inputObject, fieldName, dataTypeFor(fieldType))
             val clsName = getClassNameFromType(fieldType)
             val newPath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath
-
             expressions.Literal(fieldName) :: extractorFor(fieldValue, fieldType, newPath) :: Nil
           })
+          val nullOutput = expressions.Literal.create(null, nonNullOutput.dataType)
+          expressions.If(IsNull(inputObject), nullOutput, nonNullOutput)
 
         case t if t <:< localTypeOf[Array[_]] =>
           val TypeRef(_, _, Seq(elementType)) = t

http://git-wip-us.apache.org/repos/asf/spark/blob/86761e10/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 3337996..7fe66e4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -546,6 +546,16 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
       "Null value appeared in non-nullable field org.apache.spark.sql.ClassData.b of type Int."
     ))
   }
+
+  test("SPARK-12478: top level null field") {
+    val ds0 = Seq(NestedStruct(null)).toDS()
+    checkAnswer(ds0, NestedStruct(null))
+    checkAnswer(ds0.toDF(), Row(null))
+
+    val ds1 = Seq(DeepNestedStruct(NestedStruct(null))).toDS()
+    checkAnswer(ds1, DeepNestedStruct(NestedStruct(null)))
+    checkAnswer(ds1.toDF(), Row(Row(null)))
+  }
 }
 
 case class ClassData(a: String, b: Int)
@@ -553,6 +563,7 @@ case class ClassData2(c: String, d: Int)
 case class ClassNullableData(a: String, b: Integer)
 
 case class NestedStruct(f: ClassData)
+case class DeepNestedStruct(f: NestedStruct)
 
 /**
  * A class used to test serialization using encoders. This class throws exceptions when using


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org