You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2016/11/21 04:05:08 UTC

spark git commit: [SPARK-18467][SQL] Extracts method for preparing arguments from StaticInvoke, Invoke and NewInstance and modify to short circuit if arguments have null when `needNullCheck == true`.

Repository: spark
Updated Branches:
  refs/heads/master b625a36eb -> 658547974


[SPARK-18467][SQL] Extracts method for preparing arguments from StaticInvoke, Invoke and NewInstance and modify to short circuit if arguments have null when `needNullCheck == true`.

## What changes were proposed in this pull request?

This pr extracts method for preparing arguments from `StaticInvoke`, `Invoke` and `NewInstance` and modify to short circuit if arguments have `null` when `propageteNull == true`.

The steps are as follows:

1. Introduce `InvokeLike` to extract common logic from `StaticInvoke`, `Invoke` and `NewInstance` to prepare arguments.
`StaticInvoke` and `Invoke` had a risk to exceed 64kb JVM limit to prepare arguments but after this patch they can handle them because they share the preparing code of NewInstance, which handles the limit well.

2. Remove unneeded null checking and fix nullability of `NewInstance`.
Avoid some of nullabilty checking which are not needed because the expression is not nullable.

3. Modify to short circuit if arguments have `null` when `needNullCheck == true`.
If `needNullCheck == true`, preparing arguments can be skipped if we found one of them is `null`, so modified to short circuit in the case.

## How was this patch tested?

Existing tests.

Author: Takuya UESHIN <ue...@happy-camper.st>

Closes #15901 from ueshin/issues/SPARK-18467.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/65854797
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/65854797
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/65854797

Branch: refs/heads/master
Commit: 658547974915ebcaae83e13e4c3bdf68d5426fda
Parents: b625a36
Author: Takuya UESHIN <ue...@happy-camper.st>
Authored: Mon Nov 21 12:05:01 2016 +0800
Committer: Wenchen Fan <we...@databricks.com>
Committed: Mon Nov 21 12:05:01 2016 +0800

----------------------------------------------------------------------
 .../catalyst/expressions/objects/objects.scala  | 163 ++++++++++++-------
 1 file changed, 101 insertions(+), 62 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/65854797/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index 0e3d991..0b36091 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -33,6 +33,78 @@ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData}
 import org.apache.spark.sql.types._
 
 /**
+ * Common base class for [[StaticInvoke]], [[Invoke]], and [[NewInstance]].
+ */
+trait InvokeLike extends Expression with NonSQLExpression {
+
+  def arguments: Seq[Expression]
+
+  def propagateNull: Boolean
+
+  protected lazy val needNullCheck: Boolean = propagateNull && arguments.exists(_.nullable)
+
+  /**
+   * Prepares codes for arguments.
+   *
+   * - generate codes for argument.
+   * - use ctx.splitExpressions() to not exceed 64kb JVM limit while preparing arguments.
+   * - avoid some of nullabilty checking which are not needed because the expression is not
+   *   nullable.
+   * - when needNullCheck == true, short circuit if we found one of arguments is null because
+   *   preparing rest of arguments can be skipped in the case.
+   *
+   * @param ctx a [[CodegenContext]]
+   * @return (code to prepare arguments, argument string, result of argument null check)
+   */
+  def prepareArguments(ctx: CodegenContext): (String, String, String) = {
+
+    val resultIsNull = if (needNullCheck) {
+      val resultIsNull = ctx.freshName("resultIsNull")
+      ctx.addMutableState("boolean", resultIsNull, "")
+      resultIsNull
+    } else {
+      "false"
+    }
+    val argValues = arguments.map { e =>
+      val argValue = ctx.freshName("argValue")
+      ctx.addMutableState(ctx.javaType(e.dataType), argValue, "")
+      argValue
+    }
+
+    val argCodes = if (needNullCheck) {
+      val reset = s"$resultIsNull = false;"
+      val argCodes = arguments.zipWithIndex.map { case (e, i) =>
+        val expr = e.genCode(ctx)
+        val updateResultIsNull = if (e.nullable) {
+          s"$resultIsNull = ${expr.isNull};"
+        } else {
+          ""
+        }
+        s"""
+          if (!$resultIsNull) {
+            ${expr.code}
+            $updateResultIsNull
+            ${argValues(i)} = ${expr.value};
+          }
+        """
+      }
+      reset +: argCodes
+    } else {
+      arguments.zipWithIndex.map { case (e, i) =>
+        val expr = e.genCode(ctx)
+        s"""
+          ${expr.code}
+          ${argValues(i)} = ${expr.value};
+        """
+      }
+    }
+    val argCode = ctx.splitExpressions(ctx.INPUT_ROW, argCodes)
+
+    (argCode, argValues.mkString(", "), resultIsNull)
+  }
+}
+
+/**
  * Invokes a static function, returning the result.  By default, any of the arguments being null
  * will result in returning null instead of calling the function.
  *
@@ -50,7 +122,7 @@ case class StaticInvoke(
     dataType: DataType,
     functionName: String,
     arguments: Seq[Expression] = Nil,
-    propagateNull: Boolean = true) extends Expression with NonSQLExpression {
+    propagateNull: Boolean = true) extends InvokeLike {
 
   val objectName = staticObject.getName.stripSuffix("$")
 
@@ -62,16 +134,10 @@ case class StaticInvoke(
 
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     val javaType = ctx.javaType(dataType)
-    val argGen = arguments.map(_.genCode(ctx))
-    val argString = argGen.map(_.value).mkString(", ")
 
-    val callFunc = s"$objectName.$functionName($argString)"
+    val (argCode, argString, resultIsNull) = prepareArguments(ctx)
 
-    val setIsNull = if (propagateNull && arguments.nonEmpty) {
-      s"boolean ${ev.isNull} = ${argGen.map(_.isNull).mkString(" || ")};"
-    } else {
-      s"boolean ${ev.isNull} = false;"
-    }
+    val callFunc = s"$objectName.$functionName($argString)"
 
     // If the function can return null, we do an extra check to make sure our null bit is still set
     // correctly.
@@ -82,9 +148,9 @@ case class StaticInvoke(
     }
 
     val code = s"""
-      ${argGen.map(_.code).mkString("\n")}
-      $setIsNull
-      final $javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : $callFunc;
+      $argCode
+      boolean ${ev.isNull} = $resultIsNull;
+      final $javaType ${ev.value} = $resultIsNull ? ${ctx.defaultValue(dataType)} : $callFunc;
       $postNullCheck
      """
     ev.copy(code = code)
@@ -103,13 +169,15 @@ case class StaticInvoke(
  * @param functionName The name of the method to call.
  * @param dataType The expected return type of the function.
  * @param arguments An optional list of expressions, whos evaluation will be passed to the function.
+ * @param propagateNull When true, and any of the arguments is null, null will be returned instead
+ *                      of calling the function.
  */
 case class Invoke(
     targetObject: Expression,
     functionName: String,
     dataType: DataType,
     arguments: Seq[Expression] = Nil,
-    propagateNull: Boolean = true) extends Expression with NonSQLExpression {
+    propagateNull: Boolean = true) extends InvokeLike {
 
   override def nullable: Boolean = true
   override def children: Seq[Expression] = targetObject +: arguments
@@ -131,8 +199,8 @@ case class Invoke(
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     val javaType = ctx.javaType(dataType)
     val obj = targetObject.genCode(ctx)
-    val argGen = arguments.map(_.genCode(ctx))
-    val argString = argGen.map(_.value).mkString(", ")
+
+    val (argCode, argString, resultIsNull) = prepareArguments(ctx)
 
     val returnPrimitive = method.isDefined && method.get.getReturnType.isPrimitive
     val needTryCatch = method.isDefined && method.get.getExceptionTypes.nonEmpty
@@ -164,12 +232,6 @@ case class Invoke(
       """
     }
 
-    val setIsNull = if (propagateNull && arguments.nonEmpty) {
-      s"boolean ${ev.isNull} = ${obj.isNull} || ${argGen.map(_.isNull).mkString(" || ")};"
-    } else {
-      s"boolean ${ev.isNull} = ${obj.isNull};"
-    }
-
     // If the function can return null, we do an extra check to make sure our null bit is still set
     // correctly.
     val postNullCheck = if (ctx.defaultValue(dataType) == "null") {
@@ -177,15 +239,19 @@ case class Invoke(
     } else {
       ""
     }
+
     val code = s"""
       ${obj.code}
-      ${argGen.map(_.code).mkString("\n")}
-      $setIsNull
+      boolean ${ev.isNull} = true;
       $javaType ${ev.value} = ${ctx.defaultValue(dataType)};
-      if (!${ev.isNull}) {
-        $evaluate
+      if (!${obj.isNull}) {
+        $argCode
+        ${ev.isNull} = $resultIsNull;
+        if (!${ev.isNull}) {
+          $evaluate
+        }
+        $postNullCheck
       }
-      $postNullCheck
      """
     ev.copy(code = code)
   }
@@ -223,10 +289,10 @@ case class NewInstance(
     arguments: Seq[Expression],
     propagateNull: Boolean,
     dataType: DataType,
-    outerPointer: Option[() => AnyRef]) extends Expression with NonSQLExpression {
+    outerPointer: Option[() => AnyRef]) extends InvokeLike {
   private val className = cls.getName
 
-  override def nullable: Boolean = propagateNull
+  override def nullable: Boolean = needNullCheck
 
   override def children: Seq[Expression] = arguments
 
@@ -245,52 +311,25 @@ case class NewInstance(
 
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     val javaType = ctx.javaType(dataType)
-    val argIsNulls = ctx.freshName("argIsNulls")
-    ctx.addMutableState("boolean[]", argIsNulls,
-      s"$argIsNulls = new boolean[${arguments.size}];")
-    val argValues = arguments.zipWithIndex.map { case (e, i) =>
-      val argValue = ctx.freshName("argValue")
-      ctx.addMutableState(ctx.javaType(e.dataType), argValue, "")
-      argValue
-    }
 
-    val argCodes = arguments.zipWithIndex.map { case (e, i) =>
-      val expr = e.genCode(ctx)
-      expr.code + s"""
-       $argIsNulls[$i] = ${expr.isNull};
-       ${argValues(i)} = ${expr.value};
-     """
-    }
-    val argCode = ctx.splitExpressions(ctx.INPUT_ROW, argCodes)
+    val (argCode, argString, resultIsNull) = prepareArguments(ctx)
 
     val outer = outerPointer.map(func => Literal.fromObject(func()).genCode(ctx))
 
-    var isNull = ev.isNull
-    val setIsNull = if (propagateNull && arguments.nonEmpty) {
-      s"""
-       boolean $isNull = false;
-       for (int idx = 0; idx < ${arguments.length}; idx++) {
-         if ($argIsNulls[idx]) { $isNull = true; break; }
-       }
-     """
-    } else {
-      isNull = "false"
-      ""
-    }
+    ev.isNull = resultIsNull
 
     val constructorCall = outer.map { gen =>
-      s"""${gen.value}.new ${cls.getSimpleName}(${argValues.mkString(", ")})"""
+      s"${gen.value}.new ${cls.getSimpleName}($argString)"
     }.getOrElse {
-      s"new $className(${argValues.mkString(", ")})"
+      s"new $className($argString)"
     }
 
     val code = s"""
       $argCode
       ${outer.map(_.code).getOrElse("")}
-      $setIsNull
-      final $javaType ${ev.value} = $isNull ? ${ctx.defaultValue(javaType)} : $constructorCall;
-     """
-    ev.copy(code = code, isNull = isNull)
+      final $javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(javaType)} : $constructorCall;
+    """
+    ev.copy(code = code)
   }
 
   override def toString: String = s"newInstance($cls)"


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org