You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2014/06/20 09:12:55 UTC

git commit: [SPARK-2196] [SQL] Fix nullability of CaseWhen.

Repository: spark
Updated Branches:
  refs/heads/master f46e02fcd -> 324952892


[SPARK-2196] [SQL] Fix nullability of CaseWhen.

`CaseWhen` should use `branches.length` to check if `elseValue` is provided or not.

Author: Takuya UESHIN <ue...@happy-camper.st>

Closes #1133 from ueshin/issues/SPARK-2196 and squashes the following commits:

510f12d [Takuya UESHIN] Add some tests.
dc25e8d [Takuya UESHIN] Fix nullable of CaseWhen to be nullable if the elseValue is nullable.
4f049cc [Takuya UESHIN] Fix nullability of CaseWhen.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/32495289
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/32495289
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/32495289

Branch: refs/heads/master
Commit: 324952892085d1933bcf392ce8f2ced452fe741e
Parents: f46e02f
Author: Takuya UESHIN <ue...@happy-camper.st>
Authored: Fri Jun 20 00:12:52 2014 -0700
Committer: Reynold Xin <rx...@apache.org>
Committed: Fri Jun 20 00:12:52 2014 -0700

----------------------------------------------------------------------
 .../sql/catalyst/expressions/predicates.scala   |  4 +-
 .../expressions/ExpressionEvaluationSuite.scala | 43 ++++++++++++++++++++
 2 files changed, 46 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/32495289/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index 2902906..2718d43 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -233,10 +233,12 @@ case class CaseWhen(branches: Seq[Expression]) extends Expression {
     branches.sliding(2, 2).collect { case Seq(cond, _) => cond }.toSeq
   @transient private[this] lazy val values =
     branches.sliding(2, 2).collect { case Seq(_, value) => value }.toSeq
+  @transient private[this] lazy val elseValue =
+    if (branches.length % 2 == 0) None else Option(branches.last)
 
   override def nullable = {
     // If no value is nullable and no elseValue is provided, the whole statement defaults to null.
-    values.exists(_.nullable) || (values.length % 2 == 0)
+    values.exists(_.nullable) || (elseValue.map(_.nullable).getOrElse(true))
   }
 
   override lazy val resolved = {

http://git-wip-us.apache.org/repos/asf/spark/blob/32495289/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
index 8c3b062..84d7281 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
@@ -333,6 +333,49 @@ class ExpressionEvaluationSuite extends FunSuite {
       Literal("^Ba*n", StringType) :: c2 :: Nil), true, row)
   }
 
+  test("case when") {
+    val row = new GenericRow(Array[Any](null, false, true, "a", "b", "c"))
+    val c1 = 'a.boolean.at(0)
+    val c2 = 'a.boolean.at(1)
+    val c3 = 'a.boolean.at(2)
+    val c4 = 'a.string.at(3)
+    val c5 = 'a.string.at(4)
+    val c6 = 'a.string.at(5)
+
+    checkEvaluation(CaseWhen(Seq(c1, c4, c6)), "c", row)
+    checkEvaluation(CaseWhen(Seq(c2, c4, c6)), "c", row)
+    checkEvaluation(CaseWhen(Seq(c3, c4, c6)), "a", row)
+    checkEvaluation(CaseWhen(Seq(Literal(null, BooleanType), c4, c6)), "c", row)
+    checkEvaluation(CaseWhen(Seq(Literal(false, BooleanType), c4, c6)), "c", row)
+    checkEvaluation(CaseWhen(Seq(Literal(true, BooleanType), c4, c6)), "a", row)
+
+    checkEvaluation(CaseWhen(Seq(c3, c4, c2, c5, c6)), "a", row)
+    checkEvaluation(CaseWhen(Seq(c2, c4, c3, c5, c6)), "b", row)
+    checkEvaluation(CaseWhen(Seq(c1, c4, c2, c5, c6)), "c", row)
+    checkEvaluation(CaseWhen(Seq(c1, c4, c2, c5)), null, row)
+
+    assert(CaseWhen(Seq(c2, c4, c6)).nullable === true)
+    assert(CaseWhen(Seq(c2, c4, c3, c5, c6)).nullable === true)
+    assert(CaseWhen(Seq(c2, c4, c3, c5)).nullable === true)
+
+    val c4_notNull = 'a.boolean.notNull.at(3)
+    val c5_notNull = 'a.boolean.notNull.at(4)
+    val c6_notNull = 'a.boolean.notNull.at(5)
+
+    assert(CaseWhen(Seq(c2, c4_notNull, c6_notNull)).nullable === false)
+    assert(CaseWhen(Seq(c2, c4, c6_notNull)).nullable === true)
+    assert(CaseWhen(Seq(c2, c4_notNull, c6)).nullable === true)
+
+    assert(CaseWhen(Seq(c2, c4_notNull, c3, c5_notNull, c6_notNull)).nullable === false)
+    assert(CaseWhen(Seq(c2, c4, c3, c5_notNull, c6_notNull)).nullable === true)
+    assert(CaseWhen(Seq(c2, c4_notNull, c3, c5, c6_notNull)).nullable === true)
+    assert(CaseWhen(Seq(c2, c4_notNull, c3, c5_notNull, c6)).nullable === true)
+
+    assert(CaseWhen(Seq(c2, c4_notNull, c3, c5_notNull)).nullable === true)
+    assert(CaseWhen(Seq(c2, c4, c3, c5_notNull)).nullable === true)
+    assert(CaseWhen(Seq(c2, c4_notNull, c3, c5)).nullable === true)
+  }
+
   test("complex type") {
     val row = new GenericRow(Array[Any](
       "^Ba*n",                                  // 0