You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by da...@apache.org on 2016/02/01 07:43:16 UTC

spark git commit: [SPARK-13093] [SQL] improve null check in nullSafeCodeGen for unary, binary and ternary expression

Repository: spark
Updated Branches:
  refs/heads/master 5a8b978fa -> c1da4d421


[SPARK-13093] [SQL] improve null check in nullSafeCodeGen for unary, binary and ternary expression

The current implementation is sub-optimal:

* If an expression is always nullable, e.g. `Unhex`, we can still remove null check for children if they are not nullable.
* If an expression has some non-nullable children, we can still remove null check for these children and keep null check for others.

This PR improves this by making the null check elimination more fine-grained.

Author: Wenchen Fan <we...@databricks.com>

Closes #10987 from cloud-fan/null-check.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/c1da4d42
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/c1da4d42
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/c1da4d42

Branch: refs/heads/master
Commit: c1da4d421ab78772ffa52ad46e5bdfb4e5268f47
Parents: 5a8b978
Author: Wenchen Fan <we...@databricks.com>
Authored: Sun Jan 31 22:43:03 2016 -0800
Committer: Davies Liu <da...@gmail.com>
Committed: Sun Jan 31 22:43:03 2016 -0800

----------------------------------------------------------------------
 .../sql/catalyst/expressions/Expression.scala   | 104 ++++++++++---------
 .../expressions/codegen/CodeGenerator.scala     |  32 ++++--
 .../spark/sql/catalyst/expressions/misc.scala   |  16 +--
 3 files changed, 85 insertions(+), 67 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/c1da4d42/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index db17ba7..353fb92 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -320,7 +320,7 @@ abstract class UnaryExpression extends Expression {
 
   /**
    * Called by unary expressions to generate a code block that returns null if its parent returns
-   * null, and if not not null, use `f` to generate the expression.
+   * null, and if not null, use `f` to generate the expression.
    *
    * As an example, the following does a boolean inversion (i.e. NOT).
    * {{{
@@ -340,7 +340,7 @@ abstract class UnaryExpression extends Expression {
 
   /**
    * Called by unary expressions to generate a code block that returns null if its parent returns
-   * null, and if not not null, use `f` to generate the expression.
+   * null, and if not null, use `f` to generate the expression.
    *
    * @param f function that accepts the non-null evaluation result name of child and returns Java
    *          code to compute the output.
@@ -349,20 +349,23 @@ abstract class UnaryExpression extends Expression {
       ctx: CodegenContext,
       ev: ExprCode,
       f: String => String): String = {
-    val eval = child.gen(ctx)
+    val childGen = child.gen(ctx)
+    val resultCode = f(childGen.value)
+
     if (nullable) {
-      eval.code + s"""
-        boolean ${ev.isNull} = ${eval.isNull};
+      val nullSafeEval = ctx.nullSafeExec(child.nullable, childGen.isNull)(resultCode)
+      s"""
+        ${childGen.code}
+        boolean ${ev.isNull} = ${childGen.isNull};
         ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
-        if (!${eval.isNull}) {
-          ${f(eval.value)}
-        }
+        $nullSafeEval
       """
     } else {
       ev.isNull = "false"
-      eval.code + s"""
+      s"""
+        ${childGen.code}
         ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
-        ${f(eval.value)}
+        $resultCode
       """
     }
   }
@@ -440,29 +443,31 @@ abstract class BinaryExpression extends Expression {
       ctx: CodegenContext,
       ev: ExprCode,
       f: (String, String) => String): String = {
-    val eval1 = left.gen(ctx)
-    val eval2 = right.gen(ctx)
-    val resultCode = f(eval1.value, eval2.value)
+    val leftGen = left.gen(ctx)
+    val rightGen = right.gen(ctx)
+    val resultCode = f(leftGen.value, rightGen.value)
+
     if (nullable) {
+      val nullSafeEval =
+        leftGen.code + ctx.nullSafeExec(left.nullable, leftGen.isNull) {
+          rightGen.code + ctx.nullSafeExec(right.nullable, rightGen.isNull) {
+            s"""
+              ${ev.isNull} = false; // resultCode could change nullability.
+              $resultCode
+            """
+          }
+      }
+
       s"""
-        ${eval1.code}
-        boolean ${ev.isNull} = ${eval1.isNull};
+        boolean ${ev.isNull} = true;
         ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
-        if (!${ev.isNull}) {
-          ${eval2.code}
-          if (!${eval2.isNull}) {
-            $resultCode
-          } else {
-            ${ev.isNull} = true;
-          }
-        }
+        $nullSafeEval
       """
-
     } else {
       ev.isNull = "false"
       s"""
-        ${eval1.code}
-        ${eval2.code}
+        ${leftGen.code}
+        ${rightGen.code}
         ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
         $resultCode
       """
@@ -527,7 +532,7 @@ abstract class TernaryExpression extends Expression {
 
   /**
    * Default behavior of evaluation according to the default nullability of TernaryExpression.
-   * If subclass of BinaryExpression override nullable, probably should also override this.
+   * If subclass of TernaryExpression override nullable, probably should also override this.
    */
   override def eval(input: InternalRow): Any = {
     val exprs = children
@@ -553,11 +558,11 @@ abstract class TernaryExpression extends Expression {
     sys.error(s"BinaryExpressions must override either eval or nullSafeEval")
 
   /**
-   * Short hand for generating binary evaluation code.
+   * Short hand for generating ternary evaluation code.
    * If either of the sub-expressions is null, the result of this computation
    * is assumed to be null.
    *
-   * @param f accepts two variable names and returns Java code to compute the output.
+   * @param f accepts three variable names and returns Java code to compute the output.
    */
   protected def defineCodeGen(
     ctx: CodegenContext,
@@ -569,41 +574,46 @@ abstract class TernaryExpression extends Expression {
   }
 
   /**
-   * Short hand for generating binary evaluation code.
+   * Short hand for generating ternary evaluation code.
    * If either of the sub-expressions is null, the result of this computation
    * is assumed to be null.
    *
-   * @param f function that accepts the 2 non-null evaluation result names of children
+   * @param f function that accepts the 3 non-null evaluation result names of children
    *          and returns Java code to compute the output.
    */
   protected def nullSafeCodeGen(
     ctx: CodegenContext,
     ev: ExprCode,
     f: (String, String, String) => String): String = {
-    val evals = children.map(_.gen(ctx))
-    val resultCode = f(evals(0).value, evals(1).value, evals(2).value)
+    val leftGen = children(0).gen(ctx)
+    val midGen = children(1).gen(ctx)
+    val rightGen = children(2).gen(ctx)
+    val resultCode = f(leftGen.value, midGen.value, rightGen.value)
+
     if (nullable) {
+      val nullSafeEval =
+        leftGen.code + ctx.nullSafeExec(children(0).nullable, leftGen.isNull) {
+          midGen.code + ctx.nullSafeExec(children(1).nullable, midGen.isNull) {
+            rightGen.code + ctx.nullSafeExec(children(2).nullable, rightGen.isNull) {
+              s"""
+                ${ev.isNull} = false; // resultCode could change nullability.
+                $resultCode
+              """
+            }
+          }
+      }
+
       s"""
-        ${evals(0).code}
         boolean ${ev.isNull} = true;
         ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
-        if (!${evals(0).isNull}) {
-          ${evals(1).code}
-          if (!${evals(1).isNull}) {
-            ${evals(2).code}
-            if (!${evals(2).isNull}) {
-              ${ev.isNull} = false;  // resultCode could change nullability
-              $resultCode
-            }
-          }
-        }
+        $nullSafeEval
       """
     } else {
       ev.isNull = "false"
       s"""
-        ${evals(0).code}
-        ${evals(1).code}
-        ${evals(2).code}
+        ${leftGen.code}
+        ${midGen.code}
+        ${rightGen.code}
         ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
         $resultCode
       """

http://git-wip-us.apache.org/repos/asf/spark/blob/c1da4d42/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 21f9198..a30aba1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -402,18 +402,38 @@ class CodegenContext {
   }
 
   /**
-    * Generates code for greater of two expressions.
-    *
-    * @param dataType data type of the expressions
-    * @param c1 name of the variable of expression 1's output
-    * @param c2 name of the variable of expression 2's output
-    */
+   * Generates code for greater of two expressions.
+   *
+   * @param dataType data type of the expressions
+   * @param c1 name of the variable of expression 1's output
+   * @param c2 name of the variable of expression 2's output
+   */
   def genGreater(dataType: DataType, c1: String, c2: String): String = javaType(dataType) match {
     case JAVA_BYTE | JAVA_SHORT | JAVA_INT | JAVA_LONG => s"$c1 > $c2"
     case _ => s"(${genComp(dataType, c1, c2)}) > 0"
   }
 
   /**
+   * Generates code to do null safe execution, i.e. only execute the code when the input is not
+   * null by adding null check if necessary.
+   *
+   * @param nullable used to decide whether we should add null check or not.
+   * @param isNull the code to check if the input is null.
+   * @param execute the code that should only be executed when the input is not null.
+   */
+  def nullSafeExec(nullable: Boolean, isNull: String)(execute: String): String = {
+    if (nullable) {
+      s"""
+        if (!$isNull) {
+          $execute
+        }
+      """
+    } else {
+      "\n" + execute
+    }
+  }
+
+  /**
    * List of java data types that have special accessors and setters in [[InternalRow]].
    */
   val primitiveTypes =

http://git-wip-us.apache.org/repos/asf/spark/blob/c1da4d42/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
index 8480c3f..36e1fa1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
@@ -327,7 +327,7 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression
     ev.isNull = "false"
     val childrenHash = children.map { child =>
       val childGen = child.gen(ctx)
-      childGen.code + generateNullCheck(child.nullable, childGen.isNull) {
+      childGen.code + ctx.nullSafeExec(child.nullable, childGen.isNull) {
         computeHash(childGen.value, child.dataType, ev.value, ctx)
       }
     }.mkString("\n")
@@ -338,18 +338,6 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression
     """
   }
 
-  private def generateNullCheck(nullable: Boolean, isNull: String)(execution: String): String = {
-    if (nullable) {
-      s"""
-        if (!$isNull) {
-          $execution
-        }
-      """
-    } else {
-      "\n" + execution
-    }
-  }
-
   private def nullSafeElementHash(
       input: String,
       index: String,
@@ -359,7 +347,7 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression
       ctx: CodegenContext): String = {
     val element = ctx.freshName("element")
 
-    generateNullCheck(nullable, s"$input.isNullAt($index)") {
+    ctx.nullSafeExec(nullable, s"$input.isNullAt($index)") {
       s"""
         final ${ctx.javaType(elementType)} $element = ${ctx.getValue(input, elementType, index)};
         ${computeHash(element, elementType, result, ctx)}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org