You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by li...@apache.org on 2015/12/23 03:21:10 UTC
spark git commit: [SPARK-12478][SQL] Bugfix: Dataset fields of
product types can't be null
Repository: spark
Updated Branches:
refs/heads/master 20591afd7 -> 86761e10e
[SPARK-12478][SQL] Bugfix: Dataset fields of product types can't be null
When creating extractors for product types (i.e. case classes and tuples), a null check is missing, thus we always assume input product values are non-null.
This PR adds a null check in the extractor expression for product types. The null check is stripped off for top level product fields, which are mapped to the outermost `Row`s, since they can't be null.
Thanks cloud-fan for helping investigating this issue!
Author: Cheng Lian <li...@databricks.com>
Closes #10431 from liancheng/spark-12478.top-level-null-field.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/86761e10
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/86761e10
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/86761e10
Branch: refs/heads/master
Commit: 86761e10e145b6867cbe86b1e924ec237ba408af
Parents: 20591af
Author: Cheng Lian <li...@databricks.com>
Authored: Wed Dec 23 10:21:00 2015 +0800
Committer: Cheng Lian <li...@databricks.com>
Committed: Wed Dec 23 10:21:00 2015 +0800
----------------------------------------------------------------------
.../org/apache/spark/sql/catalyst/ScalaReflection.scala | 8 ++++----
.../test/scala/org/apache/spark/sql/DatasetSuite.scala | 11 +++++++++++
2 files changed, 15 insertions(+), 4 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/86761e10/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index becd019..8a22b37 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -380,7 +380,7 @@ object ScalaReflection extends ScalaReflection {
val clsName = getClassNameFromType(tpe)
val walkedTypePath = s"""- root class: "${clsName}"""" :: Nil
extractorFor(inputObject, tpe, walkedTypePath) match {
- case s: CreateNamedStruct => s
+ case expressions.If(_, _, s: CreateNamedStruct) if tpe <:< localTypeOf[Product] => s
case other => CreateNamedStruct(expressions.Literal("value") :: other :: Nil)
}
}
@@ -466,14 +466,14 @@ object ScalaReflection extends ScalaReflection {
case t if t <:< localTypeOf[Product] =>
val params = getConstructorParameters(t)
-
- CreateNamedStruct(params.flatMap { case (fieldName, fieldType) =>
+ val nonNullOutput = CreateNamedStruct(params.flatMap { case (fieldName, fieldType) =>
val fieldValue = Invoke(inputObject, fieldName, dataTypeFor(fieldType))
val clsName = getClassNameFromType(fieldType)
val newPath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath
-
expressions.Literal(fieldName) :: extractorFor(fieldValue, fieldType, newPath) :: Nil
})
+ val nullOutput = expressions.Literal.create(null, nonNullOutput.dataType)
+ expressions.If(IsNull(inputObject), nullOutput, nonNullOutput)
case t if t <:< localTypeOf[Array[_]] =>
val TypeRef(_, _, Seq(elementType)) = t
http://git-wip-us.apache.org/repos/asf/spark/blob/86761e10/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 3337996..7fe66e4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -546,6 +546,16 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
"Null value appeared in non-nullable field org.apache.spark.sql.ClassData.b of type Int."
))
}
+
+ test("SPARK-12478: top level null field") {
+ val ds0 = Seq(NestedStruct(null)).toDS()
+ checkAnswer(ds0, NestedStruct(null))
+ checkAnswer(ds0.toDF(), Row(null))
+
+ val ds1 = Seq(DeepNestedStruct(NestedStruct(null))).toDS()
+ checkAnswer(ds1, DeepNestedStruct(NestedStruct(null)))
+ checkAnswer(ds1.toDF(), Row(Row(null)))
+ }
}
case class ClassData(a: String, b: Int)
@@ -553,6 +563,7 @@ case class ClassData2(c: String, d: Int)
case class ClassNullableData(a: String, b: Integer)
case class NestedStruct(f: ClassData)
+case class DeepNestedStruct(f: NestedStruct)
/**
* A class used to test serialization using encoders. This class throws exceptions when using
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org