You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2016/01/13 21:44:43 UTC

spark git commit: [SPARK-12791][SQL] Simplify CaseWhen by breaking "branches" into "conditions" and "values"

Repository: spark
Updated Branches:
  refs/heads/master c2ea79f96 -> cbbcd8e42


[SPARK-12791][SQL] Simplify CaseWhen by breaking "branches" into "conditions" and "values"

This pull request rewrites CaseWhen expression to break the single, monolithic "branches" field into a sequence of tuples (Seq[(condition, value)]) and an explicit optional elseValue field.

Prior to this pull request, each even position in "branches" represents the condition for each branch, and each odd position represents the value for each branch. The use of them have been pretty confusing with a lot sliding windows or grouped(2) calls.

Author: Reynold Xin <rx...@databricks.com>

Closes #10734 from rxin/simplify-case.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/cbbcd8e4
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/cbbcd8e4
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/cbbcd8e4

Branch: refs/heads/master
Commit: cbbcd8e4250aeec700f04c231f8be2f787243f1f
Parents: c2ea79f
Author: Reynold Xin <rx...@databricks.com>
Authored: Wed Jan 13 12:44:35 2016 -0800
Committer: Reynold Xin <rx...@databricks.com>
Committed: Wed Jan 13 12:44:35 2016 -0800

----------------------------------------------------------------------
 python/pyspark/sql/column.py                    |  24 ++--
 .../apache/spark/sql/catalyst/CatalystQl.scala  |   2 +-
 .../apache/spark/sql/catalyst/SqlParser.scala   |   3 +-
 .../catalyst/analysis/HiveTypeCoercion.scala    |  26 ++--
 .../expressions/conditionalExpressions.scala    | 137 +++++++++----------
 .../spark/sql/catalyst/trees/TreeNode.scala     |   9 ++
 .../sql/catalyst/analysis/AnalysisSuite.scala   |   2 +-
 .../analysis/ExpressionTypeCheckingSuite.scala  |   4 +-
 .../analysis/HiveTypeCoercionSuite.scala        |  15 +-
 .../ConditionalExpressionSuite.scala            |  51 +++----
 .../scala/org/apache/spark/sql/Column.scala     |  19 +--
 .../scala/org/apache/spark/sql/functions.scala  |   2 +-
 12 files changed, 156 insertions(+), 138 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/cbbcd8e4/python/pyspark/sql/column.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py
index 900def5..320451c 100644
--- a/python/pyspark/sql/column.py
+++ b/python/pyspark/sql/column.py
@@ -368,12 +368,12 @@ class Column(object):
 
         >>> from pyspark.sql import functions as F
         >>> df.select(df.name, F.when(df.age > 4, 1).when(df.age < 3, -1).otherwise(0)).show()
-        +-----+--------------------------------------------------------+
-        | name|CASE WHEN (age > 4) THEN 1 WHEN (age < 3) THEN -1 ELSE 0|
-        +-----+--------------------------------------------------------+
-        |Alice|                                                      -1|
-        |  Bob|                                                       1|
-        +-----+--------------------------------------------------------+
+        +-----+------------------------------------------------------------+
+        | name|CASE WHEN (age > 4) THEN 1 WHEN (age < 3) THEN -1 ELSE 0 END|
+        +-----+------------------------------------------------------------+
+        |Alice|                                                          -1|
+        |  Bob|                                                           1|
+        +-----+------------------------------------------------------------+
         """
         if not isinstance(condition, Column):
             raise TypeError("condition should be a Column")
@@ -393,12 +393,12 @@ class Column(object):
 
         >>> from pyspark.sql import functions as F
         >>> df.select(df.name, F.when(df.age > 3, 1).otherwise(0)).show()
-        +-----+---------------------------------+
-        | name|CASE WHEN (age > 3) THEN 1 ELSE 0|
-        +-----+---------------------------------+
-        |Alice|                                0|
-        |  Bob|                                1|
-        +-----+---------------------------------+
+        +-----+-------------------------------------+
+        | name|CASE WHEN (age > 3) THEN 1 ELSE 0 END|
+        +-----+-------------------------------------+
+        |Alice|                                    0|
+        |  Bob|                                    1|
+        +-----+-------------------------------------+
         """
         v = value._jc if isinstance(value, Column) else value
         jc = self._jc.otherwise(v)

http://git-wip-us.apache.org/repos/asf/spark/blob/cbbcd8e4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala
index c87b6c8..d0fbdac 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala
@@ -752,7 +752,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
 
     /* Case statements */
     case Token("TOK_FUNCTION", Token(WHEN(), Nil) :: branches) =>
-      CaseWhen(branches.map(nodeToExpr))
+      CaseWhen.createFromParser(branches.map(nodeToExpr))
     case Token("TOK_FUNCTION", Token(CASE(), Nil) :: branches) =>
       val keyExpr = nodeToExpr(branches.head)
       CaseKeyWhen(keyExpr, branches.drop(1).map(nodeToExpr))

http://git-wip-us.apache.org/repos/asf/spark/blob/cbbcd8e4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
index 6ec408a..85ff4ea 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
@@ -305,7 +305,8 @@ object SqlParser extends AbstractSparkSQLParser with DataTypeParser {
           throw new AnalysisException(s"invalid function approximate($s) $udfName")
         }
       }
-    | CASE ~> whenThenElse ^^ CaseWhen
+    | CASE ~> whenThenElse ^^
+      { case branches => CaseWhen.createFromParser(branches) }
     | CASE ~> expression ~ whenThenElse ^^
       { case keyPart ~ branches => CaseKeyWhen(keyPart, branches) }
     )

http://git-wip-us.apache.org/repos/asf/spark/blob/cbbcd8e4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index 980b5d5..2737fe3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -621,14 +621,24 @@ object HiveTypeCoercion {
       case c: CaseWhen if c.childrenResolved && !c.valueTypesEqual =>
         val maybeCommonType = findWiderCommonType(c.valueTypes)
         maybeCommonType.map { commonType =>
-          val castedBranches = c.branches.grouped(2).map {
-            case Seq(when, value) if value.dataType != commonType =>
-              Seq(when, Cast(value, commonType))
-            case Seq(elseVal) if elseVal.dataType != commonType =>
-              Seq(Cast(elseVal, commonType))
-            case other => other
-          }.reduce(_ ++ _)
-          CaseWhen(castedBranches)
+          var changed = false
+          val newBranches = c.branches.map { case (condition, value) =>
+            if (value.dataType.sameType(commonType)) {
+              (condition, value)
+            } else {
+              changed = true
+              (condition, Cast(value, commonType))
+            }
+          }
+          val newElseValue = c.elseValue.map { value =>
+            if (value.dataType.sameType(commonType)) {
+              value
+            } else {
+              changed = true
+              Cast(value, commonType)
+            }
+          }
+          if (changed) CaseWhen(newBranches, newElseValue) else c
         }.getOrElse(c)
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/cbbcd8e4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
index 5a14624..8cc7bc1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
@@ -81,44 +81,39 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
 /**
  * Case statements of the form "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END".
  * When a = true, returns b; when c = true, returns d; else returns e.
+ *
+ * @param branches seq of (branch condition, branch value)
+ * @param elseValue optional value for the else branch
  */
-case class CaseWhen(branches: Seq[Expression]) extends Expression {
-
-  // Use private[this] Array to speed up evaluation.
-  @transient private[this] lazy val branchesArr = branches.toArray
-
-  override def children: Seq[Expression] = branches
-
-  @transient lazy val whenList =
-    branches.sliding(2, 2).collect { case Seq(whenExpr, _) => whenExpr }.toSeq
-
-  @transient lazy val thenList =
-    branches.sliding(2, 2).collect { case Seq(_, thenExpr) => thenExpr }.toSeq
+case class CaseWhen(branches: Seq[(Expression, Expression)], elseValue: Option[Expression] = None)
+  extends Expression {
 
-  val elseValue = if (branches.length % 2 == 0) None else Option(branches.last)
+  override def children: Seq[Expression] = branches.flatMap(b => b._1 :: b._2 :: Nil) ++ elseValue
 
   // both then and else expressions should be considered.
-  def valueTypes: Seq[DataType] = (thenList ++ elseValue).map(_.dataType)
+  def valueTypes: Seq[DataType] = branches.map(_._2.dataType) ++ elseValue.map(_.dataType)
+
   def valueTypesEqual: Boolean = valueTypes.size <= 1 || valueTypes.sliding(2, 1).forall {
     case Seq(dt1, dt2) => dt1.sameType(dt2)
   }
 
-  override def dataType: DataType = thenList.head.dataType
+  override def dataType: DataType = branches.head._2.dataType
 
   override def nullable: Boolean = {
-    // If no value is nullable and no elseValue is provided, the whole statement defaults to null.
-    thenList.exists(_.nullable) || elseValue.map(_.nullable).getOrElse(true)
+    // Result is nullable if any of the branch is nullable, or if the else value is nullable
+    branches.exists(_._2.nullable) || elseValue.map(_.nullable).getOrElse(true)
   }
 
   override def checkInputDataTypes(): TypeCheckResult = {
+    // Make sure all branch conditions are boolean types.
     if (valueTypesEqual) {
-      if (whenList.forall(_.dataType == BooleanType)) {
+      if (branches.forall(_._1.dataType == BooleanType)) {
         TypeCheckResult.TypeCheckSuccess
       } else {
-        val index = whenList.indexWhere(_.dataType != BooleanType)
+        val index = branches.indexWhere(_._1.dataType != BooleanType)
         TypeCheckResult.TypeCheckFailure(
           s"WHEN expressions in CaseWhen should all be boolean type, " +
-            s"but the ${index + 1}th when expression's type is ${whenList(index)}")
+            s"but the ${index + 1}th when expression's type is ${branches(index)._1}")
       }
     } else {
       TypeCheckResult.TypeCheckFailure(
@@ -127,31 +122,26 @@ case class CaseWhen(branches: Seq[Expression]) extends Expression {
   }
 
   override def eval(input: InternalRow): Any = {
-    // Written in imperative fashion for performance considerations
-    val len = branchesArr.length
     var i = 0
-    // If all branches fail and an elseVal is not provided, the whole statement
-    // defaults to null, according to Hive's semantics.
-    while (i < len - 1) {
-      if (branchesArr(i).eval(input) == true) {
-        return branchesArr(i + 1).eval(input)
+    while (i < branches.size) {
+      if (java.lang.Boolean.TRUE.equals(branches(i)._1.eval(input))) {
+        return branches(i)._2.eval(input)
       }
-      i += 2
+      i += 1
     }
-    var res: Any = null
-    if (i == len - 1) {
-      res = branchesArr(i).eval(input)
+    if (elseValue.isDefined) {
+      return elseValue.get.eval(input)
+    } else {
+      return null
     }
-    return res
   }
 
   override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
-    val len = branchesArr.length
     val got = ctx.freshName("got")
 
-    val cases = (0 until len/2).map { i =>
-      val cond = branchesArr(i * 2).gen(ctx)
-      val res = branchesArr(i * 2 + 1).gen(ctx)
+    val cases = branches.map { case (condition, value) =>
+      val cond = condition.gen(ctx)
+      val res = value.gen(ctx)
       s"""
         if (!$got) {
           ${cond.code}
@@ -165,17 +155,19 @@ case class CaseWhen(branches: Seq[Expression]) extends Expression {
       """
     }.mkString("\n")
 
-    val other = if (len % 2 == 1) {
-      val res = branchesArr(len - 1).gen(ctx)
-      s"""
+    val elseCase = {
+      if (elseValue.isDefined) {
+        val res = elseValue.get.gen(ctx)
+        s"""
         if (!$got) {
           ${res.code}
           ${ev.isNull} = ${res.isNull};
           ${ev.value} = ${res.value};
         }
-      """
-    } else {
-      ""
+        """
+      } else {
+        ""
+      }
     }
 
     s"""
@@ -183,32 +175,42 @@ case class CaseWhen(branches: Seq[Expression]) extends Expression {
       boolean ${ev.isNull} = true;
       ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
       $cases
-      $other
+      $elseCase
     """
   }
 
   override def toString: String = {
-    "CASE" + branches.sliding(2, 2).map {
-      case Seq(cond, value) => s" WHEN $cond THEN $value"
-      case Seq(elseValue) => s" ELSE $elseValue"
-    }.mkString
+    val cases = branches.map { case (c, v) => s" WHEN $c THEN $v" }.mkString
+    val elseCase = elseValue.map(" ELSE " + _).getOrElse("")
+    "CASE" + cases + elseCase + " END"
   }
 
   override def sql: String = {
-    val branchesSQL = branches.map(_.sql)
-    val (cases, maybeElse) = if (branches.length % 2 == 0) {
-      (branchesSQL, None)
-    } else {
-      (branchesSQL.init, Some(branchesSQL.last))
-    }
+    val cases = branches.map { case (c, v) => s" WHEN ${c.sql} THEN ${v.sql}" }.mkString
+    val elseCase = elseValue.map(" ELSE " + _.sql).getOrElse("")
+    "CASE" + cases + elseCase + " END"
+  }
+}
 
-    val head = s"CASE "
-    val tail = maybeElse.map(e => s" ELSE $e").getOrElse("") + " END"
-    val body = cases.grouped(2).map {
-      case Seq(whenExpr, thenExpr) => s"WHEN $whenExpr THEN $thenExpr"
-    }.mkString(" ")
+/** Factory methods for CaseWhen. */
+object CaseWhen {
 
-    head + body + tail
+  def apply(branches: Seq[(Expression, Expression)], elseValue: Expression): CaseWhen = {
+    CaseWhen(branches, Option(elseValue))
+  }
+
+  /**
+   * A factory method to faciliate the creation of this expression when used in parsers.
+   * @param branches Expressions at even position are the branch conditions, and expressions at odd
+   *                 position are branch values.
+   */
+  def createFromParser(branches: Seq[Expression]): CaseWhen = {
+    val cases = branches.grouped(2).flatMap {
+      case cond :: value :: Nil => Some((cond, value))
+      case value :: Nil => None
+    }.toArray.toSeq  // force materialization to make the seq serializable
+    val elseValue = if (branches.size % 2 == 1) Some(branches.last) else None
+    CaseWhen(cases, elseValue)
   }
 }
 
@@ -218,17 +220,12 @@ case class CaseWhen(branches: Seq[Expression]) extends Expression {
  */
 object CaseKeyWhen {
   def apply(key: Expression, branches: Seq[Expression]): CaseWhen = {
-    val newBranches = branches.zipWithIndex.map { case (expr, i) =>
-      if (i % 2 == 0 && i != branches.size - 1) {
-        // If this expression is at even position, then it is either a branch condition, or
-        // the very last value that is the "else value". The "i != branches.size - 1" makes
-        // sure we are not adding an EqualTo to the "else value".
-        EqualTo(key, expr)
-      } else {
-        expr
-      }
-    }
-    CaseWhen(newBranches)
+    val cases = branches.grouped(2).flatMap {
+      case cond :: value :: Nil => Some((EqualTo(key, cond), value))
+      case value :: Nil => None
+    }.toArray.toSeq  // force materialization to make the seq serializable
+    val elseValue = if (branches.size % 2 == 1) Some(branches.last) else None
+    CaseWhen(cases, elseValue)
   }
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/cbbcd8e4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
index d4be545..d0b29aa 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
@@ -315,6 +315,15 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
           } else {
             arg
           }
+        case tuple @ (arg1: TreeNode[_], arg2: TreeNode[_]) =>
+          val newChild1 = nextOperation(arg1.asInstanceOf[BaseType], rule)
+          val newChild2 = nextOperation(arg2.asInstanceOf[BaseType], rule)
+          if (!(newChild1 fastEquals arg1) || !(newChild2 fastEquals arg2)) {
+            changed = true
+            (newChild1, newChild2)
+          } else {
+            tuple
+          }
         case other => other
       }
       case nonChild: AnyRef => nonChild

http://git-wip-us.apache.org/repos/asf/spark/blob/cbbcd8e4/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index cf84855..975cd87 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -239,7 +239,7 @@ class AnalysisSuite extends AnalysisTest {
 
   test("SPARK-12102: Ignore nullablity when comparing two sides of case") {
     val relation = LocalRelation('a.struct('x.int), 'b.struct('x.int.withNullability(false)))
-    val plan = relation.select(CaseWhen(Seq(Literal(true), 'a, 'b)).as("val"))
+    val plan = relation.select(CaseWhen(Seq((Literal(true), 'a.attr)), 'b).as("val"))
     assertAnalysisSuccess(plan)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/cbbcd8e4/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
index 0521ed8..59549e3 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
@@ -132,13 +132,13 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
     assertErrorForDifferingTypes(If('booleanField, 'intField, 'booleanField))
 
     assertError(
-      CaseWhen(Seq('booleanField, 'intField, 'booleanField, 'mapField)),
+      CaseWhen(Seq(('booleanField.attr, 'intField.attr), ('booleanField.attr, 'mapField.attr))),
       "THEN and ELSE expressions should all be same type or coercible to a common type")
     assertError(
       CaseKeyWhen('intField, Seq('intField, 'stringField, 'intField, 'mapField)),
       "THEN and ELSE expressions should all be same type or coercible to a common type")
     assertError(
-      CaseWhen(Seq('booleanField, 'intField, 'intField, 'intField)),
+      CaseWhen(Seq(('booleanField.attr, 'intField.attr), ('intField.attr, 'intField.attr))),
       "WHEN expressions in CaseWhen should all be boolean type")
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/cbbcd8e4/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
index 40378c6..b1f6c0b 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
@@ -308,15 +308,14 @@ class HiveTypeCoercionSuite extends PlanTest {
       CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a")))
     )
     ruleTest(HiveTypeCoercion.CaseWhenCoercion,
-      CaseWhen(Seq(Literal(true), Literal(1.2), Literal.create(1, DecimalType(7, 2)))),
-      CaseWhen(Seq(
-        Literal(true), Literal(1.2), Cast(Literal.create(1, DecimalType(7, 2)), DoubleType)))
+      CaseWhen(Seq((Literal(true), Literal(1.2))), Literal.create(1, DecimalType(7, 2))),
+      CaseWhen(Seq((Literal(true), Literal(1.2))),
+        Cast(Literal.create(1, DecimalType(7, 2)), DoubleType))
     )
     ruleTest(HiveTypeCoercion.CaseWhenCoercion,
-      CaseWhen(Seq(Literal(true), Literal(100L), Literal.create(1, DecimalType(7, 2)))),
-      CaseWhen(Seq(
-        Literal(true), Cast(Literal(100L), DecimalType(22, 2)),
-        Cast(Literal.create(1, DecimalType(7, 2)), DecimalType(22, 2))))
+      CaseWhen(Seq((Literal(true), Literal(100L))), Literal.create(1, DecimalType(7, 2))),
+      CaseWhen(Seq((Literal(true), Cast(Literal(100L), DecimalType(22, 2)))),
+        Cast(Literal.create(1, DecimalType(7, 2)), DecimalType(22, 2)))
     )
   }
 
@@ -452,7 +451,7 @@ class HiveTypeCoercionSuite extends PlanTest {
     val expectedTypes = Seq(DecimalType(10, 5), DecimalType(10, 5), DecimalType(15, 5),
       DecimalType(25, 5), DoubleType, DoubleType)
 
-    rightTypes.zip(expectedTypes).map { case (rType, expectedType) =>
+    rightTypes.zip(expectedTypes).foreach { case (rType, expectedType) =>
       val plan2 = LocalRelation(
         AttributeReference("r", rType)())
 

http://git-wip-us.apache.org/repos/asf/spark/blob/cbbcd8e4/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala
index 4029da5..3c581ec 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala
@@ -80,38 +80,39 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
     val c5 = 'a.string.at(4)
     val c6 = 'a.string.at(5)
 
-    checkEvaluation(CaseWhen(Seq(c1, c4, c6)), "c", row)
-    checkEvaluation(CaseWhen(Seq(c2, c4, c6)), "c", row)
-    checkEvaluation(CaseWhen(Seq(c3, c4, c6)), "a", row)
-    checkEvaluation(CaseWhen(Seq(Literal.create(null, BooleanType), c4, c6)), "c", row)
-    checkEvaluation(CaseWhen(Seq(Literal.create(false, BooleanType), c4, c6)), "c", row)
-    checkEvaluation(CaseWhen(Seq(Literal.create(true, BooleanType), c4, c6)), "a", row)
-
-    checkEvaluation(CaseWhen(Seq(c3, c4, c2, c5, c6)), "a", row)
-    checkEvaluation(CaseWhen(Seq(c2, c4, c3, c5, c6)), "b", row)
-    checkEvaluation(CaseWhen(Seq(c1, c4, c2, c5, c6)), "c", row)
-    checkEvaluation(CaseWhen(Seq(c1, c4, c2, c5)), null, row)
-
-    assert(CaseWhen(Seq(c2, c4, c6)).nullable === true)
-    assert(CaseWhen(Seq(c2, c4, c3, c5, c6)).nullable === true)
-    assert(CaseWhen(Seq(c2, c4, c3, c5)).nullable === true)
+    checkEvaluation(CaseWhen(Seq((c1, c4)), c6), "c", row)
+    checkEvaluation(CaseWhen(Seq((c2, c4)), c6), "c", row)
+    checkEvaluation(CaseWhen(Seq((c3, c4)), c6), "a", row)
+    checkEvaluation(CaseWhen(Seq((Literal.create(null, BooleanType), c4)), c6), "c", row)
+    checkEvaluation(CaseWhen(Seq((Literal.create(false, BooleanType), c4)), c6), "c", row)
+    checkEvaluation(CaseWhen(Seq((Literal.create(true, BooleanType), c4)), c6), "a", row)
+
+    checkEvaluation(CaseWhen(Seq((c3, c4), (c2, c5)), c6), "a", row)
+    checkEvaluation(CaseWhen(Seq((c2, c4), (c3, c5)), c6), "b", row)
+    checkEvaluation(CaseWhen(Seq((c1, c4), (c2, c5)), c6), "c", row)
+    checkEvaluation(CaseWhen(Seq((c1, c4), (c2, c5))), null, row)
+
+    assert(CaseWhen(Seq((c2, c4)), c6).nullable === true)
+    assert(CaseWhen(Seq((c2, c4), (c3, c5)), c6).nullable === true)
+    assert(CaseWhen(Seq((c2, c4), (c3, c5))).nullable === true)
 
     val c4_notNull = 'a.boolean.notNull.at(3)
     val c5_notNull = 'a.boolean.notNull.at(4)
     val c6_notNull = 'a.boolean.notNull.at(5)
 
-    assert(CaseWhen(Seq(c2, c4_notNull, c6_notNull)).nullable === false)
-    assert(CaseWhen(Seq(c2, c4, c6_notNull)).nullable === true)
-    assert(CaseWhen(Seq(c2, c4_notNull, c6)).nullable === true)
+    assert(CaseWhen(Seq((c2, c4_notNull)), c6_notNull).nullable === false)
+    assert(CaseWhen(Seq((c2, c4)), c6_notNull).nullable === true)
+    assert(CaseWhen(Seq((c2, c4_notNull))).nullable === true)
+    assert(CaseWhen(Seq((c2, c4_notNull)), c6).nullable === true)
 
-    assert(CaseWhen(Seq(c2, c4_notNull, c3, c5_notNull, c6_notNull)).nullable === false)
-    assert(CaseWhen(Seq(c2, c4, c3, c5_notNull, c6_notNull)).nullable === true)
-    assert(CaseWhen(Seq(c2, c4_notNull, c3, c5, c6_notNull)).nullable === true)
-    assert(CaseWhen(Seq(c2, c4_notNull, c3, c5_notNull, c6)).nullable === true)
+    assert(CaseWhen(Seq((c2, c4_notNull), (c3, c5_notNull)), c6_notNull).nullable === false)
+    assert(CaseWhen(Seq((c2, c4), (c3, c5_notNull)), c6_notNull).nullable === true)
+    assert(CaseWhen(Seq((c2, c4_notNull), (c3, c5)), c6_notNull).nullable === true)
+    assert(CaseWhen(Seq((c2, c4_notNull), (c3, c5_notNull)), c6).nullable === true)
 
-    assert(CaseWhen(Seq(c2, c4_notNull, c3, c5_notNull)).nullable === true)
-    assert(CaseWhen(Seq(c2, c4, c3, c5_notNull)).nullable === true)
-    assert(CaseWhen(Seq(c2, c4_notNull, c3, c5)).nullable === true)
+    assert(CaseWhen(Seq((c2, c4_notNull), (c3, c5_notNull))).nullable === true)
+    assert(CaseWhen(Seq((c2, c4), (c3, c5_notNull))).nullable === true)
+    assert(CaseWhen(Seq((c2, c4_notNull), (c3, c5))).nullable === true)
   }
 
   test("case key when") {

http://git-wip-us.apache.org/repos/asf/spark/blob/cbbcd8e4/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index e8c61d6..6a020f9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -437,8 +437,11 @@ class Column(protected[sql] val expr: Expression) extends Logging {
    * @since 1.4.0
    */
   def when(condition: Column, value: Any): Column = this.expr match {
-    case CaseWhen(branches: Seq[Expression]) =>
-      withExpr { CaseWhen(branches ++ Seq(lit(condition).expr, lit(value).expr)) }
+    case CaseWhen(branches, None) =>
+      withExpr { CaseWhen(branches :+ (condition.expr, lit(value).expr)) }
+    case CaseWhen(branches, Some(_)) =>
+      throw new IllegalArgumentException(
+        "when() cannot be applied once otherwise() is applied")
     case _ =>
       throw new IllegalArgumentException(
         "when() can only be applied on a Column previously generated by when() function")
@@ -466,13 +469,11 @@ class Column(protected[sql] val expr: Expression) extends Logging {
    * @since 1.4.0
    */
   def otherwise(value: Any): Column = this.expr match {
-    case CaseWhen(branches: Seq[Expression]) =>
-      if (branches.size % 2 == 0) {
-        withExpr { CaseWhen(branches :+ lit(value).expr) }
-      } else {
-        throw new IllegalArgumentException(
-          "otherwise() can only be applied once on a Column previously generated by when()")
-      }
+    case CaseWhen(branches, None) =>
+      withExpr { CaseWhen(branches, Option(lit(value).expr)) }
+    case CaseWhen(branches, Some(_)) =>
+      throw new IllegalArgumentException(
+        "otherwise() can only be applied once on a Column previously generated by when()")
     case _ =>
       throw new IllegalArgumentException(
         "otherwise() can only be applied on a Column previously generated by when()")

http://git-wip-us.apache.org/repos/asf/spark/blob/cbbcd8e4/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 71fea27..b8ea226 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -1042,7 +1042,7 @@ object functions extends LegacyFunctions {
    * @since 1.4.0
    */
   def when(condition: Column, value: Any): Column = withExpr {
-    CaseWhen(Seq(condition.expr, lit(value).expr))
+    CaseWhen(Seq((condition.expr, lit(value).expr)))
   }
 
   /**


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org