You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2016/02/08 21:06:04 UTC

spark git commit: [SPARK-13101][SQL] nullability of array type element should not fail analysis of encoder

Repository: spark
Updated Branches:
  refs/heads/master 06f0df6df -> 8e4d15f70


[SPARK-13101][SQL] nullability of array type element should not fail analysis of encoder

nullability should only be considered as an optimization rather than part of the type system, so instead of failing analysis for mismatch nullability, we should pass analysis and add runtime null check.

Author: Wenchen Fan <we...@databricks.com>

Closes #11035 from cloud-fan/ignore-nullability.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/8e4d15f7
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/8e4d15f7
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/8e4d15f7

Branch: refs/heads/master
Commit: 8e4d15f70713e1aaaa96dfb3ea4ccc5bb08eb2ce
Parents: 06f0df6
Author: Wenchen Fan <we...@databricks.com>
Authored: Mon Feb 8 12:06:00 2016 -0800
Committer: Michael Armbrust <mi...@databricks.com>
Committed: Mon Feb 8 12:06:00 2016 -0800

----------------------------------------------------------------------
 .../spark/sql/catalyst/JavaTypeInference.scala  |   2 +-
 .../spark/sql/catalyst/ScalaReflection.scala    |  29 +++--
 .../spark/sql/catalyst/analysis/Analyzer.scala  |   2 +-
 .../sql/catalyst/expressions/objects.scala      |  20 ++--
 .../encoders/EncoderResolutionSuite.scala       | 107 ++++++-------------
 .../org/apache/spark/sql/JavaDatasetSuite.java  |   4 +-
 .../org/apache/spark/sql/DatasetSuite.scala     |   4 +-
 7 files changed, 64 insertions(+), 104 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/8e4d15f7/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
index 3c3717d..59ee41d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
@@ -292,7 +292,7 @@ object JavaTypeInference {
           val setter = if (nullable) {
             constructor
           } else {
-            AssertNotNull(constructor, other.getName, fieldName, fieldType.toString)
+            AssertNotNull(constructor, Seq("currently no type path record in java"))
           }
           p.getWriteMethod.getName -> setter
         }.toMap

http://git-wip-us.apache.org/repos/asf/spark/blob/8e4d15f7/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index e5811ef..02cb2d9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -249,6 +249,8 @@ object ScalaReflection extends ScalaReflection {
 
       case t if t <:< localTypeOf[Array[_]] =>
         val TypeRef(_, _, Seq(elementType)) = t
+
+        // TODO: add runtime null check for primitive array
         val primitiveMethod = elementType match {
           case t if t <:< definitions.IntTpe => Some("toIntArray")
           case t if t <:< definitions.LongTpe => Some("toLongArray")
@@ -276,22 +278,29 @@ object ScalaReflection extends ScalaReflection {
 
       case t if t <:< localTypeOf[Seq[_]] =>
         val TypeRef(_, _, Seq(elementType)) = t
+        val Schema(dataType, nullable) = schemaFor(elementType)
         val className = getClassNameFromType(elementType)
         val newTypePath = s"""- array element class: "$className"""" +: walkedTypePath
-        val arrayData =
-          Invoke(
-            MapObjects(
-              p => constructorFor(elementType, Some(p), newTypePath),
-              getPath,
-              schemaFor(elementType).dataType),
-            "array",
-            ObjectType(classOf[Array[Any]]))
+
+        val mapFunction: Expression => Expression = p => {
+          val converter = constructorFor(elementType, Some(p), newTypePath)
+          if (nullable) {
+            converter
+          } else {
+            AssertNotNull(converter, newTypePath)
+          }
+        }
+
+        val array = Invoke(
+          MapObjects(mapFunction, getPath, dataType),
+          "array",
+          ObjectType(classOf[Array[Any]]))
 
         StaticInvoke(
           scala.collection.mutable.WrappedArray.getClass,
           ObjectType(classOf[Seq[_]]),
           "make",
-          arrayData :: Nil)
+          array :: Nil)
 
       case t if t <:< localTypeOf[Map[_, _]] =>
         // TODO: add walked type path for map
@@ -343,7 +352,7 @@ object ScalaReflection extends ScalaReflection {
               newTypePath)
 
             if (!nullable) {
-              AssertNotNull(constructor, t.toString, fieldName, fieldType.toString)
+              AssertNotNull(constructor, newTypePath)
             } else {
               constructor
             }

http://git-wip-us.apache.org/repos/asf/spark/blob/8e4d15f7/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index cb228cf..4d53b23 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -1426,7 +1426,7 @@ object ResolveUpCast extends Rule[LogicalPlan] {
           fail(child, DateType, walkedTypePath)
         case (StringType, to: NumericType) =>
           fail(child, to, walkedTypePath)
-        case _ => Cast(child, dataType)
+        case _ => Cast(child, dataType.asNullable)
       }
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/8e4d15f7/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
index 79fe003..fef6825 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
@@ -365,7 +365,7 @@ object MapObjects {
  *                       to handle collection elements.
  * @param inputData An expression that when evaluted returns a collection object.
  */
-case class MapObjects(
+case class MapObjects private(
     loopVar: LambdaVariable,
     lambdaFunction: Expression,
     inputData: Expression) extends Expression {
@@ -637,8 +637,7 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp
  * `Int` field named `i`.  Expression `s.i` is nullable because `s` can be null.  However, for all
  * non-null `s`, `s.i` can't be null.
  */
-case class AssertNotNull(
-    child: Expression, parentType: String, fieldName: String, fieldType: String)
+case class AssertNotNull(child: Expression, walkedTypePath: Seq[String])
   extends UnaryExpression {
 
   override def dataType: DataType = child.dataType
@@ -651,6 +650,14 @@ case class AssertNotNull(
   override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = {
     val childGen = child.gen(ctx)
 
+    val errMsg = "Null value appeared in non-nullable field:" +
+      walkedTypePath.mkString("\n", "\n", "\n") +
+      "If the schema is inferred from a Scala tuple/case class, or a Java bean, " +
+      "please try to use scala.Option[_] or other nullable types " +
+      "(e.g. java.lang.Integer instead of int/scala.Int)."
+    val idx = ctx.references.length
+    ctx.references += errMsg
+
     ev.isNull = "false"
     ev.value = childGen.value
 
@@ -658,12 +665,7 @@ case class AssertNotNull(
       ${childGen.code}
 
       if (${childGen.isNull}) {
-        throw new RuntimeException(
-          "Null value appeared in non-nullable field $parentType.$fieldName of type $fieldType. " +
-          "If the schema is inferred from a Scala tuple/case class, or a Java bean, " +
-          "please try to use scala.Option[_] or other nullable types " +
-          "(e.g. java.lang.Integer instead of int/scala.Int)."
-        );
+        throw new RuntimeException((String) references[$idx]);
       }
      """
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/8e4d15f7/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
index 92a68a4..8b02b63 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
@@ -21,9 +21,11 @@ import scala.reflect.runtime.universe.TypeTag
 
 import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.dsl.expressions._
-import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.PlanTest
+import org.apache.spark.sql.catalyst.util.GenericArrayData
+import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
 
 case class StringLongClass(a: String, b: Long)
 
@@ -32,94 +34,49 @@ case class StringIntClass(a: String, b: Int)
 case class ComplexClass(a: Long, b: StringLongClass)
 
 class EncoderResolutionSuite extends PlanTest {
+  private val str = UTF8String.fromString("hello")
+
   test("real type doesn't match encoder schema but they are compatible: product") {
     val encoder = ExpressionEncoder[StringLongClass]
-    val cls = classOf[StringLongClass]
-
 
-    {
-      val attrs = Seq('a.string, 'b.int)
-      val fromRowExpr: Expression = encoder.resolve(attrs, null).fromRowExpression
-      val expected: Expression = NewInstance(
-        cls,
-        Seq(
-          toExternalString('a.string),
-          AssertNotNull('b.int.cast(LongType), cls.getName, "b", "Long")
-        ),
-        ObjectType(cls),
-        propagateNull = false)
-      compareExpressions(fromRowExpr, expected)
-    }
+    // int type can be up cast to long type
+    val attrs1 = Seq('a.string, 'b.int)
+    encoder.resolve(attrs1, null).bind(attrs1).fromRow(InternalRow(str, 1))
 
-    {
-      val attrs = Seq('a.int, 'b.long)
-      val fromRowExpr = encoder.resolve(attrs, null).fromRowExpression
-      val expected = NewInstance(
-        cls,
-        Seq(
-          toExternalString('a.int.cast(StringType)),
-          AssertNotNull('b.long, cls.getName, "b", "Long")
-        ),
-        ObjectType(cls),
-        propagateNull = false)
-      compareExpressions(fromRowExpr, expected)
-    }
+    // int type can be up cast to string type
+    val attrs2 = Seq('a.int, 'b.long)
+    encoder.resolve(attrs2, null).bind(attrs2).fromRow(InternalRow(1, 2L))
   }
 
   test("real type doesn't match encoder schema but they are compatible: nested product") {
     val encoder = ExpressionEncoder[ComplexClass]
-    val innerCls = classOf[StringLongClass]
-    val cls = classOf[ComplexClass]
-
     val attrs = Seq('a.int, 'b.struct('a.int, 'b.long))
-    val fromRowExpr: Expression = encoder.resolve(attrs, null).fromRowExpression
-    val expected: Expression = NewInstance(
-      cls,
-      Seq(
-        AssertNotNull('a.int.cast(LongType), cls.getName, "a", "Long"),
-        If(
-          'b.struct('a.int, 'b.long).isNull,
-          Literal.create(null, ObjectType(innerCls)),
-          NewInstance(
-            innerCls,
-            Seq(
-              toExternalString(
-                GetStructField('b.struct('a.int, 'b.long), 0, Some("a")).cast(StringType)),
-              AssertNotNull(
-                GetStructField('b.struct('a.int, 'b.long), 1, Some("b")),
-                innerCls.getName, "b", "Long")),
-            ObjectType(innerCls),
-            propagateNull = false)
-        )),
-      ObjectType(cls),
-      propagateNull = false)
-    compareExpressions(fromRowExpr, expected)
+    encoder.resolve(attrs, null).bind(attrs).fromRow(InternalRow(1, InternalRow(2, 3L)))
   }
 
   test("real type doesn't match encoder schema but they are compatible: tupled encoder") {
     val encoder = ExpressionEncoder.tuple(
       ExpressionEncoder[StringLongClass],
       ExpressionEncoder[Long])
-    val cls = classOf[StringLongClass]
-
     val attrs = Seq('a.struct('a.string, 'b.byte), 'b.int)
-    val fromRowExpr: Expression = encoder.resolve(attrs, null).fromRowExpression
-    val expected: Expression = NewInstance(
-      classOf[Tuple2[_, _]],
-      Seq(
-        NewInstance(
-          cls,
-          Seq(
-            toExternalString(GetStructField('a.struct('a.string, 'b.byte), 0, Some("a"))),
-            AssertNotNull(
-              GetStructField('a.struct('a.string, 'b.byte), 1, Some("b")).cast(LongType),
-              cls.getName, "b", "Long")),
-          ObjectType(cls),
-          propagateNull = false),
-        'b.int.cast(LongType)),
-      ObjectType(classOf[Tuple2[_, _]]),
-      propagateNull = false)
-    compareExpressions(fromRowExpr, expected)
+    encoder.resolve(attrs, null).bind(attrs).fromRow(InternalRow(InternalRow(str, 1.toByte), 2))
+  }
+
+  test("nullability of array type element should not fail analysis") {
+    val encoder = ExpressionEncoder[Seq[Int]]
+    val attrs = 'a.array(IntegerType) :: Nil
+
+    // It should pass analysis
+    val bound = encoder.resolve(attrs, null).bind(attrs)
+
+    // If no null values appear, it should works fine
+    bound.fromRow(InternalRow(new GenericArrayData(Array(1, 2))))
+
+    // If there is null value, it should throw runtime exception
+    val e = intercept[RuntimeException] {
+      bound.fromRow(InternalRow(new GenericArrayData(Array(1, null))))
+    }
+    assert(e.getMessage.contains("Null value appeared in non-nullable field"))
   }
 
   test("the real number of fields doesn't match encoder schema: tuple encoder") {
@@ -166,10 +123,6 @@ class EncoderResolutionSuite extends PlanTest {
     }
   }
 
-  private def toExternalString(e: Expression): Expression = {
-    Invoke(e, "toString", ObjectType(classOf[String]), Nil)
-  }
-
   test("throw exception if real type is not compatible with encoder schema") {
     val msg1 = intercept[AnalysisException] {
       ExpressionEncoder[StringIntClass].resolve(Seq('a.string, 'b.long), null)

http://git-wip-us.apache.org/repos/asf/spark/blob/8e4d15f7/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
----------------------------------------------------------------------
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index a6fb62c..1181244 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -850,9 +850,7 @@ public class JavaDatasetSuite implements Serializable {
     }
 
     nullabilityCheck.expect(RuntimeException.class);
-    nullabilityCheck.expectMessage(
-      "Null value appeared in non-nullable field " +
-        "test.org.apache.spark.sql.JavaDatasetSuite$SmallBean.b of type int.");
+    nullabilityCheck.expectMessage("Null value appeared in non-nullable field");
 
     {
       Row row = new GenericRow(new Object[] {

http://git-wip-us.apache.org/repos/asf/spark/blob/8e4d15f7/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 374f432..f9ba607 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -553,9 +553,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
       buildDataset(Row(Row("hello", null))).collect()
     }.getMessage
 
-    assert(message.contains(
-      "Null value appeared in non-nullable field org.apache.spark.sql.ClassData.b of type Int."
-    ))
+    assert(message.contains("Null value appeared in non-nullable field"))
   }
 
   test("SPARK-12478: top level null field") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org