You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2017/07/05 12:32:57 UTC

spark git commit: [SPARK-16167][SQL] RowEncoder should preserve array/map type nullability.

Repository: spark
Updated Branches:
  refs/heads/master 4852b7d44 -> 873f3ad2b


[SPARK-16167][SQL] RowEncoder should preserve array/map type nullability.

## What changes were proposed in this pull request?

Currently `RowEncoder` doesn't preserve nullability of `ArrayType` or `MapType`.
It returns always `containsNull = true` for `ArrayType`, `valueContainsNull = true` for `MapType` and also the nullability of itself is always `true`.

This pr fixes the nullability of them.
## How was this patch tested?

Add tests to check if `RowEncoder` preserves array/map nullability.

Author: Takuya UESHIN <ue...@happy-camper.st>
Author: Takuya UESHIN <ue...@databricks.com>

Closes #13873 from ueshin/issues/SPARK-16167.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/873f3ad2
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/873f3ad2
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/873f3ad2

Branch: refs/heads/master
Commit: 873f3ad2b89c955f42fced49dc129e8efa77d044
Parents: 4852b7d
Author: Takuya UESHIN <ue...@happy-camper.st>
Authored: Wed Jul 5 20:32:47 2017 +0800
Committer: Wenchen Fan <we...@databricks.com>
Committed: Wed Jul 5 20:32:47 2017 +0800

----------------------------------------------------------------------
 .../sql/catalyst/encoders/RowEncoder.scala      | 25 ++++++++++++---
 .../sql/catalyst/encoders/RowEncoderSuite.scala | 33 ++++++++++++++++++++
 2 files changed, 54 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/873f3ad2/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
index cc32fac..43c35bb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
@@ -123,7 +123,7 @@ object RowEncoder {
         inputObject :: Nil,
         returnNullable = false)
 
-    case t @ ArrayType(et, cn) =>
+    case t @ ArrayType(et, containsNull) =>
       et match {
         case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType =>
           StaticInvoke(
@@ -132,8 +132,16 @@ object RowEncoder {
             "toArrayData",
             inputObject :: Nil,
             returnNullable = false)
+
         case _ => MapObjects(
-          element => serializerFor(ValidateExternalType(element, et), et),
+          element => {
+            val value = serializerFor(ValidateExternalType(element, et), et)
+            if (!containsNull) {
+              AssertNotNull(value, Seq.empty)
+            } else {
+              value
+            }
+          },
           inputObject,
           ObjectType(classOf[Object]))
       }
@@ -155,10 +163,19 @@ object RowEncoder {
           ObjectType(classOf[scala.collection.Seq[_]]), returnNullable = false)
       val convertedValues = serializerFor(values, ArrayType(vt, valueNullable))
 
-      NewInstance(
+      val nonNullOutput = NewInstance(
         classOf[ArrayBasedMapData],
         convertedKeys :: convertedValues :: Nil,
-        dataType = t)
+        dataType = t,
+        propagateNull = false)
+
+      if (inputObject.nullable) {
+        If(IsNull(inputObject),
+          Literal.create(null, inputType),
+          nonNullOutput)
+      } else {
+        nonNullOutput
+      }
 
     case StructType(fields) =>
       val nonNullOutput = CreateNamedStruct(fields.zipWithIndex.flatMap { case (field, index) =>

http://git-wip-us.apache.org/repos/asf/spark/blob/873f3ad2/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
index 1a5569a..6ed175f 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
@@ -273,6 +273,39 @@ class RowEncoderSuite extends SparkFunSuite {
     assert(e4.getMessage.contains("java.lang.String is not a valid external type"))
   }
 
+  for {
+    elementType <- Seq(IntegerType, StringType)
+    containsNull <- Seq(true, false)
+    nullable <- Seq(true, false)
+  } {
+    test("RowEncoder should preserve array nullability: " +
+      s"ArrayType($elementType, containsNull = $containsNull), nullable = $nullable") {
+      val schema = new StructType().add("array", ArrayType(elementType, containsNull), nullable)
+      val encoder = RowEncoder(schema).resolveAndBind()
+      assert(encoder.serializer.length == 1)
+      assert(encoder.serializer.head.dataType == ArrayType(elementType, containsNull))
+      assert(encoder.serializer.head.nullable == nullable)
+    }
+  }
+
+  for {
+    keyType <- Seq(IntegerType, StringType)
+    valueType <- Seq(IntegerType, StringType)
+    valueContainsNull <- Seq(true, false)
+    nullable <- Seq(true, false)
+  } {
+    test("RowEncoder should preserve map nullability: " +
+      s"MapType($keyType, $valueType, valueContainsNull = $valueContainsNull), " +
+      s"nullable = $nullable") {
+      val schema = new StructType().add(
+        "map", MapType(keyType, valueType, valueContainsNull), nullable)
+      val encoder = RowEncoder(schema).resolveAndBind()
+      assert(encoder.serializer.length == 1)
+      assert(encoder.serializer.head.dataType == MapType(keyType, valueType, valueContainsNull))
+      assert(encoder.serializer.head.nullable == nullable)
+    }
+  }
+
   private def encodeDecodeTest(schema: StructType): Unit = {
     test(s"encode/decode: ${schema.simpleString}") {
       val encoder = RowEncoder(schema).resolveAndBind()


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org