You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by li...@apache.org on 2017/11/10 05:56:23 UTC

spark git commit: [SPARK-22472][SQL] add null check for top-level primitive values

Repository: spark
Updated Branches:
  refs/heads/master b57ed2245 -> 0025ddeb1


[SPARK-22472][SQL] add null check for top-level primitive values

## What changes were proposed in this pull request?

One powerful feature of `Dataset` is, we can easily map SQL rows to Scala/Java objects and do runtime null check automatically.

For example, let's say we have a parquet file with schema `<a: int, b: string>`, and we have a `case class Data(a: Int, b: String)`. Users can easily read this parquet file into `Data` objects, and Spark will throw NPE if column `a` has null values.

However the null checking is left behind for top-level primitive values. For example, let's say we have a parquet file with schema `<a: Int>`, and we read it into Scala `Int`. If column `a` has null values, we will get some weird results.
```
scala> val ds = spark.read.parquet(...).as[Int]

scala> ds.show()
+----+
|v   |
+----+
|null|
|1   |
+----+

scala> ds.collect
res0: Array[Long] = Array(0, 1)

scala> ds.map(_ * 2).show
+-----+
|value|
+-----+
|-2   |
|2    |
+-----+
```

This is because internally Spark use some special default values for primitive types, but never expect users to see/operate these default value directly.

This PR adds null check for top-level primitive values

## How was this patch tested?

new test

Author: Wenchen Fan <we...@databricks.com>

Closes #19707 from cloud-fan/bug.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/0025ddeb
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/0025ddeb
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/0025ddeb

Branch: refs/heads/master
Commit: 0025ddeb1dd4fd6951ecd8456457f6b94124f84e
Parents: b57ed22
Author: Wenchen Fan <we...@databricks.com>
Authored: Thu Nov 9 21:56:20 2017 -0800
Committer: gatorsmile <ga...@gmail.com>
Committed: Thu Nov 9 21:56:20 2017 -0800

----------------------------------------------------------------------
 .../spark/sql/catalyst/ScalaReflection.scala      |  8 +++++++-
 .../spark/sql/catalyst/ScalaReflectionSuite.scala |  7 ++++++-
 .../scala/org/apache/spark/sql/DatasetSuite.scala | 18 ++++++++++++++++++
 3 files changed, 31 insertions(+), 2 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/0025ddeb/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index f62553d..4e47a58 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -134,7 +134,13 @@ object ScalaReflection extends ScalaReflection {
     val tpe = localTypeOf[T]
     val clsName = getClassNameFromType(tpe)
     val walkedTypePath = s"""- root class: "$clsName"""" :: Nil
-    deserializerFor(tpe, None, walkedTypePath)
+    val expr = deserializerFor(tpe, None, walkedTypePath)
+    val Schema(_, nullable) = schemaFor(tpe)
+    if (nullable) {
+      expr
+    } else {
+      AssertNotNull(expr, walkedTypePath)
+    }
   }
 
   private def deserializerFor(

http://git-wip-us.apache.org/repos/asf/spark/blob/0025ddeb/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
index f77af5d..23e866c 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
@@ -22,7 +22,7 @@ import java.sql.{Date, Timestamp}
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
 import org.apache.spark.sql.catalyst.expressions.{BoundReference, Literal, SpecificInternalRow, UpCast}
-import org.apache.spark.sql.catalyst.expressions.objects.NewInstance
+import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, NewInstance}
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.UTF8String
 
@@ -351,4 +351,9 @@ class ScalaReflectionSuite extends SparkFunSuite {
     assert(argumentsFields(0) == Seq("field.1"))
     assert(argumentsFields(1) == Seq("field 2"))
   }
+
+  test("SPARK-22472: add null check for top-level primitive values") {
+    assert(deserializerFor[Int].isInstanceOf[AssertNotNull])
+    assert(!deserializerFor[String].isInstanceOf[AssertNotNull])
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/0025ddeb/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index c67165c..6e13a5d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql
 import java.io.{Externalizable, ObjectInput, ObjectOutput}
 import java.sql.{Date, Timestamp}
 
+import org.apache.spark.SparkException
 import org.apache.spark.sql.catalyst.encoders.{OuterScopes, RowEncoder}
 import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi}
 import org.apache.spark.sql.catalyst.util.sideBySide
@@ -1408,6 +1409,23 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
       checkDataset(ds, SpecialCharClass("1", "2"))
     }
   }
+
+  test("SPARK-22472: add null check for top-level primitive values") {
+    // If the primitive values are from Option, we need to do runtime null check.
+    val ds = Seq(Some(1), None).toDS().as[Int]
+    intercept[NullPointerException](ds.collect())
+    val e = intercept[SparkException](ds.map(_ * 2).collect())
+    assert(e.getCause.isInstanceOf[NullPointerException])
+
+    withTempPath { path =>
+      Seq(new Integer(1), null).toDF("i").write.parquet(path.getCanonicalPath)
+      // If the primitive values are from files, we need to do runtime null check.
+      val ds = spark.read.parquet(path.getCanonicalPath).as[Int]
+      intercept[NullPointerException](ds.collect())
+      val e = intercept[SparkException](ds.map(_ * 2).collect())
+      assert(e.getCause.isInstanceOf[NullPointerException])
+    }
+  }
 }
 
 case class SingleData(id: Int)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org