You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2016/01/12 20:13:11 UTC

spark git commit: [SPARK-12768][SQL] Remove CaseKeyWhen expression

Repository: spark
Updated Branches:
  refs/heads/master 508592b1b -> 0ed430e31


[SPARK-12768][SQL] Remove CaseKeyWhen expression

This patch removes CaseKeyWhen expression and replaces it with a factory method that generates the equivalent CaseWhen. This reduces the amount of code we'd need to maintain in the future for both code generation and optimizer.

Note that we introduced CaseKeyWhen to avoid duplicate evaluations of the key. This is no longer a problem because we now have common subexpression elimination.

Author: Reynold Xin <rx...@databricks.com>

Closes #10722 from rxin/SPARK-12768.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/0ed430e3
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/0ed430e3
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/0ed430e3

Branch: refs/heads/master
Commit: 0ed430e315b9a409490a3604a619321b476cb520
Parents: 508592b
Author: Reynold Xin <rx...@databricks.com>
Authored: Tue Jan 12 11:13:08 2016 -0800
Committer: Reynold Xin <rx...@databricks.com>
Committed: Tue Jan 12 11:13:08 2016 -0800

----------------------------------------------------------------------
 .../catalyst/analysis/HiveTypeCoercion.scala    |  20 +-
 .../expressions/conditionalExpressions.scala    | 187 ++++---------------
 .../analysis/HiveTypeCoercionSuite.scala        |   2 +-
 3 files changed, 38 insertions(+), 171 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/0ed430e3/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index e326ea7..75c36d9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -638,8 +638,7 @@ object HiveTypeCoercion {
    */
   object CaseWhenCoercion extends Rule[LogicalPlan] {
     def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
-      case c: CaseWhenLike if c.childrenResolved && !c.valueTypesEqual =>
-        logDebug(s"Input values for null casting ${c.valueTypes.mkString(",")}")
+      case c: CaseWhen if c.childrenResolved && !c.valueTypesEqual =>
         val maybeCommonType = findWiderCommonType(c.valueTypes)
         maybeCommonType.map { commonType =>
           val castedBranches = c.branches.grouped(2).map {
@@ -649,22 +648,7 @@ object HiveTypeCoercion {
               Seq(Cast(elseVal, commonType))
             case other => other
           }.reduce(_ ++ _)
-          c match {
-            case _: CaseWhen => CaseWhen(castedBranches)
-            case CaseKeyWhen(key, _) => CaseKeyWhen(key, castedBranches)
-          }
-        }.getOrElse(c)
-
-      case c: CaseKeyWhen if c.childrenResolved && !c.resolved =>
-        val maybeCommonType =
-          findWiderCommonType((c.key +: c.whenList).map(_.dataType))
-        maybeCommonType.map { commonType =>
-          val castedBranches = c.branches.grouped(2).map {
-            case Seq(whenExpr, thenExpr) if whenExpr.dataType != commonType =>
-              Seq(Cast(whenExpr, commonType), thenExpr)
-            case other => other
-          }.reduce(_ ++ _)
-          CaseKeyWhen(Cast(c.key, commonType), castedBranches)
+          CaseWhen(castedBranches)
         }.getOrElse(c)
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/0ed430e3/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
index 379e62a..5a14624 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
 import org.apache.spark.sql.catalyst.expressions.codegen._
-import org.apache.spark.sql.catalyst.util.{sequenceOption, TypeUtils}
+import org.apache.spark.sql.catalyst.util.TypeUtils
 import org.apache.spark.sql.types._
 
 
@@ -78,17 +78,23 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
   override def sql: String = s"(IF(${predicate.sql}, ${trueValue.sql}, ${falseValue.sql}))"
 }
 
-trait CaseWhenLike extends Expression {
+/**
+ * Case statements of the form "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END".
+ * When a = true, returns b; when c = true, returns d; else returns e.
+ */
+case class CaseWhen(branches: Seq[Expression]) extends Expression {
+
+  // Use private[this] Array to speed up evaluation.
+  @transient private[this] lazy val branchesArr = branches.toArray
 
-  // Note that `branches` are considered in consecutive pairs (cond, val), and the optional last
-  // element is the value for the default catch-all case (if provided).
-  // Hence, `branches` consists of at least two elements, and can have an odd or even length.
-  def branches: Seq[Expression]
+  override def children: Seq[Expression] = branches
 
   @transient lazy val whenList =
     branches.sliding(2, 2).collect { case Seq(whenExpr, _) => whenExpr }.toSeq
+
   @transient lazy val thenList =
     branches.sliding(2, 2).collect { case Seq(_, thenExpr) => thenExpr }.toSeq
+
   val elseValue = if (branches.length % 2 == 0) None else Option(branches.last)
 
   // both then and else expressions should be considered.
@@ -97,47 +103,26 @@ trait CaseWhenLike extends Expression {
     case Seq(dt1, dt2) => dt1.sameType(dt2)
   }
 
-  override def checkInputDataTypes(): TypeCheckResult = {
-    if (valueTypesEqual) {
-      checkTypesInternal()
-    } else {
-      TypeCheckResult.TypeCheckFailure(
-        "THEN and ELSE expressions should all be same type or coercible to a common type")
-    }
-  }
-
-  protected def checkTypesInternal(): TypeCheckResult
-
   override def dataType: DataType = thenList.head.dataType
 
   override def nullable: Boolean = {
     // If no value is nullable and no elseValue is provided, the whole statement defaults to null.
     thenList.exists(_.nullable) || elseValue.map(_.nullable).getOrElse(true)
   }
-}
-
-// scalastyle:off
-/**
- * Case statements of the form "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END".
- * Refer to this link for the corresponding semantics:
- * https://cwiki.apache.org/confluence/display/Hive/LanguageManual+UDF#LanguageManualUDF-ConditionalFunctions
- */
-// scalastyle:on
-case class CaseWhen(branches: Seq[Expression]) extends CaseWhenLike {
-
-  // Use private[this] Array to speed up evaluation.
-  @transient private[this] lazy val branchesArr = branches.toArray
 
-  override def children: Seq[Expression] = branches
-
-  override protected def checkTypesInternal(): TypeCheckResult = {
-    if (whenList.forall(_.dataType == BooleanType)) {
-      TypeCheckResult.TypeCheckSuccess
+  override def checkInputDataTypes(): TypeCheckResult = {
+    if (valueTypesEqual) {
+      if (whenList.forall(_.dataType == BooleanType)) {
+        TypeCheckResult.TypeCheckSuccess
+      } else {
+        val index = whenList.indexWhere(_.dataType != BooleanType)
+        TypeCheckResult.TypeCheckFailure(
+          s"WHEN expressions in CaseWhen should all be boolean type, " +
+            s"but the ${index + 1}th when expression's type is ${whenList(index)}")
+      }
     } else {
-      val index = whenList.indexWhere(_.dataType != BooleanType)
       TypeCheckResult.TypeCheckFailure(
-        s"WHEN expressions in CaseWhen should all be boolean type, " +
-          s"but the ${index + 1}th when expression's type is ${whenList(index)}")
+        "THEN and ELSE expressions should all be same type or coercible to a common type")
     }
   }
 
@@ -227,125 +212,23 @@ case class CaseWhen(branches: Seq[Expression]) extends CaseWhenLike {
   }
 }
 
-// scalastyle:off
 /**
  * Case statements of the form "CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END".
- * Refer to this link for the corresponding semantics:
- * https://cwiki.apache.org/confluence/display/Hive/LanguageManual+UDF#LanguageManualUDF-ConditionalFunctions
+ * When a = b, returns c; when a = d, returns e; else returns f.
  */
-// scalastyle:on
-case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseWhenLike {
-
-  // Use private[this] Array to speed up evaluation.
-  @transient private[this] lazy val branchesArr = branches.toArray
-
-  override def children: Seq[Expression] = key +: branches
-
-  override protected def checkTypesInternal(): TypeCheckResult = {
-    if ((key +: whenList).map(_.dataType).distinct.size > 1) {
-      TypeCheckResult.TypeCheckFailure(
-        "key and WHEN expressions should all be same type or coercible to a common type")
-    } else {
-      TypeCheckResult.TypeCheckSuccess
-    }
-  }
-
-  private def evalElse(input: InternalRow): Any = {
-    if (branchesArr.length % 2 == 0) {
-      null
-    } else {
-      branchesArr(branchesArr.length - 1).eval(input)
-    }
-  }
-
-  /** Written in imperative fashion for performance considerations. */
-  override def eval(input: InternalRow): Any = {
-    val evaluatedKey = key.eval(input)
-    // If key is null, we can just return the else part or null if there is no else.
-    // If key is not null but doesn't match any when part, we need to return
-    // the else part or null if there is no else, according to Hive's semantics.
-    if (evaluatedKey != null) {
-      val len = branchesArr.length
-      var i = 0
-      while (i < len - 1) {
-        if (evaluatedKey ==  branchesArr(i).eval(input)) {
-          return branchesArr(i + 1).eval(input)
-        }
-        i += 2
-      }
-    }
-    evalElse(input)
-  }
-
-  override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
-    val keyEval = key.gen(ctx)
-    val len = branchesArr.length
-    val got = ctx.freshName("got")
-
-    val cases = (0 until len/2).map { i =>
-      val cond = branchesArr(i * 2).gen(ctx)
-      val res = branchesArr(i * 2 + 1).gen(ctx)
-      s"""
-        if (!$got) {
-          ${cond.code}
-          if (!${cond.isNull} && ${ctx.genEqual(key.dataType, keyEval.value, cond.value)}) {
-            $got = true;
-            ${res.code}
-            ${ev.isNull} = ${res.isNull};
-            ${ev.value} = ${res.value};
-          }
-        }
-      """
-    }.mkString("\n")
-
-    val other = if (len % 2 == 1) {
-      val res = branchesArr(len - 1).gen(ctx)
-      s"""
-        if (!$got) {
-          ${res.code}
-          ${ev.isNull} = ${res.isNull};
-          ${ev.value} = ${res.value};
-        }
-      """
-    } else {
-      ""
-    }
-
-    s"""
-      boolean $got = false;
-      boolean ${ev.isNull} = true;
-      ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
-      ${keyEval.code}
-      if (!${keyEval.isNull}) {
-        $cases
+object CaseKeyWhen {
+  def apply(key: Expression, branches: Seq[Expression]): CaseWhen = {
+    val newBranches = branches.zipWithIndex.map { case (expr, i) =>
+      if (i % 2 == 0 && i != branches.size - 1) {
+        // If this expression is at even position, then it is either a branch condition, or
+        // the very last value that is the "else value". The "i != branches.size - 1" makes
+        // sure we are not adding an EqualTo to the "else value".
+        EqualTo(key, expr)
+      } else {
+        expr
       }
-      $other
-    """
-  }
-
-  override def toString: String = {
-    s"CASE $key" + branches.sliding(2, 2).map {
-      case Seq(cond, value) => s" WHEN $cond THEN $value"
-      case Seq(elseValue) => s" ELSE $elseValue"
-    }.mkString
-  }
-
-  override def sql: String = {
-    val keySQL = key.sql
-    val branchesSQL = branches.map(_.sql)
-    val (cases, maybeElse) = if (branches.length % 2 == 0) {
-      (branchesSQL, None)
-    } else {
-      (branchesSQL.init, Some(branchesSQL.last))
     }
-
-    val head = s"CASE $keySQL "
-    val tail = maybeElse.map(e => s" ELSE $e").getOrElse("") + " END"
-    val body = cases.grouped(2).map {
-      case Seq(whenExpr, thenExpr) => s"WHEN $whenExpr THEN $thenExpr"
-    }.mkString(" ")
-
-    head + body + tail
+    CaseWhen(newBranches)
   }
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/0ed430e3/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
index 58d808c..23b11af 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
@@ -299,7 +299,7 @@ class HiveTypeCoercionSuite extends PlanTest {
   }
 
   test("type coercion for CaseKeyWhen") {
-    ruleTest(HiveTypeCoercion.CaseWhenCoercion,
+    ruleTest(HiveTypeCoercion.ImplicitTypeCasts,
       CaseKeyWhen(Literal(1.toShort), Seq(Literal(1), Literal("a"))),
       CaseKeyWhen(Cast(Literal(1.toShort), IntegerType), Seq(Literal(1), Literal("a")))
     )


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org