You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2017/07/05 06:25:31 UTC

spark git commit: [SPARK-18623][SQL] Add `returnNullable` to `StaticInvoke` and modify it to handle properly.

Repository: spark
Updated Branches:
  refs/heads/master f2c3b1dd6 -> a38643256


[SPARK-18623][SQL] Add `returnNullable` to `StaticInvoke` and modify it to handle properly.

## What changes were proposed in this pull request?

Add `returnNullable` to `StaticInvoke` the same as #15780 is trying to add to `Invoke` and modify to handle properly.

## How was this patch tested?

Existing tests.

Author: Takuya UESHIN <ue...@happy-camper.st>
Author: Takuya UESHIN <ue...@databricks.com>

Closes #16056 from ueshin/issues/SPARK-18623.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/a3864325
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/a3864325
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/a3864325

Branch: refs/heads/master
Commit: a38643256691947ff7f7c474b85c052a7d5d8553
Parents: f2c3b1d
Author: Takuya UESHIN <ue...@happy-camper.st>
Authored: Wed Jul 5 14:25:26 2017 +0800
Committer: Wenchen Fan <we...@databricks.com>
Committed: Wed Jul 5 14:25:26 2017 +0800

----------------------------------------------------------------------
 .../spark/sql/catalyst/JavaTypeInference.scala  | 21 ++++++----
 .../spark/sql/catalyst/ScalaReflection.scala    | 44 ++++++++++++--------
 .../sql/catalyst/encoders/RowEncoder.scala      | 27 ++++++++----
 .../catalyst/expressions/objects/objects.scala  | 42 +++++++++++++++----
 4 files changed, 91 insertions(+), 43 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/a3864325/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
index 90ec699..21363d3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
@@ -216,7 +216,7 @@ object JavaTypeInference {
           ObjectType(c),
           "valueOf",
           getPath :: Nil,
-          propagateNull = true)
+          returnNullable = false)
 
       case c if c == classOf[java.sql.Date] =>
         StaticInvoke(
@@ -224,7 +224,7 @@ object JavaTypeInference {
           ObjectType(c),
           "toJavaDate",
           getPath :: Nil,
-          propagateNull = true)
+          returnNullable = false)
 
       case c if c == classOf[java.sql.Timestamp] =>
         StaticInvoke(
@@ -232,7 +232,7 @@ object JavaTypeInference {
           ObjectType(c),
           "toJavaTimestamp",
           getPath :: Nil,
-          propagateNull = true)
+          returnNullable = false)
 
       case c if c == classOf[java.lang.String] =>
         Invoke(getPath, "toString", ObjectType(classOf[String]))
@@ -300,7 +300,8 @@ object JavaTypeInference {
           ArrayBasedMapData.getClass,
           ObjectType(classOf[JMap[_, _]]),
           "toJavaMap",
-          keyData :: valueData :: Nil)
+          keyData :: valueData :: Nil,
+          returnNullable = false)
 
       case other =>
         val properties = getJavaBeanReadableAndWritableProperties(other)
@@ -367,28 +368,32 @@ object JavaTypeInference {
             classOf[UTF8String],
             StringType,
             "fromString",
-            inputObject :: Nil)
+            inputObject :: Nil,
+            returnNullable = false)
 
         case c if c == classOf[java.sql.Timestamp] =>
           StaticInvoke(
             DateTimeUtils.getClass,
             TimestampType,
             "fromJavaTimestamp",
-            inputObject :: Nil)
+            inputObject :: Nil,
+            returnNullable = false)
 
         case c if c == classOf[java.sql.Date] =>
           StaticInvoke(
             DateTimeUtils.getClass,
             DateType,
             "fromJavaDate",
-            inputObject :: Nil)
+            inputObject :: Nil,
+            returnNullable = false)
 
         case c if c == classOf[java.math.BigDecimal] =>
           StaticInvoke(
             Decimal.getClass,
             DecimalType.SYSTEM_DEFAULT,
             "apply",
-            inputObject :: Nil)
+            inputObject :: Nil,
+            returnNullable = false)
 
         case c if c == classOf[java.lang.Boolean] =>
           Invoke(inputObject, "booleanValue", BooleanType)

http://git-wip-us.apache.org/repos/asf/spark/blob/a3864325/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index bea0de4..814f2c1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -206,51 +206,53 @@ object ScalaReflection extends ScalaReflection {
       case t if t <:< localTypeOf[java.lang.Integer] =>
         val boxedType = classOf[java.lang.Integer]
         val objectType = ObjectType(boxedType)
-        StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, propagateNull = true)
+        StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, returnNullable = false)
 
       case t if t <:< localTypeOf[java.lang.Long] =>
         val boxedType = classOf[java.lang.Long]
         val objectType = ObjectType(boxedType)
-        StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, propagateNull = true)
+        StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, returnNullable = false)
 
       case t if t <:< localTypeOf[java.lang.Double] =>
         val boxedType = classOf[java.lang.Double]
         val objectType = ObjectType(boxedType)
-        StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, propagateNull = true)
+        StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, returnNullable = false)
 
       case t if t <:< localTypeOf[java.lang.Float] =>
         val boxedType = classOf[java.lang.Float]
         val objectType = ObjectType(boxedType)
-        StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, propagateNull = true)
+        StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, returnNullable = false)
 
       case t if t <:< localTypeOf[java.lang.Short] =>
         val boxedType = classOf[java.lang.Short]
         val objectType = ObjectType(boxedType)
-        StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, propagateNull = true)
+        StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, returnNullable = false)
 
       case t if t <:< localTypeOf[java.lang.Byte] =>
         val boxedType = classOf[java.lang.Byte]
         val objectType = ObjectType(boxedType)
-        StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, propagateNull = true)
+        StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, returnNullable = false)
 
       case t if t <:< localTypeOf[java.lang.Boolean] =>
         val boxedType = classOf[java.lang.Boolean]
         val objectType = ObjectType(boxedType)
-        StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, propagateNull = true)
+        StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, returnNullable = false)
 
       case t if t <:< localTypeOf[java.sql.Date] =>
         StaticInvoke(
           DateTimeUtils.getClass,
           ObjectType(classOf[java.sql.Date]),
           "toJavaDate",
-          getPath :: Nil)
+          getPath :: Nil,
+          returnNullable = false)
 
       case t if t <:< localTypeOf[java.sql.Timestamp] =>
         StaticInvoke(
           DateTimeUtils.getClass,
           ObjectType(classOf[java.sql.Timestamp]),
           "toJavaTimestamp",
-          getPath :: Nil)
+          getPath :: Nil,
+          returnNullable = false)
 
       case t if t <:< localTypeOf[java.lang.String] =>
         Invoke(getPath, "toString", ObjectType(classOf[String]), returnNullable = false)
@@ -446,7 +448,8 @@ object ScalaReflection extends ScalaReflection {
               classOf[UnsafeArrayData],
               ArrayType(dt, false),
               "fromPrimitiveArray",
-              input :: Nil)
+              input :: Nil,
+              returnNullable = false)
           } else {
             NewInstance(
               classOf[GenericArrayData],
@@ -504,49 +507,56 @@ object ScalaReflection extends ScalaReflection {
           classOf[UTF8String],
           StringType,
           "fromString",
-          inputObject :: Nil)
+          inputObject :: Nil,
+          returnNullable = false)
 
       case t if t <:< localTypeOf[java.sql.Timestamp] =>
         StaticInvoke(
           DateTimeUtils.getClass,
           TimestampType,
           "fromJavaTimestamp",
-          inputObject :: Nil)
+          inputObject :: Nil,
+          returnNullable = false)
 
       case t if t <:< localTypeOf[java.sql.Date] =>
         StaticInvoke(
           DateTimeUtils.getClass,
           DateType,
           "fromJavaDate",
-          inputObject :: Nil)
+          inputObject :: Nil,
+          returnNullable = false)
 
       case t if t <:< localTypeOf[BigDecimal] =>
         StaticInvoke(
           Decimal.getClass,
           DecimalType.SYSTEM_DEFAULT,
           "apply",
-          inputObject :: Nil)
+          inputObject :: Nil,
+          returnNullable = false)
 
       case t if t <:< localTypeOf[java.math.BigDecimal] =>
         StaticInvoke(
           Decimal.getClass,
           DecimalType.SYSTEM_DEFAULT,
           "apply",
-          inputObject :: Nil)
+          inputObject :: Nil,
+          returnNullable = false)
 
       case t if t <:< localTypeOf[java.math.BigInteger] =>
         StaticInvoke(
           Decimal.getClass,
           DecimalType.BigIntDecimal,
           "apply",
-          inputObject :: Nil)
+          inputObject :: Nil,
+          returnNullable = false)
 
       case t if t <:< localTypeOf[scala.math.BigInt] =>
         StaticInvoke(
           Decimal.getClass,
           DecimalType.BigIntDecimal,
           "apply",
-          inputObject :: Nil)
+          inputObject :: Nil,
+          returnNullable = false)
 
       case t if t <:< localTypeOf[java.lang.Integer] =>
         Invoke(inputObject, "intValue", IntegerType)

http://git-wip-us.apache.org/repos/asf/spark/blob/a3864325/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
index 0f8282d..cc32fac 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
@@ -96,28 +96,32 @@ object RowEncoder {
         DateTimeUtils.getClass,
         TimestampType,
         "fromJavaTimestamp",
-        inputObject :: Nil)
+        inputObject :: Nil,
+        returnNullable = false)
 
     case DateType =>
       StaticInvoke(
         DateTimeUtils.getClass,
         DateType,
         "fromJavaDate",
-        inputObject :: Nil)
+        inputObject :: Nil,
+        returnNullable = false)
 
     case d: DecimalType =>
       StaticInvoke(
         Decimal.getClass,
         d,
         "fromDecimal",
-        inputObject :: Nil)
+        inputObject :: Nil,
+        returnNullable = false)
 
     case StringType =>
       StaticInvoke(
         classOf[UTF8String],
         StringType,
         "fromString",
-        inputObject :: Nil)
+        inputObject :: Nil,
+        returnNullable = false)
 
     case t @ ArrayType(et, cn) =>
       et match {
@@ -126,7 +130,8 @@ object RowEncoder {
             classOf[ArrayData],
             t,
             "toArrayData",
-            inputObject :: Nil)
+            inputObject :: Nil,
+            returnNullable = false)
         case _ => MapObjects(
           element => serializerFor(ValidateExternalType(element, et), et),
           inputObject,
@@ -254,14 +259,16 @@ object RowEncoder {
         DateTimeUtils.getClass,
         ObjectType(classOf[java.sql.Timestamp]),
         "toJavaTimestamp",
-        input :: Nil)
+        input :: Nil,
+        returnNullable = false)
 
     case DateType =>
       StaticInvoke(
         DateTimeUtils.getClass,
         ObjectType(classOf[java.sql.Date]),
         "toJavaDate",
-        input :: Nil)
+        input :: Nil,
+        returnNullable = false)
 
     case _: DecimalType =>
       Invoke(input, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal]),
@@ -280,7 +287,8 @@ object RowEncoder {
         scala.collection.mutable.WrappedArray.getClass,
         ObjectType(classOf[Seq[_]]),
         "make",
-        arrayData :: Nil)
+        arrayData :: Nil,
+        returnNullable = false)
 
     case MapType(kt, vt, valueNullable) =>
       val keyArrayType = ArrayType(kt, false)
@@ -293,7 +301,8 @@ object RowEncoder {
         ArrayBasedMapData.getClass,
         ObjectType(classOf[Map[_, _]]),
         "toScalaMap",
-        keyData :: valueData :: Nil)
+        keyData :: valueData :: Nil,
+        returnNullable = false)
 
     case schema @ StructType(fields) =>
       val convertedFields = fields.zipWithIndex.map { case (f, i) =>

http://git-wip-us.apache.org/repos/asf/spark/blob/a3864325/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index ce07f4a..24c06d8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -118,17 +118,20 @@ trait InvokeLike extends Expression with NonSQLExpression {
  * @param arguments An optional list of expressions to pass as arguments to the function.
  * @param propagateNull When true, and any of the arguments is null, null will be returned instead
  *                      of calling the function.
+ * @param returnNullable When false, indicating the invoked method will always return
+ *                       non-null value.
  */
 case class StaticInvoke(
     staticObject: Class[_],
     dataType: DataType,
     functionName: String,
     arguments: Seq[Expression] = Nil,
-    propagateNull: Boolean = true) extends InvokeLike {
+    propagateNull: Boolean = true,
+    returnNullable: Boolean = true) extends InvokeLike {
 
   val objectName = staticObject.getName.stripSuffix("$")
 
-  override def nullable: Boolean = true
+  override def nullable: Boolean = needNullCheck || returnNullable
   override def children: Seq[Expression] = arguments
 
   override def eval(input: InternalRow): Any =
@@ -141,19 +144,40 @@ case class StaticInvoke(
 
     val callFunc = s"$objectName.$functionName($argString)"
 
-    // If the function can return null, we do an extra check to make sure our null bit is still set
-    // correctly.
-    val postNullCheck = if (ctx.defaultValue(dataType) == "null") {
-      s"${ev.isNull} = ${ev.value} == null;"
+    val prepareIsNull = if (nullable) {
+      s"boolean ${ev.isNull} = $resultIsNull;"
     } else {
+      ev.isNull = "false"
       ""
     }
 
+    val evaluate = if (returnNullable) {
+      if (ctx.defaultValue(dataType) == "null") {
+        s"""
+          ${ev.value} = $callFunc;
+          ${ev.isNull} = ${ev.value} == null;
+        """
+      } else {
+        val boxedResult = ctx.freshName("boxedResult")
+        s"""
+          ${ctx.boxedType(dataType)} $boxedResult = $callFunc;
+          ${ev.isNull} = $boxedResult == null;
+          if (!${ev.isNull}) {
+            ${ev.value} = $boxedResult;
+          }
+        """
+      }
+    } else {
+      s"${ev.value} = $callFunc;"
+    }
+
     val code = s"""
       $argCode
-      boolean ${ev.isNull} = $resultIsNull;
-      final $javaType ${ev.value} = $resultIsNull ? ${ctx.defaultValue(dataType)} : $callFunc;
-      $postNullCheck
+      $prepareIsNull
+      $javaType ${ev.value} = ${ctx.defaultValue(dataType)};
+      if (!$resultIsNull) {
+        $evaluate
+      }
      """
     ev.copy(code = code)
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org