You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2016/02/04 01:13:27 UTC

spark git commit: [SPARK-13101][SQL][BRANCH-1.6] nullability of array type element should not fail analysis of encoder

Repository: spark
Updated Branches:
  refs/heads/branch-1.6 5fe8796c2 -> cdfb2a141


[SPARK-13101][SQL][BRANCH-1.6] nullability of array type element should not fail analysis of encoder

nullability should only be considered as an optimization rather than part of the type system, so instead of failing analysis for mismatch nullability, we should pass analysis and add runtime null check.

backport https://github.com/apache/spark/pull/11035 to 1.6

Author: Wenchen Fan <we...@databricks.com>

Closes #11042 from cloud-fan/branch-1.6.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/cdfb2a14
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/cdfb2a14
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/cdfb2a14

Branch: refs/heads/branch-1.6
Commit: cdfb2a1410aa799596c8b751187dbac28b2cc678
Parents: 5fe8796
Author: Wenchen Fan <we...@databricks.com>
Authored: Wed Feb 3 16:13:23 2016 -0800
Committer: Michael Armbrust <mi...@databricks.com>
Committed: Wed Feb 3 16:13:23 2016 -0800

----------------------------------------------------------------------
 .../spark/sql/catalyst/JavaTypeInference.scala  |   2 +-
 .../spark/sql/catalyst/ScalaReflection.scala    |  29 +++--
 .../spark/sql/catalyst/analysis/Analyzer.scala  |   2 +-
 .../sql/catalyst/expressions/objects.scala      |  16 ++-
 .../encoders/EncoderResolutionSuite.scala       | 107 ++++++-------------
 .../org/apache/spark/sql/JavaDatasetSuite.java  |   4 +-
 .../org/apache/spark/sql/DatasetSuite.scala     |  13 +--
 7 files changed, 69 insertions(+), 104 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/cdfb2a14/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
index 4c00803..a7d6a3d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
@@ -293,7 +293,7 @@ object JavaTypeInference {
           val setter = if (nullable) {
             constructor
           } else {
-            AssertNotNull(constructor, other.getName, fieldName, fieldType.toString)
+            AssertNotNull(constructor, Seq("currently no type path record in java"))
           }
           p.getWriteMethod.getName -> setter
         }.toMap

http://git-wip-us.apache.org/repos/asf/spark/blob/cdfb2a14/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index b0efdf3..8722191 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -249,6 +249,8 @@ object ScalaReflection extends ScalaReflection {
 
       case t if t <:< localTypeOf[Array[_]] =>
         val TypeRef(_, _, Seq(elementType)) = t
+
+        // TODO: add runtime null check for primitive array
         val primitiveMethod = elementType match {
           case t if t <:< definitions.IntTpe => Some("toIntArray")
           case t if t <:< definitions.LongTpe => Some("toLongArray")
@@ -276,22 +278,29 @@ object ScalaReflection extends ScalaReflection {
 
       case t if t <:< localTypeOf[Seq[_]] =>
         val TypeRef(_, _, Seq(elementType)) = t
+        val Schema(dataType, nullable) = schemaFor(elementType)
         val className = getClassNameFromType(elementType)
         val newTypePath = s"""- array element class: "$className"""" +: walkedTypePath
-        val arrayData =
-          Invoke(
-            MapObjects(
-              p => constructorFor(elementType, Some(p), newTypePath),
-              getPath,
-              schemaFor(elementType).dataType),
-            "array",
-            ObjectType(classOf[Array[Any]]))
+
+        val mapFunction: Expression => Expression = p => {
+          val converter = constructorFor(elementType, Some(p), newTypePath)
+          if (nullable) {
+            converter
+          } else {
+            AssertNotNull(converter, newTypePath)
+          }
+        }
+
+        val array = Invoke(
+          MapObjects(mapFunction, getPath, dataType),
+          "array",
+          ObjectType(classOf[Array[Any]]))
 
         StaticInvoke(
           scala.collection.mutable.WrappedArray.getClass,
           ObjectType(classOf[Seq[_]]),
           "make",
-          arrayData :: Nil)
+          array :: Nil)
 
       case t if t <:< localTypeOf[Map[_, _]] =>
         // TODO: add walked type path for map
@@ -343,7 +352,7 @@ object ScalaReflection extends ScalaReflection {
               newTypePath)
 
             if (!nullable) {
-              AssertNotNull(constructor, t.toString, fieldName, fieldType.toString)
+              AssertNotNull(constructor, newTypePath)
             } else {
               constructor
             }

http://git-wip-us.apache.org/repos/asf/spark/blob/cdfb2a14/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 8c3cdfb..bc62c7f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -1223,7 +1223,7 @@ object ResolveUpCast extends Rule[LogicalPlan] {
           fail(child, DateType, walkedTypePath)
         case (StringType, to: NumericType) =>
           fail(child, to, walkedTypePath)
-        case _ => Cast(child, dataType)
+        case _ => Cast(child, dataType.asNullable)
       }
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/cdfb2a14/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
index c0c3e6e..f4b0cdc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
@@ -361,7 +361,7 @@ object MapObjects {
  *                       to handle collection elements.
  * @param inputData An expression that when evaluted returns a collection object.
  */
-case class MapObjects(
+case class MapObjects private(
     loopVar: LambdaVariable,
     lambdaFunction: Expression,
     inputData: Expression) extends Expression {
@@ -633,8 +633,7 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp
  * `Int` field named `i`.  Expression `s.i` is nullable because `s` can be null.  However, for all
  * non-null `s`, `s.i` can't be null.
  */
-case class AssertNotNull(
-    child: Expression, parentType: String, fieldName: String, fieldType: String)
+case class AssertNotNull(child: Expression, walkedTypePath: Seq[String])
   extends UnaryExpression {
 
   override def dataType: DataType = child.dataType
@@ -647,6 +646,14 @@ case class AssertNotNull(
   override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
     val childGen = child.gen(ctx)
 
+    // This is going to be a string literal in generated java code, so we should escape `"` by `\"`
+    // and wrap every line with `"` at left side and `\n"` at right side, and finally concat them by
+    // ` + `.
+    val typePathString = walkedTypePath
+      .map(s => s.replaceAll("\"", "\\\\\""))
+      .map(s => '"' + s + "\\n\"")
+      .mkString(" + ")
+
     ev.isNull = "false"
     ev.value = childGen.value
 
@@ -655,7 +662,8 @@ case class AssertNotNull(
 
       if (${childGen.isNull}) {
         throw new RuntimeException(
-          "Null value appeared in non-nullable field $parentType.$fieldName of type $fieldType. " +
+          "Null value appeared in non-nullable field:\\n" +
+          $typePathString +
           "If the schema is inferred from a Scala tuple/case class, or a Java bean, " +
           "please try to use scala.Option[_] or other nullable types " +
           "(e.g. java.lang.Integer instead of int/scala.Int)."

http://git-wip-us.apache.org/repos/asf/spark/blob/cdfb2a14/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
index bc36a55..1d7a708 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
@@ -21,9 +21,11 @@ import scala.reflect.runtime.universe.TypeTag
 
 import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.dsl.expressions._
-import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.PlanTest
+import org.apache.spark.sql.catalyst.util.GenericArrayData
+import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
 
 case class StringLongClass(a: String, b: Long)
 
@@ -32,94 +34,49 @@ case class StringIntClass(a: String, b: Int)
 case class ComplexClass(a: Long, b: StringLongClass)
 
 class EncoderResolutionSuite extends PlanTest {
+  private val str = UTF8String.fromString("hello")
+
   test("real type doesn't match encoder schema but they are compatible: product") {
     val encoder = ExpressionEncoder[StringLongClass]
-    val cls = classOf[StringLongClass]
-
 
-    {
-      val attrs = Seq('a.string, 'b.int)
-      val fromRowExpr: Expression = encoder.resolve(attrs, null).fromRowExpression
-      val expected: Expression = NewInstance(
-        cls,
-        Seq(
-          toExternalString('a.string),
-          AssertNotNull('b.int.cast(LongType), cls.getName, "b", "Long")
-        ),
-        ObjectType(cls),
-        propagateNull = false)
-      compareExpressions(fromRowExpr, expected)
-    }
+    // int type can be up cast to long type
+    val attrs1 = Seq('a.string, 'b.int)
+    encoder.resolve(attrs1, null).bind(attrs1).fromRow(InternalRow(str, 1))
 
-    {
-      val attrs = Seq('a.int, 'b.long)
-      val fromRowExpr = encoder.resolve(attrs, null).fromRowExpression
-      val expected = NewInstance(
-        cls,
-        Seq(
-          toExternalString('a.int.cast(StringType)),
-          AssertNotNull('b.long, cls.getName, "b", "Long")
-        ),
-        ObjectType(cls),
-        propagateNull = false)
-      compareExpressions(fromRowExpr, expected)
-    }
+    // int type can be up cast to string type
+    val attrs2 = Seq('a.int, 'b.long)
+    encoder.resolve(attrs2, null).bind(attrs2).fromRow(InternalRow(1, 2L))
   }
 
   test("real type doesn't match encoder schema but they are compatible: nested product") {
     val encoder = ExpressionEncoder[ComplexClass]
-    val innerCls = classOf[StringLongClass]
-    val cls = classOf[ComplexClass]
-
     val attrs = Seq('a.int, 'b.struct('a.int, 'b.long))
-    val fromRowExpr: Expression = encoder.resolve(attrs, null).fromRowExpression
-    val expected: Expression = NewInstance(
-      cls,
-      Seq(
-        AssertNotNull('a.int.cast(LongType), cls.getName, "a", "Long"),
-        If(
-          'b.struct('a.int, 'b.long).isNull,
-          Literal.create(null, ObjectType(innerCls)),
-          NewInstance(
-            innerCls,
-            Seq(
-              toExternalString(
-                GetStructField('b.struct('a.int, 'b.long), 0, Some("a")).cast(StringType)),
-              AssertNotNull(
-                GetStructField('b.struct('a.int, 'b.long), 1, Some("b")),
-                innerCls.getName, "b", "Long")),
-            ObjectType(innerCls),
-            propagateNull = false)
-        )),
-      ObjectType(cls),
-      propagateNull = false)
-    compareExpressions(fromRowExpr, expected)
+    encoder.resolve(attrs, null).bind(attrs).fromRow(InternalRow(1, InternalRow(2, 3L)))
   }
 
   test("real type doesn't match encoder schema but they are compatible: tupled encoder") {
     val encoder = ExpressionEncoder.tuple(
       ExpressionEncoder[StringLongClass],
       ExpressionEncoder[Long])
-    val cls = classOf[StringLongClass]
-
     val attrs = Seq('a.struct('a.string, 'b.byte), 'b.int)
-    val fromRowExpr: Expression = encoder.resolve(attrs, null).fromRowExpression
-    val expected: Expression = NewInstance(
-      classOf[Tuple2[_, _]],
-      Seq(
-        NewInstance(
-          cls,
-          Seq(
-            toExternalString(GetStructField('a.struct('a.string, 'b.byte), 0, Some("a"))),
-            AssertNotNull(
-              GetStructField('a.struct('a.string, 'b.byte), 1, Some("b")).cast(LongType),
-              cls.getName, "b", "Long")),
-          ObjectType(cls),
-          propagateNull = false),
-        'b.int.cast(LongType)),
-      ObjectType(classOf[Tuple2[_, _]]),
-      propagateNull = false)
-    compareExpressions(fromRowExpr, expected)
+    encoder.resolve(attrs, null).bind(attrs).fromRow(InternalRow(InternalRow(str, 1.toByte), 2))
+  }
+
+  test("nullability of array type element should not fail analysis") {
+    val encoder = ExpressionEncoder[Seq[Int]]
+    val attrs = 'a.array(IntegerType) :: Nil
+
+    // It should pass analysis
+    val bound = encoder.resolve(attrs, null).bind(attrs)
+
+    // If no null values appear, it should works fine
+    bound.fromRow(InternalRow(new GenericArrayData(Array(1, 2))))
+
+    // If there is null value, it should throw runtime exception
+    val e = intercept[RuntimeException] {
+      bound.fromRow(InternalRow(new GenericArrayData(Array(1, null))))
+    }
+    assert(e.getMessage.contains("Null value appeared in non-nullable field"))
   }
 
   test("the real number of fields doesn't match encoder schema: tuple encoder") {
@@ -166,10 +123,6 @@ class EncoderResolutionSuite extends PlanTest {
     }
   }
 
-  private def toExternalString(e: Expression): Expression = {
-    Invoke(e, "toString", ObjectType(classOf[String]), Nil)
-  }
-
   test("throw exception if real type is not compatible with encoder schema") {
     val msg1 = intercept[AnalysisException] {
       ExpressionEncoder[StringIntClass].resolve(Seq('a.string, 'b.long), null)

http://git-wip-us.apache.org/repos/asf/spark/blob/cdfb2a14/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
----------------------------------------------------------------------
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index 9f8db39..9fe0c39 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -848,9 +848,7 @@ public class JavaDatasetSuite implements Serializable {
     }
 
     nullabilityCheck.expect(RuntimeException.class);
-    nullabilityCheck.expectMessage(
-      "Null value appeared in non-nullable field " +
-        "test.org.apache.spark.sql.JavaDatasetSuite$SmallBean.b of type int.");
+    nullabilityCheck.expectMessage("Null value appeared in non-nullable field");
 
     {
       Row row = new GenericRow(new Object[] {

http://git-wip-us.apache.org/repos/asf/spark/blob/cdfb2a14/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index c19b5a4..ff9cd27 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -44,13 +44,12 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
       1, 1, 1)
   }
 
-
   test("SPARK-12404: Datatype Helper Serializablity") {
     val ds = sparkContext.parallelize((
-          new Timestamp(0),
-          new Date(0),
-          java.math.BigDecimal.valueOf(1),
-          scala.math.BigDecimal(1)) :: Nil).toDS()
+      new Timestamp(0),
+      new Date(0),
+      java.math.BigDecimal.valueOf(1),
+      scala.math.BigDecimal(1)) :: Nil).toDS()
 
     ds.collect()
   }
@@ -542,9 +541,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
       buildDataset(Row(Row("hello", null))).collect()
     }.getMessage
 
-    assert(message.contains(
-      "Null value appeared in non-nullable field org.apache.spark.sql.ClassData.b of type Int."
-    ))
+    assert(message.contains("Null value appeared in non-nullable field"))
   }
 
   test("SPARK-12478: top level null field") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org