You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2015/05/08 01:26:58 UTC

spark git commit: [SPARK-2155] [SQL] [WHEN D THEN E] [ELSE F] add CaseKeyWhen for "CASE a WHEN b THEN c * END"

Repository: spark
Updated Branches:
  refs/heads/master 937ba798c -> 35f0173b8


[SPARK-2155] [SQL] [WHEN D THEN E] [ELSE F] add CaseKeyWhen for "CASE a WHEN b THEN c * END"

Avoid translating to CaseWhen and evaluate the key expression many times.

Author: Wenchen Fan <cl...@outlook.com>

Closes #5979 from cloud-fan/condition and squashes the following commits:

3ce54e1 [Wenchen Fan] add CaseKeyWhen


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/35f0173b
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/35f0173b
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/35f0173b

Branch: refs/heads/master
Commit: 35f0173b8f67e2e506fc4575be6430cfb66e2238
Parents: 937ba79
Author: Wenchen Fan <cl...@outlook.com>
Authored: Thu May 7 16:26:49 2015 -0700
Committer: Michael Armbrust <mi...@databricks.com>
Committed: Thu May 7 16:26:49 2015 -0700

----------------------------------------------------------------------
 .../apache/spark/sql/catalyst/SqlParser.scala   |  10 +-
 .../catalyst/analysis/HiveTypeCoercion.scala    |  43 +++---
 .../sql/catalyst/expressions/Expression.scala   |   2 +-
 .../sql/catalyst/expressions/predicates.scala   | 135 +++++++++++++------
 .../expressions/ExpressionEvaluationSuite.scala |  26 ++++
 .../apache/spark/sql/DataFrameNaFunctions.scala |   9 +-
 .../org/apache/spark/sql/hive/HiveQl.scala      |  12 +-
 .../sql/hive/execution/SQLQuerySuite.scala      |   7 +
 8 files changed, 159 insertions(+), 85 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/35f0173b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
index 1d3a2dc..b06bfb2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
@@ -296,13 +296,13 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser {
     | LOWER ~ "(" ~> expression <~ ")" ^^ { case exp => Lower(exp) }
     | IF ~ "(" ~> expression ~ ("," ~> expression) ~ ("," ~> expression) <~ ")" ^^
       { case c ~ t ~ f => If(c, t, f) }
-    | CASE ~> expression.? ~ (WHEN ~> expression ~ (THEN ~> expression)).* ~
+    | CASE ~> expression.? ~ rep1(WHEN ~> expression ~ (THEN ~> expression)) ~
         (ELSE ~> expression).? <~ END ^^ {
           case casePart ~ altPart ~ elsePart =>
-            val altExprs = altPart.flatMap { case whenExpr ~ thenExpr =>
-              Seq(casePart.fold(whenExpr)(EqualTo(_, whenExpr)), thenExpr)
-            }
-            CaseWhen(altExprs ++ elsePart.toList)
+            val branches = altPart.flatMap { case whenExpr ~ thenExpr =>
+              Seq(whenExpr, thenExpr)
+            } ++ elsePart
+            casePart.map(CaseKeyWhen(_, branches)).getOrElse(CaseWhen(branches))
         }
     | (SUBSTR | SUBSTRING) ~ "(" ~> expression ~ ("," ~> expression) <~ ")" ^^
       { case s ~ p => Substring(s, p, Literal(Integer.MAX_VALUE)) }

http://git-wip-us.apache.org/repos/asf/spark/blob/35f0173b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index 873c75c..168a4e3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -631,31 +631,24 @@ trait HiveTypeCoercion {
     import HiveTypeCoercion._
 
     def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
-      case cw @ CaseWhen(branches) if !cw.resolved && !branches.exists(!_.resolved)  =>
-        val valueTypes = branches.sliding(2, 2).map {
-          case Seq(_, value) => value.dataType
-          case Seq(elseVal) => elseVal.dataType
-        }.toSeq
-
-        logDebug(s"Input values for null casting ${valueTypes.mkString(",")}")
-
-        if (valueTypes.distinct.size > 1) {
-          val commonType = valueTypes.reduce { (v1, v2) =>
-            findTightestCommonType(v1, v2)
-              .getOrElse(sys.error(
-                s"Types in CASE WHEN must be the same or coercible to a common type: $v1 != $v2"))
-          }
-          val transformedBranches = branches.sliding(2, 2).map {
-            case Seq(cond, value) if value.dataType != commonType =>
-              Seq(cond, Cast(value, commonType))
-            case Seq(elseVal) if elseVal.dataType != commonType =>
-              Seq(Cast(elseVal, commonType))
-            case s => s
-          }.reduce(_ ++ _)
-          CaseWhen(transformedBranches)
-        } else {
-          // Types match up.  Hopefully some other rule fixes whatever is wrong with resolution.
-          cw
+      case cw: CaseWhenLike if !cw.resolved && cw.childrenResolved && !cw.valueTypesEqual  =>
+        logDebug(s"Input values for null casting ${cw.valueTypes.mkString(",")}")
+        val commonType = cw.valueTypes.reduce { (v1, v2) =>
+          findTightestCommonType(v1, v2).getOrElse(sys.error(
+            s"Types in CASE WHEN must be the same or coercible to a common type: $v1 != $v2"))
+        }
+        val transformedBranches = cw.branches.sliding(2, 2).map {
+          case Seq(when, value) if value.dataType != commonType =>
+            Seq(when, Cast(value, commonType))
+          case Seq(elseVal) if elseVal.dataType != commonType =>
+            Seq(Cast(elseVal, commonType))
+          case s => s
+        }.reduce(_ ++ _)
+        cw match {
+          case _: CaseWhen =>
+            CaseWhen(transformedBranches)
+          case CaseKeyWhen(key, _) =>
+            CaseKeyWhen(key, transformedBranches)
         }
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/35f0173b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index 4fd1bc4..0837a31 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -64,7 +64,7 @@ abstract class Expression extends TreeNode[Expression] {
    * Returns true if  all the children of this expression have been resolved to a specific schema
    * and false if any still contains any unresolved placeholders.
    */
-  def childrenResolved: Boolean = !children.exists(!_.resolved)
+  def childrenResolved: Boolean = children.forall(_.resolved)
 
   /**
    * Returns a string representation of this expression that does not have developer centric

http://git-wip-us.apache.org/repos/asf/spark/blob/35f0173b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index 26c38c5..50b0f3e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -353,79 +353,134 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
   override def toString: String = s"if ($predicate) $trueValue else $falseValue"
 }
 
+trait CaseWhenLike extends Expression {
+  self: Product =>
+
+  type EvaluatedType = Any
+
+  // Note that `branches` are considered in consecutive pairs (cond, val), and the optional last
+  // element is the value for the default catch-all case (if provided).
+  // Hence, `branches` consists of at least two elements, and can have an odd or even length.
+  def branches: Seq[Expression]
+
+  @transient lazy val whenList =
+    branches.sliding(2, 2).collect { case Seq(whenExpr, _) => whenExpr }.toSeq
+  @transient lazy val thenList =
+    branches.sliding(2, 2).collect { case Seq(_, thenExpr) => thenExpr }.toSeq
+  val elseValue = if (branches.length % 2 == 0) None else Option(branches.last)
+
+  // both then and else val should be considered.
+  def valueTypes: Seq[DataType] = (thenList ++ elseValue).map(_.dataType)
+  def valueTypesEqual: Boolean = valueTypes.distinct.size <= 1
+
+  override def dataType: DataType = {
+    if (!resolved) {
+      throw new UnresolvedException(this, "cannot resolve due to differing types in some branches")
+    }
+    valueTypes.head
+  }
+
+  override def nullable: Boolean = {
+    // If no value is nullable and no elseValue is provided, the whole statement defaults to null.
+    thenList.exists(_.nullable) || (elseValue.map(_.nullable).getOrElse(true))
+  }
+}
+
 // scalastyle:off
 /**
  * Case statements of the form "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END".
  * Refer to this link for the corresponding semantics:
  * https://cwiki.apache.org/confluence/display/Hive/LanguageManual+UDF#LanguageManualUDF-ConditionalFunctions
- *
- * The other form of case statements "CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END" gets
- * translated to this form at parsing time.  Namely, such a statement gets translated to
- * "CASE WHEN a=b THEN c [WHEN a=d THEN e]* [ELSE f] END".
- *
- * Note that `branches` are considered in consecutive pairs (cond, val), and the optional last
- * element is the value for the default catch-all case (if provided). Hence, `branches` consists of
- * at least two elements, and can have an odd or even length.
  */
 // scalastyle:on
-case class CaseWhen(branches: Seq[Expression]) extends Expression {
-  type EvaluatedType = Any
+case class CaseWhen(branches: Seq[Expression]) extends CaseWhenLike {
+
+  // Use private[this] Array to speed up evaluation.
+  @transient private[this] lazy val branchesArr = branches.toArray
 
   override def children: Seq[Expression] = branches
 
-  override def dataType: DataType = {
-    if (!resolved) {
-      throw new UnresolvedException(this, "cannot resolve due to differing types in some branches")
+  override lazy val resolved: Boolean =
+    childrenResolved &&
+    whenList.forall(_.dataType == BooleanType) &&
+    valueTypesEqual
+
+  /** Written in imperative fashion for performance considerations. */
+  override def eval(input: Row): Any = {
+    val len = branchesArr.length
+    var i = 0
+    // If all branches fail and an elseVal is not provided, the whole statement
+    // defaults to null, according to Hive's semantics.
+    while (i < len - 1) {
+      if (branchesArr(i).eval(input) == true) {
+        return branchesArr(i + 1).eval(input)
+      }
+      i += 2
+    }
+    var res: Any = null
+    if (i == len - 1) {
+      res = branchesArr(i).eval(input)
     }
-    branches(1).dataType
+    return res
   }
 
+  override def toString: String = {
+    "CASE" + branches.sliding(2, 2).map {
+      case Seq(cond, value) => s" WHEN $cond THEN $value"
+      case Seq(elseValue) => s" ELSE $elseValue"
+    }.mkString
+  }
+}
+
+// scalastyle:off
+/**
+ * Case statements of the form "CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END".
+ * Refer to this link for the corresponding semantics:
+ * https://cwiki.apache.org/confluence/display/Hive/LanguageManual+UDF#LanguageManualUDF-ConditionalFunctions
+ */
+// scalastyle:on
+case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseWhenLike {
+
+  // Use private[this] Array to speed up evaluation.
   @transient private[this] lazy val branchesArr = branches.toArray
-  @transient private[this] lazy val predicates =
-    branches.sliding(2, 2).collect { case Seq(cond, _) => cond }.toSeq
-  @transient private[this] lazy val values =
-    branches.sliding(2, 2).collect { case Seq(_, value) => value }.toSeq
-  @transient private[this] lazy val elseValue =
-    if (branches.length % 2 == 0) None else Option(branches.last)
 
-  override def nullable: Boolean = {
-    // If no value is nullable and no elseValue is provided, the whole statement defaults to null.
-    values.exists(_.nullable) || (elseValue.map(_.nullable).getOrElse(true))
-  }
+  override def children: Seq[Expression] = key +: branches
 
-  override lazy val resolved: Boolean = {
-    if (!childrenResolved) {
-      false
-    } else {
-      val allCondBooleans = predicates.forall(_.dataType == BooleanType)
-      // both then and else val should be considered.
-      val dataTypesEqual = (values ++ elseValue).map(_.dataType).distinct.size <= 1
-      allCondBooleans && dataTypesEqual
-    }
-  }
+  override lazy val resolved: Boolean =
+    childrenResolved && valueTypesEqual
 
   /** Written in imperative fashion for performance considerations. */
   override def eval(input: Row): Any = {
+    val evaluatedKey = key.eval(input)
     val len = branchesArr.length
     var i = 0
     // If all branches fail and an elseVal is not provided, the whole statement
     // defaults to null, according to Hive's semantics.
-    var res: Any = null
     while (i < len - 1) {
-      if (branchesArr(i).eval(input) == true) {
-        res = branchesArr(i + 1).eval(input)
-        return res
+      if (equalNullSafe(evaluatedKey, branchesArr(i).eval(input))) {
+        return branchesArr(i + 1).eval(input)
       }
       i += 2
     }
+    var res: Any = null
     if (i == len - 1) {
       res = branchesArr(i).eval(input)
     }
-    res
+    return res
+  }
+
+  private def equalNullSafe(l: Any, r: Any) = {
+    if (l == null && r == null) {
+      true
+    } else if (l == null || r == null) {
+      false
+    } else {
+      l == r
+    }
   }
 
   override def toString: String = {
-    "CASE" + branches.sliding(2, 2).map {
+    s"CASE $key" + branches.sliding(2, 2).map {
       case Seq(cond, value) => s" WHEN $cond THEN $value"
       case Seq(elseValue) => s" ELSE $elseValue"
     }.mkString

http://git-wip-us.apache.org/repos/asf/spark/blob/35f0173b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
index faaa55a..88d36d1 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
@@ -850,6 +850,32 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
     assert(CaseWhen(Seq(c2, c4_notNull, c3, c5)).nullable === true)
   }
 
+  test("case key when") {
+    val row = create_row(null, 1, 2, "a", "b", "c")
+    val c1 = 'a.int.at(0)
+    val c2 = 'a.int.at(1)
+    val c3 = 'a.int.at(2)
+    val c4 = 'a.string.at(3)
+    val c5 = 'a.string.at(4)
+    val c6 = 'a.string.at(5)
+
+    val literalNull = Literal.create(null, BooleanType)
+    val literalInt = Literal(1)
+    val literalString = Literal("a")
+
+    checkEvaluation(CaseKeyWhen(c1, Seq(c2, c4, c5)), "b", row)
+    checkEvaluation(CaseKeyWhen(c1, Seq(c2, c4, literalNull, c5, c6)), "b", row)
+    checkEvaluation(CaseKeyWhen(c2, Seq(literalInt, c4, c5)), "a", row)
+    checkEvaluation(CaseKeyWhen(c2, Seq(c1, c4, c5)), "b", row)
+    checkEvaluation(CaseKeyWhen(c4, Seq(literalString, c2, c3)), 1, row)
+    checkEvaluation(CaseKeyWhen(c4, Seq(c1, c3, c5, c2, Literal(3))), 3, row)
+
+    checkEvaluation(CaseKeyWhen(literalInt, Seq(c2, c4, c5)), "a", row)
+    checkEvaluation(CaseKeyWhen(literalString, Seq(c5, c2, c4, c3)), 2, row)
+    checkEvaluation(CaseKeyWhen(literalInt, Seq(c5, c2, c4, c3)), null, row)
+    checkEvaluation(CaseKeyWhen(literalNull, Seq(c5, c2, c1, c3)), 2, row)
+  }
+
   test("complex type") {
     val row = create_row(
       "^Ba*n",                                // 0

http://git-wip-us.apache.org/repos/asf/spark/blob/35f0173b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
index 481ed49..4a54120 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
@@ -357,11 +357,12 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
    * TODO: This can be optimized to use broadcast join when replacementMap is large.
    */
   private def replaceCol(col: StructField, replacementMap: Map[_, _]): Column = {
-    val branches: Seq[Expression] = replacementMap.flatMap { case (source, target) =>
-      df.col(col.name).equalTo(lit(source).cast(col.dataType)).expr ::
-        lit(target).cast(col.dataType).expr :: Nil
+    val keyExpr = df.col(col.name).expr
+    def buildExpr(v: Any) = Cast(Literal(v), keyExpr.dataType)
+    val branches = replacementMap.flatMap { case (source, target) =>
+      Seq(buildExpr(source), buildExpr(target))
     }.toSeq
-    new Column(CaseWhen(branches ++ Seq(df.col(col.name).expr))).as(col.name)
+    new Column(CaseKeyWhen(keyExpr, branches :+ keyExpr)).as(col.name)
   }
 
   private def convertToDouble(v: Any): Double = v match {

http://git-wip-us.apache.org/repos/asf/spark/blob/35f0173b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
index 4e51473..6176aee 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
@@ -1246,16 +1246,8 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
     case Token("TOK_FUNCTION", Token(WHEN(), Nil) :: branches) =>
       CaseWhen(branches.map(nodeToExpr))
     case Token("TOK_FUNCTION", Token(CASE(), Nil) :: branches) =>
-      val transformed = branches.drop(1).sliding(2, 2).map {
-        case Seq(condVal, value) =>
-          // FIXME (SPARK-2155): the key will get evaluated for multiple times in CaseWhen's eval().
-          // Hence effectful / non-deterministic key expressions are *not* supported at the moment.
-          // We should consider adding new Expressions to get around this.
-          Seq(EqualTo(nodeToExpr(branches(0)), nodeToExpr(condVal)),
-              nodeToExpr(value))
-        case Seq(elseVal) => Seq(nodeToExpr(elseVal))
-      }.toSeq.reduce(_ ++ _)
-      CaseWhen(transformed)
+      val keyExpr = nodeToExpr(branches.head)
+      CaseKeyWhen(keyExpr, branches.drop(1).map(nodeToExpr))
 
     /* Complex datatype manipulation */
     case Token("[", child :: ordinal :: Nil) =>

http://git-wip-us.apache.org/repos/asf/spark/blob/35f0173b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
index 616352d..c605f10 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
@@ -751,4 +751,11 @@ class SQLQuerySuite extends QueryTest {
         (6, "c", 0, 6)
       ).map(i => Row(i._1, i._2, i._3, i._4)))
   }
+
+  test("test case key when") {
+    (1 to 5).map(i => (i, i.toString)).toDF("k", "v").registerTempTable("t")
+    checkAnswer(
+      sql("SELECT CASE k WHEN 2 THEN 22 WHEN 4 THEN 44 ELSE 0 END, v FROM t"),
+      Row(0, "1") :: Row(22, "2") :: Row(0, "3") :: Row(44, "4") :: Row(0, "5") :: Nil)
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org