You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2017/04/10 02:47:19 UTC

spark git commit: [SPARK-20253][SQL] Remove unnecessary nullchecks of a return value from Spark runtime routines in generated Java code

Repository: spark
Updated Branches:
  refs/heads/master 261eaf514 -> 7a63f5e82


[SPARK-20253][SQL] Remove unnecessary nullchecks of a return value from Spark runtime routines in generated Java code

## What changes were proposed in this pull request?

This PR elminates unnecessary nullchecks of a return value from known Spark runtime routines. We know whether a given Spark runtime routine returns ``null`` or not (e.g. ``ArrayData.toDoubleArray()`` never returns ``null``). Thus, we can eliminate a null check for the return value from the Spark runtime routine.

When we run the following example program, now we get the Java code "Without this PR". In this code, since we know ``ArrayData.toDoubleArray()`` never returns ``null```, we can eliminate null checks at lines 90-92, and 97.

```java
val ds = sparkContext.parallelize(Seq(Array(1.1, 2.2)), 1).toDS.cache
ds.count
ds.map(e => e).show
```

Without this PR
```java
/* 050 */   protected void processNext() throws java.io.IOException {
/* 051 */     while (inputadapter_input.hasNext() && !stopEarly()) {
/* 052 */       InternalRow inputadapter_row = (InternalRow) inputadapter_input.next();
/* 053 */       boolean inputadapter_isNull = inputadapter_row.isNullAt(0);
/* 054 */       ArrayData inputadapter_value = inputadapter_isNull ? null : (inputadapter_row.getArray(0));
/* 055 */
/* 056 */       ArrayData deserializetoobject_value1 = null;
/* 057 */
/* 058 */       if (!inputadapter_isNull) {
/* 059 */         int deserializetoobject_dataLength = inputadapter_value.numElements();
/* 060 */
/* 061 */         Double[] deserializetoobject_convertedArray = null;
/* 062 */         deserializetoobject_convertedArray = new Double[deserializetoobject_dataLength];
/* 063 */
/* 064 */         int deserializetoobject_loopIndex = 0;
/* 065 */         while (deserializetoobject_loopIndex < deserializetoobject_dataLength) {
/* 066 */           MapObjects_loopValue2 = (double) (inputadapter_value.getDouble(deserializetoobject_loopIndex));
/* 067 */           MapObjects_loopIsNull2 = inputadapter_value.isNullAt(deserializetoobject_loopIndex);
/* 068 */
/* 069 */           if (MapObjects_loopIsNull2) {
/* 070 */             throw new RuntimeException(((java.lang.String) references[0]));
/* 071 */           }
/* 072 */           if (false) {
/* 073 */             deserializetoobject_convertedArray[deserializetoobject_loopIndex] = null;
/* 074 */           } else {
/* 075 */             deserializetoobject_convertedArray[deserializetoobject_loopIndex] = MapObjects_loopValue2;
/* 076 */           }
/* 077 */
/* 078 */           deserializetoobject_loopIndex += 1;
/* 079 */         }
/* 080 */
/* 081 */         deserializetoobject_value1 = new org.apache.spark.sql.catalyst.util.GenericArrayData(deserializetoobject_convertedArray); /*###*/
/* 082 */       }
/* 083 */       boolean deserializetoobject_isNull = true;
/* 084 */       double[] deserializetoobject_value = null;
/* 085 */       if (!inputadapter_isNull) {
/* 086 */         deserializetoobject_isNull = false;
/* 087 */         if (!deserializetoobject_isNull) {
/* 088 */           Object deserializetoobject_funcResult = null;
/* 089 */           deserializetoobject_funcResult = deserializetoobject_value1.toDoubleArray();
/* 090 */           if (deserializetoobject_funcResult == null) {
/* 091 */             deserializetoobject_isNull = true;
/* 092 */           } else {
/* 093 */             deserializetoobject_value = (double[]) deserializetoobject_funcResult;
/* 094 */           }
/* 095 */
/* 096 */         }
/* 097 */         deserializetoobject_isNull = deserializetoobject_value == null;
/* 098 */       }
/* 099 */
/* 100 */       boolean mapelements_isNull = true;
/* 101 */       double[] mapelements_value = null;
/* 102 */       if (!false) {
/* 103 */         mapelements_resultIsNull = false;
/* 104 */
/* 105 */         if (!mapelements_resultIsNull) {
/* 106 */           mapelements_resultIsNull = deserializetoobject_isNull;
/* 107 */           mapelements_argValue = deserializetoobject_value;
/* 108 */         }
/* 109 */
/* 110 */         mapelements_isNull = mapelements_resultIsNull;
/* 111 */         if (!mapelements_isNull) {
/* 112 */           Object mapelements_funcResult = null;
/* 113 */           mapelements_funcResult = ((scala.Function1) references[1]).apply(mapelements_argValue);
/* 114 */           if (mapelements_funcResult == null) {
/* 115 */             mapelements_isNull = true;
/* 116 */           } else {
/* 117 */             mapelements_value = (double[]) mapelements_funcResult;
/* 118 */           }
/* 119 */
/* 120 */         }
/* 121 */         mapelements_isNull = mapelements_value == null;
/* 122 */       }
/* 123 */
/* 124 */       serializefromobject_resultIsNull = false;
/* 125 */
/* 126 */       if (!serializefromobject_resultIsNull) {
/* 127 */         serializefromobject_resultIsNull = mapelements_isNull;
/* 128 */         serializefromobject_argValue = mapelements_value;
/* 129 */       }
/* 130 */
/* 131 */       boolean serializefromobject_isNull = serializefromobject_resultIsNull;
/* 132 */       final ArrayData serializefromobject_value = serializefromobject_resultIsNull ? null : org.apache.spark.sql.catalyst.expressions.UnsafeArrayData.fromPrimitiveArray(serializefromobject_argValue);
/* 133 */       serializefromobject_isNull = serializefromobject_value == null;
/* 134 */       serializefromobject_holder.reset();
/* 135 */
/* 136 */       serializefromobject_rowWriter.zeroOutNullBytes();
/* 137 */
/* 138 */       if (serializefromobject_isNull) {
/* 139 */         serializefromobject_rowWriter.setNullAt(0);
/* 140 */       } else {
/* 141 */         // Remember the current cursor so that we can calculate how many bytes are
/* 142 */         // written later.
/* 143 */         final int serializefromobject_tmpCursor = serializefromobject_holder.cursor;
/* 144 */
/* 145 */         if (serializefromobject_value instanceof UnsafeArrayData) {
/* 146 */           final int serializefromobject_sizeInBytes = ((UnsafeArrayData) serializefromobject_value).getSizeInBytes();
/* 147 */           // grow the global buffer before writing data.
/* 148 */           serializefromobject_holder.grow(serializefromobject_sizeInBytes);
/* 149 */           ((UnsafeArrayData) serializefromobject_value).writeToMemory(serializefromobject_holder.buffer, serializefromobject_holder.cursor);
/* 150 */           serializefromobject_holder.cursor += serializefromobject_sizeInBytes;
/* 151 */
/* 152 */         } else {
/* 153 */           final int serializefromobject_numElements = serializefromobject_value.numElements();
/* 154 */           serializefromobject_arrayWriter.initialize(serializefromobject_holder, serializefromobject_numElements, 8);
/* 155 */
/* 156 */           for (int serializefromobject_index = 0; serializefromobject_index < serializefromobject_numElements; serializefromobject_index++) {
/* 157 */             if (serializefromobject_value.isNullAt(serializefromobject_index)) {
/* 158 */               serializefromobject_arrayWriter.setNullDouble(serializefromobject_index);
/* 159 */             } else {
/* 160 */               final double serializefromobject_element = serializefromobject_value.getDouble(serializefromobject_index);
/* 161 */               serializefromobject_arrayWriter.write(serializefromobject_index, serializefromobject_element);
/* 162 */             }
/* 163 */           }
/* 164 */         }
/* 165 */
/* 166 */         serializefromobject_rowWriter.setOffsetAndSize(0, serializefromobject_tmpCursor, serializefromobject_holder.cursor - serializefromobject_tmpCursor);
/* 167 */       }
/* 168 */       serializefromobject_result.setTotalSize(serializefromobject_holder.totalSize());
/* 169 */       append(serializefromobject_result);
/* 170 */       if (shouldStop()) return;
/* 171 */     }
/* 172 */   }
```

With this PR (removed most of lines 90-97 in the above code)
```java
/* 050 */   protected void processNext() throws java.io.IOException {
/* 051 */     while (inputadapter_input.hasNext() && !stopEarly()) {
/* 052 */       InternalRow inputadapter_row = (InternalRow) inputadapter_input.next();
/* 053 */       boolean inputadapter_isNull = inputadapter_row.isNullAt(0);
/* 054 */       ArrayData inputadapter_value = inputadapter_isNull ? null : (inputadapter_row.getArray(0));
/* 055 */
/* 056 */       ArrayData deserializetoobject_value1 = null;
/* 057 */
/* 058 */       if (!inputadapter_isNull) {
/* 059 */         int deserializetoobject_dataLength = inputadapter_value.numElements();
/* 060 */
/* 061 */         Double[] deserializetoobject_convertedArray = null;
/* 062 */         deserializetoobject_convertedArray = new Double[deserializetoobject_dataLength];
/* 063 */
/* 064 */         int deserializetoobject_loopIndex = 0;
/* 065 */         while (deserializetoobject_loopIndex < deserializetoobject_dataLength) {
/* 066 */           MapObjects_loopValue2 = (double) (inputadapter_value.getDouble(deserializetoobject_loopIndex));
/* 067 */           MapObjects_loopIsNull2 = inputadapter_value.isNullAt(deserializetoobject_loopIndex);
/* 068 */
/* 069 */           if (MapObjects_loopIsNull2) {
/* 070 */             throw new RuntimeException(((java.lang.String) references[0]));
/* 071 */           }
/* 072 */           if (false) {
/* 073 */             deserializetoobject_convertedArray[deserializetoobject_loopIndex] = null;
/* 074 */           } else {
/* 075 */             deserializetoobject_convertedArray[deserializetoobject_loopIndex] = MapObjects_loopValue2;
/* 076 */           }
/* 077 */
/* 078 */           deserializetoobject_loopIndex += 1;
/* 079 */         }
/* 080 */
/* 081 */         deserializetoobject_value1 = new org.apache.spark.sql.catalyst.util.GenericArrayData(deserializetoobject_convertedArray); /*###*/
/* 082 */       }
/* 083 */       boolean deserializetoobject_isNull = true;
/* 084 */       double[] deserializetoobject_value = null;
/* 085 */       if (!inputadapter_isNull) {
/* 086 */         deserializetoobject_isNull = false;
/* 087 */         if (!deserializetoobject_isNull) {
/* 088 */           Object deserializetoobject_funcResult = null;
/* 089 */           deserializetoobject_funcResult = deserializetoobject_value1.toDoubleArray();
/* 090 */           deserializetoobject_value = (double[]) deserializetoobject_funcResult;
/* 091 */
/* 092 */         }
/* 093 */
/* 094 */       }
/* 095 */
/* 096 */       boolean mapelements_isNull = true;
/* 097 */       double[] mapelements_value = null;
/* 098 */       if (!false) {
/* 099 */         mapelements_resultIsNull = false;
/* 100 */
/* 101 */         if (!mapelements_resultIsNull) {
/* 102 */           mapelements_resultIsNull = deserializetoobject_isNull;
/* 103 */           mapelements_argValue = deserializetoobject_value;
/* 104 */         }
/* 105 */
/* 106 */         mapelements_isNull = mapelements_resultIsNull;
/* 107 */         if (!mapelements_isNull) {
/* 108 */           Object mapelements_funcResult = null;
/* 109 */           mapelements_funcResult = ((scala.Function1) references[1]).apply(mapelements_argValue);
/* 110 */           if (mapelements_funcResult == null) {
/* 111 */             mapelements_isNull = true;
/* 112 */           } else {
/* 113 */             mapelements_value = (double[]) mapelements_funcResult;
/* 114 */           }
/* 115 */
/* 116 */         }
/* 117 */         mapelements_isNull = mapelements_value == null;
/* 118 */       }
/* 119 */
/* 120 */       serializefromobject_resultIsNull = false;
/* 121 */
/* 122 */       if (!serializefromobject_resultIsNull) {
/* 123 */         serializefromobject_resultIsNull = mapelements_isNull;
/* 124 */         serializefromobject_argValue = mapelements_value;
/* 125 */       }
/* 126 */
/* 127 */       boolean serializefromobject_isNull = serializefromobject_resultIsNull;
/* 128 */       final ArrayData serializefromobject_value = serializefromobject_resultIsNull ? null : org.apache.spark.sql.catalyst.expressions.UnsafeArrayData.fromPrimitiveArray(serializefromobject_argValue);
/* 129 */       serializefromobject_isNull = serializefromobject_value == null;
/* 130 */       serializefromobject_holder.reset();
/* 131 */
/* 132 */       serializefromobject_rowWriter.zeroOutNullBytes();
/* 133 */
/* 134 */       if (serializefromobject_isNull) {
/* 135 */         serializefromobject_rowWriter.setNullAt(0);
/* 136 */       } else {
/* 137 */         // Remember the current cursor so that we can calculate how many bytes are
/* 138 */         // written later.
/* 139 */         final int serializefromobject_tmpCursor = serializefromobject_holder.cursor;
/* 140 */
/* 141 */         if (serializefromobject_value instanceof UnsafeArrayData) {
/* 142 */           final int serializefromobject_sizeInBytes = ((UnsafeArrayData) serializefromobject_value).getSizeInBytes();
/* 143 */           // grow the global buffer before writing data.
/* 144 */           serializefromobject_holder.grow(serializefromobject_sizeInBytes);
/* 145 */           ((UnsafeArrayData) serializefromobject_value).writeToMemory(serializefromobject_holder.buffer, serializefromobject_holder.cursor);
/* 146 */           serializefromobject_holder.cursor += serializefromobject_sizeInBytes;
/* 147 */
/* 148 */         } else {
/* 149 */           final int serializefromobject_numElements = serializefromobject_value.numElements();
/* 150 */           serializefromobject_arrayWriter.initialize(serializefromobject_holder, serializefromobject_numElements, 8);
/* 151 */
/* 152 */           for (int serializefromobject_index = 0; serializefromobject_index < serializefromobject_numElements; serializefromobject_index++) {
/* 153 */             if (serializefromobject_value.isNullAt(serializefromobject_index)) {
/* 154 */               serializefromobject_arrayWriter.setNullDouble(serializefromobject_index);
/* 155 */             } else {
/* 156 */               final double serializefromobject_element = serializefromobject_value.getDouble(serializefromobject_index);
/* 157 */               serializefromobject_arrayWriter.write(serializefromobject_index, serializefromobject_element);
/* 158 */             }
/* 159 */           }
/* 160 */         }
/* 161 */
/* 162 */         serializefromobject_rowWriter.setOffsetAndSize(0, serializefromobject_tmpCursor, serializefromobject_holder.cursor - serializefromobject_tmpCursor);
/* 163 */       }
/* 164 */       serializefromobject_result.setTotalSize(serializefromobject_holder.totalSize());
/* 165 */       append(serializefromobject_result);
/* 166 */       if (shouldStop()) return;
/* 167 */     }
/* 168 */   }
```

## How was this patch tested?

Add test suites to ``DatasetPrimitiveSuite``

Author: Kazuaki Ishizaki <is...@jp.ibm.com>

Closes #17569 from kiszk/SPARK-20253.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/7a63f5e8
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/7a63f5e8
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/7a63f5e8

Branch: refs/heads/master
Commit: 7a63f5e82758345ff1f3322950f2bbea350c48b9
Parents: 261eaf5
Author: Kazuaki Ishizaki <is...@jp.ibm.com>
Authored: Mon Apr 10 10:47:17 2017 +0800
Committer: Wenchen Fan <we...@databricks.com>
Committed: Mon Apr 10 10:47:17 2017 +0800

----------------------------------------------------------------------
 .../spark/sql/catalyst/ScalaReflection.scala    | 27 +++++++++++--------
 .../sql/catalyst/encoders/RowEncoder.scala      | 19 +++++++------
 .../catalyst/expressions/objects/objects.scala  | 28 ++++++++++----------
 .../spark/sql/DatasetPrimitiveSuite.scala       | 10 +++++++
 4 files changed, 51 insertions(+), 33 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/7a63f5e8/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 206ae2f..1981227 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -251,19 +251,22 @@ object ScalaReflection extends ScalaReflection {
           getPath :: Nil)
 
       case t if t <:< localTypeOf[java.lang.String] =>
-        Invoke(getPath, "toString", ObjectType(classOf[String]))
+        Invoke(getPath, "toString", ObjectType(classOf[String]), returnNullable = false)
 
       case t if t <:< localTypeOf[java.math.BigDecimal] =>
-        Invoke(getPath, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal]))
+        Invoke(getPath, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal]),
+          returnNullable = false)
 
       case t if t <:< localTypeOf[BigDecimal] =>
-        Invoke(getPath, "toBigDecimal", ObjectType(classOf[BigDecimal]))
+        Invoke(getPath, "toBigDecimal", ObjectType(classOf[BigDecimal]), returnNullable = false)
 
       case t if t <:< localTypeOf[java.math.BigInteger] =>
-        Invoke(getPath, "toJavaBigInteger", ObjectType(classOf[java.math.BigInteger]))
+        Invoke(getPath, "toJavaBigInteger", ObjectType(classOf[java.math.BigInteger]),
+          returnNullable = false)
 
       case t if t <:< localTypeOf[scala.math.BigInt] =>
-        Invoke(getPath, "toScalaBigInt", ObjectType(classOf[scala.math.BigInt]))
+        Invoke(getPath, "toScalaBigInt", ObjectType(classOf[scala.math.BigInt]),
+          returnNullable = false)
 
       case t if t <:< localTypeOf[Array[_]] =>
         val TypeRef(_, _, Seq(elementType)) = t
@@ -284,7 +287,7 @@ object ScalaReflection extends ScalaReflection {
         val arrayCls = arrayClassFor(elementType)
 
         if (elementNullable) {
-          Invoke(arrayData, "array", arrayCls)
+          Invoke(arrayData, "array", arrayCls, returnNullable = false)
         } else {
           val primitiveMethod = elementType match {
             case t if t <:< definitions.IntTpe => "toIntArray"
@@ -297,7 +300,7 @@ object ScalaReflection extends ScalaReflection {
             case other => throw new IllegalStateException("expect primitive array element type " +
               "but got " + other)
           }
-          Invoke(arrayData, primitiveMethod, arrayCls)
+          Invoke(arrayData, primitiveMethod, arrayCls, returnNullable = false)
         }
 
       case t if t <:< localTypeOf[Seq[_]] =>
@@ -330,19 +333,21 @@ object ScalaReflection extends ScalaReflection {
           Invoke(
             MapObjects(
               p => deserializerFor(keyType, Some(p), walkedTypePath),
-              Invoke(getPath, "keyArray", ArrayType(schemaFor(keyType).dataType)),
+              Invoke(getPath, "keyArray", ArrayType(schemaFor(keyType).dataType),
+                returnNullable = false),
               schemaFor(keyType).dataType),
             "array",
-            ObjectType(classOf[Array[Any]]))
+            ObjectType(classOf[Array[Any]]), returnNullable = false)
 
         val valueData =
           Invoke(
             MapObjects(
               p => deserializerFor(valueType, Some(p), walkedTypePath),
-              Invoke(getPath, "valueArray", ArrayType(schemaFor(valueType).dataType)),
+              Invoke(getPath, "valueArray", ArrayType(schemaFor(valueType).dataType),
+                returnNullable = false),
               schemaFor(valueType).dataType),
             "array",
-            ObjectType(classOf[Array[Any]]))
+            ObjectType(classOf[Array[Any]]), returnNullable = false)
 
         StaticInvoke(
           ArrayBasedMapData.getClass,

http://git-wip-us.apache.org/repos/asf/spark/blob/7a63f5e8/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
index e95e97b..0f8282d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
@@ -89,7 +89,7 @@ object RowEncoder {
         udtClass,
         Nil,
         dataType = ObjectType(udtClass), false)
-      Invoke(obj, "serialize", udt, inputObject :: Nil)
+      Invoke(obj, "serialize", udt, inputObject :: Nil, returnNullable = false)
 
     case TimestampType =>
       StaticInvoke(
@@ -136,16 +136,18 @@ object RowEncoder {
     case t @ MapType(kt, vt, valueNullable) =>
       val keys =
         Invoke(
-          Invoke(inputObject, "keysIterator", ObjectType(classOf[scala.collection.Iterator[_]])),
+          Invoke(inputObject, "keysIterator", ObjectType(classOf[scala.collection.Iterator[_]]),
+            returnNullable = false),
           "toSeq",
-          ObjectType(classOf[scala.collection.Seq[_]]))
+          ObjectType(classOf[scala.collection.Seq[_]]), returnNullable = false)
       val convertedKeys = serializerFor(keys, ArrayType(kt, false))
 
       val values =
         Invoke(
-          Invoke(inputObject, "valuesIterator", ObjectType(classOf[scala.collection.Iterator[_]])),
+          Invoke(inputObject, "valuesIterator", ObjectType(classOf[scala.collection.Iterator[_]]),
+            returnNullable = false),
           "toSeq",
-          ObjectType(classOf[scala.collection.Seq[_]]))
+          ObjectType(classOf[scala.collection.Seq[_]]), returnNullable = false)
       val convertedValues = serializerFor(values, ArrayType(vt, valueNullable))
 
       NewInstance(
@@ -262,17 +264,18 @@ object RowEncoder {
         input :: Nil)
 
     case _: DecimalType =>
-      Invoke(input, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal]))
+      Invoke(input, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal]),
+        returnNullable = false)
 
     case StringType =>
-      Invoke(input, "toString", ObjectType(classOf[String]))
+      Invoke(input, "toString", ObjectType(classOf[String]), returnNullable = false)
 
     case ArrayType(et, nullable) =>
       val arrayData =
         Invoke(
           MapObjects(deserializerFor(_), input, et),
           "array",
-          ObjectType(classOf[Array[_]]))
+          ObjectType(classOf[Array[_]]), returnNullable = false)
       StaticInvoke(
         scala.collection.mutable.WrappedArray.getClass,
         ObjectType(classOf[Seq[_]]),

http://git-wip-us.apache.org/repos/asf/spark/blob/7a63f5e8/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index 53842ef..6d94764 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -225,25 +225,26 @@ case class Invoke(
       getFuncResult(ev.value, s"${obj.value}.$functionName($argString)")
     } else {
       val funcResult = ctx.freshName("funcResult")
+      // If the function can return null, we do an extra check to make sure our null bit is still
+      // set correctly.
+      val assignResult = if (!returnNullable) {
+        s"${ev.value} = (${ctx.boxedType(javaType)}) $funcResult;"
+      } else {
+        s"""
+          if ($funcResult != null) {
+            ${ev.value} = (${ctx.boxedType(javaType)}) $funcResult;
+          } else {
+            ${ev.isNull} = true;
+          }
+        """
+      }
       s"""
         Object $funcResult = null;
         ${getFuncResult(funcResult, s"${obj.value}.$functionName($argString)")}
-        if ($funcResult == null) {
-          ${ev.isNull} = true;
-        } else {
-          ${ev.value} = (${ctx.boxedType(javaType)}) $funcResult;
-        }
+        $assignResult
       """
     }
 
-    // If the function can return null, we do an extra check to make sure our null bit is still set
-    // correctly.
-    val postNullCheck = if (ctx.defaultValue(dataType) == "null") {
-      s"${ev.isNull} = ${ev.value} == null;"
-    } else {
-      ""
-    }
-
     val code = s"""
       ${obj.code}
       boolean ${ev.isNull} = true;
@@ -254,7 +255,6 @@ case class Invoke(
         if (!${ev.isNull}) {
           $evaluate
         }
-        $postNullCheck
       }
      """
     ev.copy(code = code)

http://git-wip-us.apache.org/repos/asf/spark/blob/7a63f5e8/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
index 82b7075..5415653 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
@@ -96,6 +96,16 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext {
     checkDataset(dsBoolean.map(e => !e), false, true)
   }
 
+  test("mapPrimitiveArray") {
+    val dsInt = Seq(Array(1, 2), Array(3, 4)).toDS()
+    checkDataset(dsInt.map(e => e), Array(1, 2), Array(3, 4))
+    checkDataset(dsInt.map(e => null: Array[Int]), null, null)
+
+    val dsDouble = Seq(Array(1D, 2D), Array(3D, 4D)).toDS()
+    checkDataset(dsDouble.map(e => e), Array(1D, 2D), Array(3D, 4D))
+    checkDataset(dsDouble.map(e => null: Array[Double]), null, null)
+  }
+
   test("filter") {
     val ds = Seq(1, 2, 3, 4).toDS()
     checkDataset(


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org