You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by li...@apache.org on 2017/11/10 05:56:23 UTC
spark git commit: [SPARK-22472][SQL] add null check for top-level
primitive values
Repository: spark
Updated Branches:
refs/heads/master b57ed2245 -> 0025ddeb1
[SPARK-22472][SQL] add null check for top-level primitive values
## What changes were proposed in this pull request?
One powerful feature of `Dataset` is, we can easily map SQL rows to Scala/Java objects and do runtime null check automatically.
For example, let's say we have a parquet file with schema `<a: int, b: string>`, and we have a `case class Data(a: Int, b: String)`. Users can easily read this parquet file into `Data` objects, and Spark will throw NPE if column `a` has null values.
However the null checking is left behind for top-level primitive values. For example, let's say we have a parquet file with schema `<a: Int>`, and we read it into Scala `Int`. If column `a` has null values, we will get some weird results.
```
scala> val ds = spark.read.parquet(...).as[Int]
scala> ds.show()
+----+
|v |
+----+
|null|
|1 |
+----+
scala> ds.collect
res0: Array[Long] = Array(0, 1)
scala> ds.map(_ * 2).show
+-----+
|value|
+-----+
|-2 |
|2 |
+-----+
```
This is because internally Spark use some special default values for primitive types, but never expect users to see/operate these default value directly.
This PR adds null check for top-level primitive values
## How was this patch tested?
new test
Author: Wenchen Fan <we...@databricks.com>
Closes #19707 from cloud-fan/bug.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/0025ddeb
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/0025ddeb
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/0025ddeb
Branch: refs/heads/master
Commit: 0025ddeb1dd4fd6951ecd8456457f6b94124f84e
Parents: b57ed22
Author: Wenchen Fan <we...@databricks.com>
Authored: Thu Nov 9 21:56:20 2017 -0800
Committer: gatorsmile <ga...@gmail.com>
Committed: Thu Nov 9 21:56:20 2017 -0800
----------------------------------------------------------------------
.../spark/sql/catalyst/ScalaReflection.scala | 8 +++++++-
.../spark/sql/catalyst/ScalaReflectionSuite.scala | 7 ++++++-
.../scala/org/apache/spark/sql/DatasetSuite.scala | 18 ++++++++++++++++++
3 files changed, 31 insertions(+), 2 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/0025ddeb/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index f62553d..4e47a58 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -134,7 +134,13 @@ object ScalaReflection extends ScalaReflection {
val tpe = localTypeOf[T]
val clsName = getClassNameFromType(tpe)
val walkedTypePath = s"""- root class: "$clsName"""" :: Nil
- deserializerFor(tpe, None, walkedTypePath)
+ val expr = deserializerFor(tpe, None, walkedTypePath)
+ val Schema(_, nullable) = schemaFor(tpe)
+ if (nullable) {
+ expr
+ } else {
+ AssertNotNull(expr, walkedTypePath)
+ }
}
private def deserializerFor(
http://git-wip-us.apache.org/repos/asf/spark/blob/0025ddeb/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
index f77af5d..23e866c 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
@@ -22,7 +22,7 @@ import java.sql.{Date, Timestamp}
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.expressions.{BoundReference, Literal, SpecificInternalRow, UpCast}
-import org.apache.spark.sql.catalyst.expressions.objects.NewInstance
+import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, NewInstance}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@@ -351,4 +351,9 @@ class ScalaReflectionSuite extends SparkFunSuite {
assert(argumentsFields(0) == Seq("field.1"))
assert(argumentsFields(1) == Seq("field 2"))
}
+
+ test("SPARK-22472: add null check for top-level primitive values") {
+ assert(deserializerFor[Int].isInstanceOf[AssertNotNull])
+ assert(!deserializerFor[String].isInstanceOf[AssertNotNull])
+ }
}
http://git-wip-us.apache.org/repos/asf/spark/blob/0025ddeb/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index c67165c..6e13a5d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql
import java.io.{Externalizable, ObjectInput, ObjectOutput}
import java.sql.{Date, Timestamp}
+import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.encoders.{OuterScopes, RowEncoder}
import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi}
import org.apache.spark.sql.catalyst.util.sideBySide
@@ -1408,6 +1409,23 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
checkDataset(ds, SpecialCharClass("1", "2"))
}
}
+
+ test("SPARK-22472: add null check for top-level primitive values") {
+ // If the primitive values are from Option, we need to do runtime null check.
+ val ds = Seq(Some(1), None).toDS().as[Int]
+ intercept[NullPointerException](ds.collect())
+ val e = intercept[SparkException](ds.map(_ * 2).collect())
+ assert(e.getCause.isInstanceOf[NullPointerException])
+
+ withTempPath { path =>
+ Seq(new Integer(1), null).toDF("i").write.parquet(path.getCanonicalPath)
+ // If the primitive values are from files, we need to do runtime null check.
+ val ds = spark.read.parquet(path.getCanonicalPath).as[Int]
+ intercept[NullPointerException](ds.collect())
+ val e = intercept[SparkException](ds.map(_ * 2).collect())
+ assert(e.getCause.isInstanceOf[NullPointerException])
+ }
+ }
}
case class SingleData(id: Int)
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org