You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2015/06/08 03:45:30 UTC

[3/3] spark git commit: [SPARK-8149][SQL] Break ExpressionEvaluationSuite down to multiple files

[SPARK-8149][SQL] Break ExpressionEvaluationSuite down to multiple files

Also moved a few files in expressions package around to match test suites.

Author: Reynold Xin <rx...@databricks.com>

Closes #6693 from rxin/expr-refactoring and squashes the following commits:

857599f [Reynold Xin] Fixed style violation.
c0eb74b [Reynold Xin] Fixed compilation.
b3a40f8 [Reynold Xin] Refactored expression test suites.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f74be744
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f74be744
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f74be744

Branch: refs/heads/master
Commit: f74be744d41586690e73ec57e5551c1fbabc1d6f
Parents: 5e7b6b6
Author: Reynold Xin <rx...@databricks.com>
Authored: Sun Jun 7 18:45:24 2015 -0700
Committer: Reynold Xin <rx...@databricks.com>
Committed: Sun Jun 7 18:45:24 2015 -0700

----------------------------------------------------------------------
 .../spark/sql/catalyst/expressions/Cast.scala   |    2 +-
 .../sql/catalyst/expressions/Expression.scala   |    6 +-
 .../sql/catalyst/expressions/arithmetic.scala   |   96 --
 .../sql/catalyst/expressions/bitwise.scala      |  120 ++
 .../sql/catalyst/expressions/conditionals.scala |  313 ++++
 .../spark/sql/catalyst/expressions/math.scala   |  204 +++
 .../catalyst/expressions/mathfuncs/binary.scala |  100 --
 .../catalyst/expressions/mathfuncs/unary.scala  |  113 --
 .../sql/catalyst/expressions/predicates.scala   |  290 ----
 .../expressions/ArithmeticExpressionSuite.scala |  144 ++
 .../expressions/BitwiseFunctionsSuite.scala     |   80 +
 .../sql/catalyst/expressions/CastSuite.scala    |  532 +++++++
 .../expressions/CodeGenerationSuite.scala       |   45 +
 .../catalyst/expressions/ComplexTypeSuite.scala |  122 ++
 .../ConditionalExpressionSuite.scala            |   96 ++
 .../expressions/ExpressionEvalHelper.scala      |  134 ++
 .../expressions/ExpressionEvaluationSuite.scala | 1461 ------------------
 .../expressions/GeneratedEvaluationSuite.scala  |   44 -
 .../expressions/LiteralExpressionSuite.scala    |   55 +
 .../expressions/MathFunctionsSuite.scala        |  179 +++
 .../expressions/NullFunctionsSuite.scala        |   65 +
 .../catalyst/expressions/PredicateSuite.scala   |  179 +++
 .../expressions/StringFunctionsSuite.scala      |  218 +++
 .../optimizer/ExpressionOptimizationSuite.scala |    3 +-
 .../scala/org/apache/spark/sql/functions.scala  |    1 -
 25 files changed, 2492 insertions(+), 2110 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/f74be744/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 5f76a51..2a1f964 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -161,7 +161,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
         try Timestamp.valueOf(n) catch { case _: java.lang.IllegalArgumentException => null }
       })
     case BooleanType =>
-      buildCast[Boolean](_, b => new Timestamp((if (b) 1 else 0)))
+      buildCast[Boolean](_, b => new Timestamp(if (b) 1 else 0))
     case LongType =>
       buildCast[Long](_, l => new Timestamp(l))
     case IntegerType =>

http://git-wip-us.apache.org/repos/asf/spark/blob/f74be744/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index 0ed576b..432d65e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -81,11 +81,11 @@ abstract class Expression extends TreeNode[Expression] {
     val objectTerm = ctx.freshName("obj")
     s"""
       /* expression: ${this} */
-      Object ${objectTerm} = expressions[${ctx.references.size - 1}].eval(i);
-      boolean ${ev.isNull} = ${objectTerm} == null;
+      Object $objectTerm = expressions[${ctx.references.size - 1}].eval(i);
+      boolean ${ev.isNull} = $objectTerm == null;
       ${ctx.javaType(this.dataType)} ${ev.primitive} = ${ctx.defaultValue(this.dataType)};
       if (!${ev.isNull}) {
-        ${ev.primitive} = (${ctx.boxedType(this.dataType)})${objectTerm};
+        ${ev.primitive} = (${ctx.boxedType(this.dataType)}) $objectTerm;
       }
     """
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/f74be744/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index 3ac7c92..d4efda2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -322,102 +322,6 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet
   }
 }
 
-/**
- * A function that calculates bitwise and(&) of two numbers.
- */
-case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithmetic {
-  override def symbol: String = "&"
-
-  protected def checkTypesInternal(t: DataType) =
-    TypeUtils.checkForBitwiseExpr(t, "operator " + symbol)
-
-  private lazy val and: (Any, Any) => Any = dataType match {
-    case ByteType =>
-      ((evalE1: Byte, evalE2: Byte) => (evalE1 & evalE2).toByte).asInstanceOf[(Any, Any) => Any]
-    case ShortType =>
-      ((evalE1: Short, evalE2: Short) => (evalE1 & evalE2).toShort).asInstanceOf[(Any, Any) => Any]
-    case IntegerType =>
-      ((evalE1: Int, evalE2: Int) => evalE1 & evalE2).asInstanceOf[(Any, Any) => Any]
-    case LongType =>
-      ((evalE1: Long, evalE2: Long) => evalE1 & evalE2).asInstanceOf[(Any, Any) => Any]
-  }
-
-  protected override def evalInternal(evalE1: Any, evalE2: Any) = and(evalE1, evalE2)
-}
-
-/**
- * A function that calculates bitwise or(|) of two numbers.
- */
-case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmetic {
-  override def symbol: String = "|"
-
-  protected def checkTypesInternal(t: DataType) =
-    TypeUtils.checkForBitwiseExpr(t, "operator " + symbol)
-
-  private lazy val or: (Any, Any) => Any = dataType match {
-    case ByteType =>
-      ((evalE1: Byte, evalE2: Byte) => (evalE1 | evalE2).toByte).asInstanceOf[(Any, Any) => Any]
-    case ShortType =>
-      ((evalE1: Short, evalE2: Short) => (evalE1 | evalE2).toShort).asInstanceOf[(Any, Any) => Any]
-    case IntegerType =>
-      ((evalE1: Int, evalE2: Int) => evalE1 | evalE2).asInstanceOf[(Any, Any) => Any]
-    case LongType =>
-      ((evalE1: Long, evalE2: Long) => evalE1 | evalE2).asInstanceOf[(Any, Any) => Any]
-  }
-
-  protected override def evalInternal(evalE1: Any, evalE2: Any) = or(evalE1, evalE2)
-}
-
-/**
- * A function that calculates bitwise xor of two numbers.
- */
-case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithmetic {
-  override def symbol: String = "^"
-
-  protected def checkTypesInternal(t: DataType) =
-    TypeUtils.checkForBitwiseExpr(t, "operator " + symbol)
-
-  private lazy val xor: (Any, Any) => Any = dataType match {
-    case ByteType =>
-      ((evalE1: Byte, evalE2: Byte) => (evalE1 ^ evalE2).toByte).asInstanceOf[(Any, Any) => Any]
-    case ShortType =>
-      ((evalE1: Short, evalE2: Short) => (evalE1 ^ evalE2).toShort).asInstanceOf[(Any, Any) => Any]
-    case IntegerType =>
-      ((evalE1: Int, evalE2: Int) => evalE1 ^ evalE2).asInstanceOf[(Any, Any) => Any]
-    case LongType =>
-      ((evalE1: Long, evalE2: Long) => evalE1 ^ evalE2).asInstanceOf[(Any, Any) => Any]
-  }
-
-  protected override def evalInternal(evalE1: Any, evalE2: Any): Any = xor(evalE1, evalE2)
-}
-
-/**
- * A function that calculates bitwise not(~) of a number.
- */
-case class BitwiseNot(child: Expression) extends UnaryArithmetic {
-  override def toString: String = s"~$child"
-
-  override def checkInputDataTypes(): TypeCheckResult =
-    TypeUtils.checkForBitwiseExpr(child.dataType, "operator ~")
-
-  private lazy val not: (Any) => Any = dataType match {
-    case ByteType =>
-      ((evalE: Byte) => (~evalE).toByte).asInstanceOf[(Any) => Any]
-    case ShortType =>
-      ((evalE: Short) => (~evalE).toShort).asInstanceOf[(Any) => Any]
-    case IntegerType =>
-      ((evalE: Int) => ~evalE).asInstanceOf[(Any) => Any]
-    case LongType =>
-      ((evalE: Long) => ~evalE).asInstanceOf[(Any) => Any]
-  }
-
-  override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
-    defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dataType)})~($c)")
-  }
-
-  protected override def evalInternal(evalE: Any) = not(evalE)
-}
-
 case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic {
   override def nullable: Boolean = left.nullable && right.nullable
 

http://git-wip-us.apache.org/repos/asf/spark/blob/f74be744/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwise.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwise.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwise.scala
new file mode 100644
index 0000000..ef34586
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwise.scala
@@ -0,0 +1,120 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.util.TypeUtils
+import org.apache.spark.sql.types._
+
+
+/**
+ * A function that calculates bitwise and(&) of two numbers.
+ */
+case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithmetic {
+  override def symbol: String = "&"
+
+  protected def checkTypesInternal(t: DataType) =
+    TypeUtils.checkForBitwiseExpr(t, "operator " + symbol)
+
+  private lazy val and: (Any, Any) => Any = dataType match {
+    case ByteType =>
+      ((evalE1: Byte, evalE2: Byte) => (evalE1 & evalE2).toByte).asInstanceOf[(Any, Any) => Any]
+    case ShortType =>
+      ((evalE1: Short, evalE2: Short) => (evalE1 & evalE2).toShort).asInstanceOf[(Any, Any) => Any]
+    case IntegerType =>
+      ((evalE1: Int, evalE2: Int) => evalE1 & evalE2).asInstanceOf[(Any, Any) => Any]
+    case LongType =>
+      ((evalE1: Long, evalE2: Long) => evalE1 & evalE2).asInstanceOf[(Any, Any) => Any]
+  }
+
+  protected override def evalInternal(evalE1: Any, evalE2: Any) = and(evalE1, evalE2)
+}
+
+/**
+ * A function that calculates bitwise or(|) of two numbers.
+ */
+case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmetic {
+  override def symbol: String = "|"
+
+  protected def checkTypesInternal(t: DataType) =
+    TypeUtils.checkForBitwiseExpr(t, "operator " + symbol)
+
+  private lazy val or: (Any, Any) => Any = dataType match {
+    case ByteType =>
+      ((evalE1: Byte, evalE2: Byte) => (evalE1 | evalE2).toByte).asInstanceOf[(Any, Any) => Any]
+    case ShortType =>
+      ((evalE1: Short, evalE2: Short) => (evalE1 | evalE2).toShort).asInstanceOf[(Any, Any) => Any]
+    case IntegerType =>
+      ((evalE1: Int, evalE2: Int) => evalE1 | evalE2).asInstanceOf[(Any, Any) => Any]
+    case LongType =>
+      ((evalE1: Long, evalE2: Long) => evalE1 | evalE2).asInstanceOf[(Any, Any) => Any]
+  }
+
+  protected override def evalInternal(evalE1: Any, evalE2: Any) = or(evalE1, evalE2)
+}
+
+/**
+ * A function that calculates bitwise xor of two numbers.
+ */
+case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithmetic {
+  override def symbol: String = "^"
+
+  protected def checkTypesInternal(t: DataType) =
+    TypeUtils.checkForBitwiseExpr(t, "operator " + symbol)
+
+  private lazy val xor: (Any, Any) => Any = dataType match {
+    case ByteType =>
+      ((evalE1: Byte, evalE2: Byte) => (evalE1 ^ evalE2).toByte).asInstanceOf[(Any, Any) => Any]
+    case ShortType =>
+      ((evalE1: Short, evalE2: Short) => (evalE1 ^ evalE2).toShort).asInstanceOf[(Any, Any) => Any]
+    case IntegerType =>
+      ((evalE1: Int, evalE2: Int) => evalE1 ^ evalE2).asInstanceOf[(Any, Any) => Any]
+    case LongType =>
+      ((evalE1: Long, evalE2: Long) => evalE1 ^ evalE2).asInstanceOf[(Any, Any) => Any]
+  }
+
+  protected override def evalInternal(evalE1: Any, evalE2: Any): Any = xor(evalE1, evalE2)
+}
+
+/**
+ * A function that calculates bitwise not(~) of a number.
+ */
+case class BitwiseNot(child: Expression) extends UnaryArithmetic {
+  override def toString: String = s"~$child"
+
+  override def checkInputDataTypes(): TypeCheckResult =
+    TypeUtils.checkForBitwiseExpr(child.dataType, "operator ~")
+
+  private lazy val not: (Any) => Any = dataType match {
+    case ByteType =>
+      ((evalE: Byte) => (~evalE).toByte).asInstanceOf[(Any) => Any]
+    case ShortType =>
+      ((evalE: Short) => (~evalE).toShort).asInstanceOf[(Any) => Any]
+    case IntegerType =>
+      ((evalE: Int) => ~evalE).asInstanceOf[(Any) => Any]
+    case LongType =>
+      ((evalE: Long) => ~evalE).asInstanceOf[(Any) => Any]
+  }
+
+  override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
+    defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dataType)})~($c)")
+  }
+
+  protected override def evalInternal(evalE: Any) = not(evalE)
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/f74be744/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala
new file mode 100644
index 0000000..3aa86ed
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala
@@ -0,0 +1,313 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.types.{BooleanType, DataType}
+
+
+case class If(predicate: Expression, trueValue: Expression, falseValue: Expression)
+  extends Expression {
+
+  override def children: Seq[Expression] = predicate :: trueValue :: falseValue :: Nil
+  override def nullable: Boolean = trueValue.nullable || falseValue.nullable
+
+  override def checkInputDataTypes(): TypeCheckResult = {
+    if (predicate.dataType != BooleanType) {
+      TypeCheckResult.TypeCheckFailure(
+        s"type of predicate expression in If should be boolean, not ${predicate.dataType}")
+    } else if (trueValue.dataType != falseValue.dataType) {
+      TypeCheckResult.TypeCheckFailure(
+        s"differing types in If (${trueValue.dataType} and ${falseValue.dataType}).")
+    } else {
+      TypeCheckResult.TypeCheckSuccess
+    }
+  }
+
+  override def dataType: DataType = trueValue.dataType
+
+  override def eval(input: Row): Any = {
+    if (true == predicate.eval(input)) {
+      trueValue.eval(input)
+    } else {
+      falseValue.eval(input)
+    }
+  }
+
+  override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
+    val condEval = predicate.gen(ctx)
+    val trueEval = trueValue.gen(ctx)
+    val falseEval = falseValue.gen(ctx)
+
+    s"""
+      ${condEval.code}
+      boolean ${ev.isNull} = false;
+      ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
+      if (!${condEval.isNull} && ${condEval.primitive}) {
+        ${trueEval.code}
+        ${ev.isNull} = ${trueEval.isNull};
+        ${ev.primitive} = ${trueEval.primitive};
+      } else {
+        ${falseEval.code}
+        ${ev.isNull} = ${falseEval.isNull};
+        ${ev.primitive} = ${falseEval.primitive};
+      }
+    """
+  }
+
+  override def toString: String = s"if ($predicate) $trueValue else $falseValue"
+}
+
+trait CaseWhenLike extends Expression {
+  self: Product =>
+
+  // Note that `branches` are considered in consecutive pairs (cond, val), and the optional last
+  // element is the value for the default catch-all case (if provided).
+  // Hence, `branches` consists of at least two elements, and can have an odd or even length.
+  def branches: Seq[Expression]
+
+  @transient lazy val whenList =
+    branches.sliding(2, 2).collect { case Seq(whenExpr, _) => whenExpr }.toSeq
+  @transient lazy val thenList =
+    branches.sliding(2, 2).collect { case Seq(_, thenExpr) => thenExpr }.toSeq
+  val elseValue = if (branches.length % 2 == 0) None else Option(branches.last)
+
+  // both then and else expressions should be considered.
+  def valueTypes: Seq[DataType] = (thenList ++ elseValue).map(_.dataType)
+  def valueTypesEqual: Boolean = valueTypes.distinct.size == 1
+
+  override def checkInputDataTypes(): TypeCheckResult = {
+    if (valueTypesEqual) {
+      checkTypesInternal()
+    } else {
+      TypeCheckResult.TypeCheckFailure(
+        "THEN and ELSE expressions should all be same type or coercible to a common type")
+    }
+  }
+
+  protected def checkTypesInternal(): TypeCheckResult
+
+  override def dataType: DataType = thenList.head.dataType
+
+  override def nullable: Boolean = {
+    // If no value is nullable and no elseValue is provided, the whole statement defaults to null.
+    thenList.exists(_.nullable) || (elseValue.map(_.nullable).getOrElse(true))
+  }
+}
+
+// scalastyle:off
+/**
+ * Case statements of the form "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END".
+ * Refer to this link for the corresponding semantics:
+ * https://cwiki.apache.org/confluence/display/Hive/LanguageManual+UDF#LanguageManualUDF-ConditionalFunctions
+ */
+// scalastyle:on
+case class CaseWhen(branches: Seq[Expression]) extends CaseWhenLike {
+
+  // Use private[this] Array to speed up evaluation.
+  @transient private[this] lazy val branchesArr = branches.toArray
+
+  override def children: Seq[Expression] = branches
+
+  override protected def checkTypesInternal(): TypeCheckResult = {
+    if (whenList.forall(_.dataType == BooleanType)) {
+      TypeCheckResult.TypeCheckSuccess
+    } else {
+      val index = whenList.indexWhere(_.dataType != BooleanType)
+      TypeCheckResult.TypeCheckFailure(
+        s"WHEN expressions in CaseWhen should all be boolean type, " +
+          s"but the ${index + 1}th when expression's type is ${whenList(index)}")
+    }
+  }
+
+  /** Written in imperative fashion for performance considerations. */
+  override def eval(input: Row): Any = {
+    val len = branchesArr.length
+    var i = 0
+    // If all branches fail and an elseVal is not provided, the whole statement
+    // defaults to null, according to Hive's semantics.
+    while (i < len - 1) {
+      if (branchesArr(i).eval(input) == true) {
+        return branchesArr(i + 1).eval(input)
+      }
+      i += 2
+    }
+    var res: Any = null
+    if (i == len - 1) {
+      res = branchesArr(i).eval(input)
+    }
+    return res
+  }
+
+  override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
+    val len = branchesArr.length
+    val got = ctx.freshName("got")
+
+    val cases = (0 until len/2).map { i =>
+      val cond = branchesArr(i * 2).gen(ctx)
+      val res = branchesArr(i * 2 + 1).gen(ctx)
+      s"""
+        if (!$got) {
+          ${cond.code}
+          if (!${cond.isNull} && ${cond.primitive}) {
+            $got = true;
+            ${res.code}
+            ${ev.isNull} = ${res.isNull};
+            ${ev.primitive} = ${res.primitive};
+          }
+        }
+      """
+    }.mkString("\n")
+
+    val other = if (len % 2 == 1) {
+      val res = branchesArr(len - 1).gen(ctx)
+      s"""
+        if (!$got) {
+          ${res.code}
+          ${ev.isNull} = ${res.isNull};
+          ${ev.primitive} = ${res.primitive};
+        }
+      """
+    } else {
+      ""
+    }
+
+    s"""
+      boolean $got = false;
+      boolean ${ev.isNull} = true;
+      ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
+      $cases
+      $other
+    """
+  }
+
+  override def toString: String = {
+    "CASE" + branches.sliding(2, 2).map {
+      case Seq(cond, value) => s" WHEN $cond THEN $value"
+      case Seq(elseValue) => s" ELSE $elseValue"
+    }.mkString
+  }
+}
+
+// scalastyle:off
+/**
+ * Case statements of the form "CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END".
+ * Refer to this link for the corresponding semantics:
+ * https://cwiki.apache.org/confluence/display/Hive/LanguageManual+UDF#LanguageManualUDF-ConditionalFunctions
+ */
+// scalastyle:on
+case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseWhenLike {
+
+  // Use private[this] Array to speed up evaluation.
+  @transient private[this] lazy val branchesArr = branches.toArray
+
+  override def children: Seq[Expression] = key +: branches
+
+  override protected def checkTypesInternal(): TypeCheckResult = {
+    if ((key +: whenList).map(_.dataType).distinct.size > 1) {
+      TypeCheckResult.TypeCheckFailure(
+        "key and WHEN expressions should all be same type or coercible to a common type")
+    } else {
+      TypeCheckResult.TypeCheckSuccess
+    }
+  }
+
+  /** Written in imperative fashion for performance considerations. */
+  override def eval(input: Row): Any = {
+    val evaluatedKey = key.eval(input)
+    val len = branchesArr.length
+    var i = 0
+    // If all branches fail and an elseVal is not provided, the whole statement
+    // defaults to null, according to Hive's semantics.
+    while (i < len - 1) {
+      if (equalNullSafe(evaluatedKey, branchesArr(i).eval(input))) {
+        return branchesArr(i + 1).eval(input)
+      }
+      i += 2
+    }
+    var res: Any = null
+    if (i == len - 1) {
+      res = branchesArr(i).eval(input)
+    }
+    return res
+  }
+
+  override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
+    val keyEval = key.gen(ctx)
+    val len = branchesArr.length
+    val got = ctx.freshName("got")
+
+    val cases = (0 until len/2).map { i =>
+      val cond = branchesArr(i * 2).gen(ctx)
+      val res = branchesArr(i * 2 + 1).gen(ctx)
+      s"""
+        if (!$got) {
+          ${cond.code}
+          if (${keyEval.isNull} && ${cond.isNull} ||
+            !${keyEval.isNull} && !${cond.isNull}
+             && ${ctx.equalFunc(key.dataType)(keyEval.primitive, cond.primitive)}) {
+            $got = true;
+            ${res.code}
+            ${ev.isNull} = ${res.isNull};
+            ${ev.primitive} = ${res.primitive};
+          }
+        }
+      """
+    }.mkString("\n")
+
+    val other = if (len % 2 == 1) {
+      val res = branchesArr(len - 1).gen(ctx)
+      s"""
+        if (!$got) {
+          ${res.code}
+          ${ev.isNull} = ${res.isNull};
+          ${ev.primitive} = ${res.primitive};
+        }
+      """
+    } else {
+      ""
+    }
+
+    s"""
+      boolean $got = false;
+      boolean ${ev.isNull} = true;
+      ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
+      ${keyEval.code}
+      $cases
+      $other
+    """
+  }
+
+  private def equalNullSafe(l: Any, r: Any) = {
+    if (l == null && r == null) {
+      true
+    } else if (l == null || r == null) {
+      false
+    } else {
+      l == r
+    }
+  }
+
+  override def toString: String = {
+    s"CASE $key" + branches.sliding(2, 2).map {
+      case Seq(cond, value) => s" WHEN $cond THEN $value"
+      case Seq(elseValue) => s" ELSE $elseValue"
+    }.mkString
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/f74be744/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
new file mode 100644
index 0000000..a18067e
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
@@ -0,0 +1,204 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.types.{DataType, DoubleType}
+
+/**
+ * A unary expression specifically for math functions. Math Functions expect a specific type of
+ * input format, therefore these functions extend `ExpectsInputTypes`.
+ * @param name The short name of the function
+ */
+abstract class UnaryMathExpression(f: Double => Double, name: String)
+  extends UnaryExpression with Serializable with ExpectsInputTypes {
+  self: Product =>
+
+  override def expectedChildTypes: Seq[DataType] = Seq(DoubleType)
+  override def dataType: DataType = DoubleType
+  override def foldable: Boolean = child.foldable
+  override def nullable: Boolean = true
+  override def toString: String = s"$name($child)"
+
+  override def eval(input: Row): Any = {
+    val evalE = child.eval(input)
+    if (evalE == null) {
+      null
+    } else {
+      val result = f(evalE.asInstanceOf[Double])
+      if (result.isNaN) null else result
+    }
+  }
+
+  // name of function in java.lang.Math
+  def funcName: String = name.toLowerCase
+
+  override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
+    val eval = child.gen(ctx)
+    eval.code + s"""
+      boolean ${ev.isNull} = ${eval.isNull};
+      ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
+      if (!${ev.isNull}) {
+        ${ev.primitive} = java.lang.Math.${funcName}(${eval.primitive});
+        if (Double.valueOf(${ev.primitive}).isNaN()) {
+          ${ev.isNull} = true;
+        }
+      }
+    """
+  }
+}
+
+/**
+ * A binary expression specifically for math functions that take two `Double`s as input and returns
+ * a `Double`.
+ * @param f The math function.
+ * @param name The short name of the function
+ */
+abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String)
+  extends BinaryExpression with Serializable with ExpectsInputTypes { self: Product =>
+
+  override def expectedChildTypes: Seq[DataType] = Seq(DoubleType, DoubleType)
+
+  override def toString: String = s"$name($left, $right)"
+
+  override def dataType: DataType = DoubleType
+
+  override def eval(input: Row): Any = {
+    val evalE1 = left.eval(input)
+    if (evalE1 == null) {
+      null
+    } else {
+      val evalE2 = right.eval(input)
+      if (evalE2 == null) {
+        null
+      } else {
+        val result = f(evalE1.asInstanceOf[Double], evalE2.asInstanceOf[Double])
+        if (result.isNaN) null else result
+      }
+    }
+  }
+
+  override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
+    defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.${name.toLowerCase}($c1, $c2)")
+  }
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+////////////////////////////////////////////////////////////////////////////////////////////////////
+// Unary math functions
+////////////////////////////////////////////////////////////////////////////////////////////////////
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+case class Acos(child: Expression) extends UnaryMathExpression(math.acos, "ACOS")
+
+case class Asin(child: Expression) extends UnaryMathExpression(math.asin, "ASIN")
+
+case class Atan(child: Expression) extends UnaryMathExpression(math.atan, "ATAN")
+
+case class Cbrt(child: Expression) extends UnaryMathExpression(math.cbrt, "CBRT")
+
+case class Ceil(child: Expression) extends UnaryMathExpression(math.ceil, "CEIL")
+
+case class Cos(child: Expression) extends UnaryMathExpression(math.cos, "COS")
+
+case class Cosh(child: Expression) extends UnaryMathExpression(math.cosh, "COSH")
+
+case class Exp(child: Expression) extends UnaryMathExpression(math.exp, "EXP")
+
+case class Expm1(child: Expression) extends UnaryMathExpression(math.expm1, "EXPM1")
+
+case class Floor(child: Expression) extends UnaryMathExpression(math.floor, "FLOOR")
+
+case class Log(child: Expression) extends UnaryMathExpression(math.log, "LOG")
+
+case class Log10(child: Expression) extends UnaryMathExpression(math.log10, "LOG10")
+
+case class Log1p(child: Expression) extends UnaryMathExpression(math.log1p, "LOG1P")
+
+case class Rint(child: Expression) extends UnaryMathExpression(math.rint, "ROUND") {
+  override def funcName: String = "rint"
+}
+
+case class Signum(child: Expression) extends UnaryMathExpression(math.signum, "SIGNUM")
+
+case class Sin(child: Expression) extends UnaryMathExpression(math.sin, "SIN")
+
+case class Sinh(child: Expression) extends UnaryMathExpression(math.sinh, "SINH")
+
+case class Tan(child: Expression) extends UnaryMathExpression(math.tan, "TAN")
+
+case class Tanh(child: Expression) extends UnaryMathExpression(math.tanh, "TANH")
+
+case class ToDegrees(child: Expression) extends UnaryMathExpression(math.toDegrees, "DEGREES") {
+  override def funcName: String = "toDegrees"
+}
+
+case class ToRadians(child: Expression) extends UnaryMathExpression(math.toRadians, "RADIANS") {
+  override def funcName: String = "toRadians"
+}
+
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+////////////////////////////////////////////////////////////////////////////////////////////////////
+// Binary math functions
+////////////////////////////////////////////////////////////////////////////////////////////////////
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+
+case class Atan2(left: Expression, right: Expression)
+  extends BinaryMathExpression(math.atan2, "ATAN2") {
+
+  override def eval(input: Row): Any = {
+    val evalE1 = left.eval(input)
+    if (evalE1 == null) {
+      null
+    } else {
+      val evalE2 = right.eval(input)
+      if (evalE2 == null) {
+        null
+      } else {
+        // With codegen, the values returned by -0.0 and 0.0 are different. Handled with +0.0
+        val result = math.atan2(evalE1.asInstanceOf[Double] + 0.0,
+          evalE2.asInstanceOf[Double] + 0.0)
+        if (result.isNaN) null else result
+      }
+    }
+  }
+
+  override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
+    defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.atan2($c1 + 0.0, $c2 + 0.0)") + s"""
+      if (Double.valueOf(${ev.primitive}).isNaN()) {
+        ${ev.isNull} = true;
+      }
+      """
+  }
+}
+
+case class Hypot(left: Expression, right: Expression)
+  extends BinaryMathExpression(math.hypot, "HYPOT")
+
+case class Pow(left: Expression, right: Expression)
+  extends BinaryMathExpression(math.pow, "POWER") {
+  override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
+    defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.pow($c1, $c2)") + s"""
+      if (Double.valueOf(${ev.primitive}).isNaN()) {
+        ${ev.isNull} = true;
+      }
+      """
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/f74be744/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala
deleted file mode 100644
index 88211ac..0000000
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala
+++ /dev/null
@@ -1,100 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.expressions.mathfuncs
-
-import org.apache.spark.sql.catalyst.expressions.codegen._
-import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, BinaryExpression, Expression, Row}
-import org.apache.spark.sql.types._
-
-/**
- * A binary expression specifically for math functions that take two `Double`s as input and returns
- * a `Double`.
- * @param f The math function.
- * @param name The short name of the function
- */
-abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String)
-  extends BinaryExpression with Serializable with ExpectsInputTypes { self: Product =>
-
-  override def expectedChildTypes: Seq[DataType] = Seq(DoubleType, DoubleType)
-
-  override def toString: String = s"$name($left, $right)"
-
-  override def dataType: DataType = DoubleType
-
-  override def eval(input: Row): Any = {
-    val evalE1 = left.eval(input)
-    if (evalE1 == null) {
-      null
-    } else {
-      val evalE2 = right.eval(input)
-      if (evalE2 == null) {
-        null
-      } else {
-        val result = f(evalE1.asInstanceOf[Double], evalE2.asInstanceOf[Double])
-        if (result.isNaN) null else result
-      }
-    }
-  }
-
-  override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
-    defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.${name.toLowerCase}($c1, $c2)")
-  }
-}
-
-case class Atan2(left: Expression, right: Expression)
-  extends BinaryMathExpression(math.atan2, "ATAN2") {
-
-  override def eval(input: Row): Any = {
-    val evalE1 = left.eval(input)
-    if (evalE1 == null) {
-      null
-    } else {
-      val evalE2 = right.eval(input)
-      if (evalE2 == null) {
-        null
-      } else {
-        // With codegen, the values returned by -0.0 and 0.0 are different. Handled with +0.0
-        val result = math.atan2(evalE1.asInstanceOf[Double] + 0.0,
-          evalE2.asInstanceOf[Double] + 0.0)
-        if (result.isNaN) null else result
-      }
-    }
-  }
-
-  override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
-    defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.atan2($c1 + 0.0, $c2 + 0.0)") + s"""
-      if (Double.valueOf(${ev.primitive}).isNaN()) {
-        ${ev.isNull} = true;
-      }
-      """
-  }
-}
-
-case class Hypot(left: Expression, right: Expression)
-  extends BinaryMathExpression(math.hypot, "HYPOT")
-
-case class Pow(left: Expression, right: Expression)
-  extends BinaryMathExpression(math.pow, "POWER") {
-  override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
-    defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.pow($c1, $c2)") + s"""
-      if (Double.valueOf(${ev.primitive}).isNaN()) {
-        ${ev.isNull} = true;
-      }
-      """
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/f74be744/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala
deleted file mode 100644
index 5563cd9..0000000
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala
+++ /dev/null
@@ -1,113 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.expressions.mathfuncs
-
-import org.apache.spark.sql.catalyst.expressions.codegen.{Code, CodeGenContext, GeneratedExpressionCode}
-import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, Row, UnaryExpression}
-import org.apache.spark.sql.types._
-
-/**
- * A unary expression specifically for math functions. Math Functions expect a specific type of
- * input format, therefore these functions extend `ExpectsInputTypes`.
- * @param name The short name of the function
- */
-abstract class UnaryMathExpression(f: Double => Double, name: String)
-  extends UnaryExpression with Serializable with ExpectsInputTypes {
-  self: Product =>
-
-  override def expectedChildTypes: Seq[DataType] = Seq(DoubleType)
-  override def dataType: DataType = DoubleType
-  override def foldable: Boolean = child.foldable
-  override def nullable: Boolean = true
-  override def toString: String = s"$name($child)"
-
-  override def eval(input: Row): Any = {
-    val evalE = child.eval(input)
-    if (evalE == null) {
-      null
-    } else {
-      val result = f(evalE.asInstanceOf[Double])
-      if (result.isNaN) null else result
-    }
-  }
-
-  // name of function in java.lang.Math
-  def funcName: String = name.toLowerCase
-
-  override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
-    val eval = child.gen(ctx)
-    eval.code + s"""
-      boolean ${ev.isNull} = ${eval.isNull};
-      ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
-      if (!${ev.isNull}) {
-        ${ev.primitive} = java.lang.Math.${funcName}(${eval.primitive});
-        if (Double.valueOf(${ev.primitive}).isNaN()) {
-          ${ev.isNull} = true;
-        }
-      }
-    """
-  }
-}
-
-case class Acos(child: Expression) extends UnaryMathExpression(math.acos, "ACOS")
-
-case class Asin(child: Expression) extends UnaryMathExpression(math.asin, "ASIN")
-
-case class Atan(child: Expression) extends UnaryMathExpression(math.atan, "ATAN")
-
-case class Cbrt(child: Expression) extends UnaryMathExpression(math.cbrt, "CBRT")
-
-case class Ceil(child: Expression) extends UnaryMathExpression(math.ceil, "CEIL")
-
-case class Cos(child: Expression) extends UnaryMathExpression(math.cos, "COS")
-
-case class Cosh(child: Expression) extends UnaryMathExpression(math.cosh, "COSH")
-
-case class Exp(child: Expression) extends UnaryMathExpression(math.exp, "EXP")
-
-case class Expm1(child: Expression) extends UnaryMathExpression(math.expm1, "EXPM1")
-
-case class Floor(child: Expression) extends UnaryMathExpression(math.floor, "FLOOR")
-
-case class Log(child: Expression) extends UnaryMathExpression(math.log, "LOG")
-
-case class Log10(child: Expression) extends UnaryMathExpression(math.log10, "LOG10")
-
-case class Log1p(child: Expression) extends UnaryMathExpression(math.log1p, "LOG1P")
-
-case class Rint(child: Expression) extends UnaryMathExpression(math.rint, "ROUND") {
-  override def funcName: String = "rint"
-}
-
-case class Signum(child: Expression) extends UnaryMathExpression(math.signum, "SIGNUM")
-
-case class Sin(child: Expression) extends UnaryMathExpression(math.sin, "SIN")
-
-case class Sinh(child: Expression) extends UnaryMathExpression(math.sinh, "SINH")
-
-case class Tan(child: Expression) extends UnaryMathExpression(math.tan, "TAN")
-
-case class Tanh(child: Expression) extends UnaryMathExpression(math.tanh, "TANH")
-
-case class ToDegrees(child: Expression) extends UnaryMathExpression(math.toDegrees, "DEGREES") {
-  override def funcName: String = "toDegrees"
-}
-
-case class ToRadians(child: Expression) extends UnaryMathExpression(math.toRadians, "RADIANS") {
-  override def funcName: String = "toRadians"
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/f74be744/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index 1d0f19a..5edcf3b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -359,293 +359,3 @@ case class GreaterThanOrEqual(left: Expression, right: Expression) extends Binar
 
   protected override def evalInternal(evalE1: Any, evalE2: Any) = ordering.gteq(evalE1, evalE2)
 }
-
-case class If(predicate: Expression, trueValue: Expression, falseValue: Expression)
-  extends Expression {
-
-  override def children: Seq[Expression] = predicate :: trueValue :: falseValue :: Nil
-  override def nullable: Boolean = trueValue.nullable || falseValue.nullable
-
-  override def checkInputDataTypes(): TypeCheckResult = {
-    if (predicate.dataType != BooleanType) {
-      TypeCheckResult.TypeCheckFailure(
-        s"type of predicate expression in If should be boolean, not ${predicate.dataType}")
-    } else if (trueValue.dataType != falseValue.dataType) {
-      TypeCheckResult.TypeCheckFailure(
-        s"differing types in If (${trueValue.dataType} and ${falseValue.dataType}).")
-    } else {
-      TypeCheckResult.TypeCheckSuccess
-    }
-  }
-
-  override def dataType: DataType = trueValue.dataType
-
-  override def eval(input: Row): Any = {
-    if (true == predicate.eval(input)) {
-      trueValue.eval(input)
-    } else {
-      falseValue.eval(input)
-    }
-  }
-
-  override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
-    val condEval = predicate.gen(ctx)
-    val trueEval = trueValue.gen(ctx)
-    val falseEval = falseValue.gen(ctx)
-
-    s"""
-      ${condEval.code}
-      boolean ${ev.isNull} = false;
-      ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
-      if (!${condEval.isNull} && ${condEval.primitive}) {
-        ${trueEval.code}
-        ${ev.isNull} = ${trueEval.isNull};
-        ${ev.primitive} = ${trueEval.primitive};
-      } else {
-        ${falseEval.code}
-        ${ev.isNull} = ${falseEval.isNull};
-        ${ev.primitive} = ${falseEval.primitive};
-      }
-    """
-  }
-
-  override def toString: String = s"if ($predicate) $trueValue else $falseValue"
-}
-
-trait CaseWhenLike extends Expression {
-  self: Product =>
-
-  // Note that `branches` are considered in consecutive pairs (cond, val), and the optional last
-  // element is the value for the default catch-all case (if provided).
-  // Hence, `branches` consists of at least two elements, and can have an odd or even length.
-  def branches: Seq[Expression]
-
-  @transient lazy val whenList =
-    branches.sliding(2, 2).collect { case Seq(whenExpr, _) => whenExpr }.toSeq
-  @transient lazy val thenList =
-    branches.sliding(2, 2).collect { case Seq(_, thenExpr) => thenExpr }.toSeq
-  val elseValue = if (branches.length % 2 == 0) None else Option(branches.last)
-
-  // both then and else expressions should be considered.
-  def valueTypes: Seq[DataType] = (thenList ++ elseValue).map(_.dataType)
-  def valueTypesEqual: Boolean = valueTypes.distinct.size == 1
-
-  override def checkInputDataTypes(): TypeCheckResult = {
-    if (valueTypesEqual) {
-      checkTypesInternal()
-    } else {
-      TypeCheckResult.TypeCheckFailure(
-        "THEN and ELSE expressions should all be same type or coercible to a common type")
-    }
-  }
-
-  protected def checkTypesInternal(): TypeCheckResult
-
-  override def dataType: DataType = thenList.head.dataType
-
-  override def nullable: Boolean = {
-    // If no value is nullable and no elseValue is provided, the whole statement defaults to null.
-    thenList.exists(_.nullable) || (elseValue.map(_.nullable).getOrElse(true))
-  }
-}
-
-// scalastyle:off
-/**
- * Case statements of the form "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END".
- * Refer to this link for the corresponding semantics:
- * https://cwiki.apache.org/confluence/display/Hive/LanguageManual+UDF#LanguageManualUDF-ConditionalFunctions
- */
-// scalastyle:on
-case class CaseWhen(branches: Seq[Expression]) extends CaseWhenLike {
-
-  // Use private[this] Array to speed up evaluation.
-  @transient private[this] lazy val branchesArr = branches.toArray
-
-  override def children: Seq[Expression] = branches
-
-  override protected def checkTypesInternal(): TypeCheckResult = {
-    if (whenList.forall(_.dataType == BooleanType)) {
-      TypeCheckResult.TypeCheckSuccess
-    } else {
-      val index = whenList.indexWhere(_.dataType != BooleanType)
-      TypeCheckResult.TypeCheckFailure(
-        s"WHEN expressions in CaseWhen should all be boolean type, " +
-        s"but the ${index + 1}th when expression's type is ${whenList(index)}")
-    }
-  }
-
-  /** Written in imperative fashion for performance considerations. */
-  override def eval(input: Row): Any = {
-    val len = branchesArr.length
-    var i = 0
-    // If all branches fail and an elseVal is not provided, the whole statement
-    // defaults to null, according to Hive's semantics.
-    while (i < len - 1) {
-      if (branchesArr(i).eval(input) == true) {
-        return branchesArr(i + 1).eval(input)
-      }
-      i += 2
-    }
-    var res: Any = null
-    if (i == len - 1) {
-      res = branchesArr(i).eval(input)
-    }
-    return res
-  }
-
-  override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
-    val len = branchesArr.length
-    val got = ctx.freshName("got")
-
-    val cases = (0 until len/2).map { i =>
-      val cond = branchesArr(i * 2).gen(ctx)
-      val res = branchesArr(i * 2 + 1).gen(ctx)
-      s"""
-        if (!$got) {
-          ${cond.code}
-          if (!${cond.isNull} && ${cond.primitive}) {
-            $got = true;
-            ${res.code}
-            ${ev.isNull} = ${res.isNull};
-            ${ev.primitive} = ${res.primitive};
-          }
-        }
-      """
-    }.mkString("\n")
-
-    val other = if (len % 2 == 1) {
-      val res = branchesArr(len - 1).gen(ctx)
-      s"""
-        if (!$got) {
-          ${res.code}
-          ${ev.isNull} = ${res.isNull};
-          ${ev.primitive} = ${res.primitive};
-        }
-      """
-    } else {
-      ""
-    }
-
-    s"""
-      boolean $got = false;
-      boolean ${ev.isNull} = true;
-      ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
-      $cases
-      $other
-    """
-  }
-
-  override def toString: String = {
-    "CASE" + branches.sliding(2, 2).map {
-      case Seq(cond, value) => s" WHEN $cond THEN $value"
-      case Seq(elseValue) => s" ELSE $elseValue"
-    }.mkString
-  }
-}
-
-// scalastyle:off
-/**
- * Case statements of the form "CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END".
- * Refer to this link for the corresponding semantics:
- * https://cwiki.apache.org/confluence/display/Hive/LanguageManual+UDF#LanguageManualUDF-ConditionalFunctions
- */
-// scalastyle:on
-case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseWhenLike {
-
-  // Use private[this] Array to speed up evaluation.
-  @transient private[this] lazy val branchesArr = branches.toArray
-
-  override def children: Seq[Expression] = key +: branches
-
-  override protected def checkTypesInternal(): TypeCheckResult = {
-    if ((key +: whenList).map(_.dataType).distinct.size > 1) {
-      TypeCheckResult.TypeCheckFailure(
-        "key and WHEN expressions should all be same type or coercible to a common type")
-    } else {
-      TypeCheckResult.TypeCheckSuccess
-    }
-  }
-
-  /** Written in imperative fashion for performance considerations. */
-  override def eval(input: Row): Any = {
-    val evaluatedKey = key.eval(input)
-    val len = branchesArr.length
-    var i = 0
-    // If all branches fail and an elseVal is not provided, the whole statement
-    // defaults to null, according to Hive's semantics.
-    while (i < len - 1) {
-      if (equalNullSafe(evaluatedKey, branchesArr(i).eval(input))) {
-        return branchesArr(i + 1).eval(input)
-      }
-      i += 2
-    }
-    var res: Any = null
-    if (i == len - 1) {
-      res = branchesArr(i).eval(input)
-    }
-    return res
-  }
-
-  override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
-    val keyEval = key.gen(ctx)
-    val len = branchesArr.length
-    val got = ctx.freshName("got")
-
-    val cases = (0 until len/2).map { i =>
-      val cond = branchesArr(i * 2).gen(ctx)
-      val res = branchesArr(i * 2 + 1).gen(ctx)
-      s"""
-        if (!$got) {
-          ${cond.code}
-          if (${keyEval.isNull} && ${cond.isNull} ||
-            !${keyEval.isNull} && !${cond.isNull}
-             && ${ctx.equalFunc(key.dataType)(keyEval.primitive, cond.primitive)}) {
-            $got = true;
-            ${res.code}
-            ${ev.isNull} = ${res.isNull};
-            ${ev.primitive} = ${res.primitive};
-          }
-        }
-      """
-    }.mkString("\n")
-
-    val other = if (len % 2 == 1) {
-      val res = branchesArr(len - 1).gen(ctx)
-      s"""
-        if (!$got) {
-          ${res.code}
-          ${ev.isNull} = ${res.isNull};
-          ${ev.primitive} = ${res.primitive};
-        }
-      """
-    } else {
-      ""
-    }
-
-    s"""
-      boolean $got = false;
-      boolean ${ev.isNull} = true;
-      ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
-      ${keyEval.code}
-      $cases
-      $other
-    """
-  }
-
-  private def equalNullSafe(l: Any, r: Any) = {
-    if (l == null && r == null) {
-      true
-    } else if (l == null || r == null) {
-      false
-    } else {
-      l == r
-    }
-  }
-
-  override def toString: String = {
-    s"CASE $key" + branches.sliding(2, 2).map {
-      case Seq(cond, value) => s" WHEN $cond THEN $value"
-      case Seq(elseValue) => s" ELSE $elseValue"
-    }.mkString
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/f74be744/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
new file mode 100644
index 0000000..e1afa81
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
@@ -0,0 +1,144 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.scalatest.Matchers._
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.types.{DoubleType, IntegerType}
+
+
+class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
+
+  test("arithmetic") {
+    val row = create_row(1, 2, 3, null)
+    val c1 = 'a.int.at(0)
+    val c2 = 'a.int.at(1)
+    val c3 = 'a.int.at(2)
+    val c4 = 'a.int.at(3)
+
+    checkEvaluation(UnaryMinus(c1), -1, row)
+    checkEvaluation(UnaryMinus(Literal.create(100, IntegerType)), -100)
+
+    checkEvaluation(Add(c1, c4), null, row)
+    checkEvaluation(Add(c1, c2), 3, row)
+    checkEvaluation(Add(c1, Literal.create(null, IntegerType)), null, row)
+    checkEvaluation(Add(Literal.create(null, IntegerType), c2), null, row)
+    checkEvaluation(
+      Add(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null, row)
+
+    checkEvaluation(-c1, -1, row)
+    checkEvaluation(c1 + c2, 3, row)
+    checkEvaluation(c1 - c2, -1, row)
+    checkEvaluation(c1 * c2, 2, row)
+    checkEvaluation(c1 / c2, 0, row)
+    checkEvaluation(c1 % c2, 1, row)
+  }
+
+  test("fractional arithmetic") {
+    val row = create_row(1.1, 2.0, 3.1, null)
+    val c1 = 'a.double.at(0)
+    val c2 = 'a.double.at(1)
+    val c3 = 'a.double.at(2)
+    val c4 = 'a.double.at(3)
+
+    checkEvaluation(UnaryMinus(c1), -1.1, row)
+    checkEvaluation(UnaryMinus(Literal.create(100.0, DoubleType)), -100.0)
+    checkEvaluation(Add(c1, c4), null, row)
+    checkEvaluation(Add(c1, c2), 3.1, row)
+    checkEvaluation(Add(c1, Literal.create(null, DoubleType)), null, row)
+    checkEvaluation(Add(Literal.create(null, DoubleType), c2), null, row)
+    checkEvaluation(
+      Add(Literal.create(null, DoubleType), Literal.create(null, DoubleType)), null, row)
+
+    checkEvaluation(-c1, -1.1, row)
+    checkEvaluation(c1 + c2, 3.1, row)
+    checkDoubleEvaluation(c1 - c2, (-0.9 +- 0.001), row)
+    checkDoubleEvaluation(c1 * c2, (2.2 +- 0.001), row)
+    checkDoubleEvaluation(c1 / c2, (0.55 +- 0.001), row)
+    checkDoubleEvaluation(c3 % c2, (1.1 +- 0.001), row)
+  }
+
+  test("Divide") {
+    checkEvaluation(Divide(Literal(2), Literal(1)), 2)
+    checkEvaluation(Divide(Literal(1.0), Literal(2.0)), 0.5)
+    checkEvaluation(Divide(Literal(1), Literal(2)), 0)
+    checkEvaluation(Divide(Literal(1), Literal(0)), null)
+    checkEvaluation(Divide(Literal(1.0), Literal(0.0)), null)
+    checkEvaluation(Divide(Literal(0.0), Literal(0.0)), null)
+    checkEvaluation(Divide(Literal(0), Literal.create(null, IntegerType)), null)
+    checkEvaluation(Divide(Literal(1), Literal.create(null, IntegerType)), null)
+    checkEvaluation(Divide(Literal.create(null, IntegerType), Literal(0)), null)
+    checkEvaluation(Divide(Literal.create(null, DoubleType), Literal(0.0)), null)
+    checkEvaluation(Divide(Literal.create(null, IntegerType), Literal(1)), null)
+    checkEvaluation(Divide(Literal.create(null, IntegerType), Literal.create(null, IntegerType)),
+      null)
+  }
+
+  test("Remainder") {
+    checkEvaluation(Remainder(Literal(2), Literal(1)), 0)
+    checkEvaluation(Remainder(Literal(1.0), Literal(2.0)), 1.0)
+    checkEvaluation(Remainder(Literal(1), Literal(2)), 1)
+    checkEvaluation(Remainder(Literal(1), Literal(0)), null)
+    checkEvaluation(Remainder(Literal(1.0), Literal(0.0)), null)
+    checkEvaluation(Remainder(Literal(0.0), Literal(0.0)), null)
+    checkEvaluation(Remainder(Literal(0), Literal.create(null, IntegerType)), null)
+    checkEvaluation(Remainder(Literal(1), Literal.create(null, IntegerType)), null)
+    checkEvaluation(Remainder(Literal.create(null, IntegerType), Literal(0)), null)
+    checkEvaluation(Remainder(Literal.create(null, DoubleType), Literal(0.0)), null)
+    checkEvaluation(Remainder(Literal.create(null, IntegerType), Literal(1)), null)
+    checkEvaluation(Remainder(Literal.create(null, IntegerType), Literal.create(null, IntegerType)),
+      null)
+  }
+
+  test("MaxOf") {
+    checkEvaluation(MaxOf(1, 2), 2)
+    checkEvaluation(MaxOf(2, 1), 2)
+    checkEvaluation(MaxOf(1L, 2L), 2L)
+    checkEvaluation(MaxOf(2L, 1L), 2L)
+
+    checkEvaluation(MaxOf(Literal.create(null, IntegerType), 2), 2)
+    checkEvaluation(MaxOf(2, Literal.create(null, IntegerType)), 2)
+  }
+
+  test("MinOf") {
+    checkEvaluation(MinOf(1, 2), 1)
+    checkEvaluation(MinOf(2, 1), 1)
+    checkEvaluation(MinOf(1L, 2L), 1L)
+    checkEvaluation(MinOf(2L, 1L), 1L)
+
+    checkEvaluation(MinOf(Literal.create(null, IntegerType), 1), 1)
+    checkEvaluation(MinOf(1, Literal.create(null, IntegerType)), 1)
+  }
+
+  test("SQRT") {
+    val inputSequence = (1 to (1<<24) by 511).map(_ * (1L<<24))
+    val expectedResults = inputSequence.map(l => math.sqrt(l.toDouble))
+    val rowSequence = inputSequence.map(l => create_row(l.toDouble))
+    val d = 'a.double.at(0)
+
+    for ((row, expected) <- rowSequence zip expectedResults) {
+      checkEvaluation(Sqrt(d), expected, row)
+    }
+
+    checkEvaluation(Sqrt(Literal.create(null, DoubleType)), null, create_row(null))
+    checkEvaluation(Sqrt(-1), null, EmptyRow)
+    checkEvaluation(Sqrt(-1.5), null, EmptyRow)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/f74be744/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseFunctionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseFunctionsSuite.scala
new file mode 100644
index 0000000..c9bbc7a
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseFunctionsSuite.scala
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.types._
+
+
+class BitwiseFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
+
+  test("Bitwise operations") {
+    val row = create_row(1, 2, 3, null)
+    val c1 = 'a.int.at(0)
+    val c2 = 'a.int.at(1)
+    val c3 = 'a.int.at(2)
+    val c4 = 'a.int.at(3)
+
+    checkEvaluation(BitwiseAnd(c1, c4), null, row)
+    checkEvaluation(BitwiseAnd(c1, c2), 0, row)
+    checkEvaluation(BitwiseAnd(c1, Literal.create(null, IntegerType)), null, row)
+    checkEvaluation(
+      BitwiseAnd(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null, row)
+
+    checkEvaluation(BitwiseOr(c1, c4), null, row)
+    checkEvaluation(BitwiseOr(c1, c2), 3, row)
+    checkEvaluation(BitwiseOr(c1, Literal.create(null, IntegerType)), null, row)
+    checkEvaluation(
+      BitwiseOr(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null, row)
+
+    checkEvaluation(BitwiseXor(c1, c4), null, row)
+    checkEvaluation(BitwiseXor(c1, c2), 3, row)
+    checkEvaluation(BitwiseXor(c1, Literal.create(null, IntegerType)), null, row)
+    checkEvaluation(
+      BitwiseXor(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null, row)
+
+    checkEvaluation(BitwiseNot(c4), null, row)
+    checkEvaluation(BitwiseNot(c1), -2, row)
+    checkEvaluation(BitwiseNot(Literal.create(null, IntegerType)), null, row)
+
+    checkEvaluation(c1 & c2, 0, row)
+    checkEvaluation(c1 | c2, 3, row)
+    checkEvaluation(c1 ^ c2, 3, row)
+    checkEvaluation(~c1, -2, row)
+  }
+
+  test("unary BitwiseNOT") {
+    checkEvaluation(BitwiseNot(1), -2)
+    assert(BitwiseNot(1).dataType === IntegerType)
+    assert(BitwiseNot(1).eval(EmptyRow).isInstanceOf[Int])
+
+    checkEvaluation(BitwiseNot(1.toLong), -2.toLong)
+    assert(BitwiseNot(1.toLong).dataType === LongType)
+    assert(BitwiseNot(1.toLong).eval(EmptyRow).isInstanceOf[Long])
+
+    checkEvaluation(BitwiseNot(1.toShort), -2.toShort)
+    assert(BitwiseNot(1.toShort).dataType === ShortType)
+    assert(BitwiseNot(1.toShort).eval(EmptyRow).isInstanceOf[Short])
+
+    checkEvaluation(BitwiseNot(1.toByte), -2.toByte)
+    assert(BitwiseNot(1.toByte).dataType === ByteType)
+    assert(BitwiseNot(1.toByte).eval(EmptyRow).isInstanceOf[Byte])
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/f74be744/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
new file mode 100644
index 0000000..5bc7c30
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
@@ -0,0 +1,532 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import java.sql.{Timestamp, Date}
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.types._
+
+/**
+ * Test suite for data type casting expression [[Cast]].
+ */
+class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
+
+  private def cast(v: Any, targetType: DataType): Cast = {
+    v match {
+      case lit: Expression => Cast(lit, targetType)
+      case _ => Cast(Literal(v), targetType)
+    }
+  }
+
+  // expected cannot be null
+  private def checkCast(v: Any, expected: Any): Unit = {
+    checkEvaluation(cast(v, Literal(expected).dataType), expected)
+  }
+
+  test("cast from int") {
+    checkCast(0, false)
+    checkCast(1, true)
+    checkCast(5, true)
+    checkCast(1, 1.toByte)
+    checkCast(1, 1.toShort)
+    checkCast(1, 1)
+    checkCast(1, 1.toLong)
+    checkCast(1, 1.0f)
+    checkCast(1, 1.0)
+    checkCast(123, "123")
+
+    checkEvaluation(cast(123, DecimalType.Unlimited), Decimal(123))
+    checkEvaluation(cast(123, DecimalType(3, 0)), Decimal(123))
+    checkEvaluation(cast(123, DecimalType(3, 1)), null)
+    checkEvaluation(cast(123, DecimalType(2, 0)), null)
+  }
+
+  test("cast from long") {
+    checkCast(0L, false)
+    checkCast(1L, true)
+    checkCast(5L, true)
+    checkCast(1L, 1.toByte)
+    checkCast(1L, 1.toShort)
+    checkCast(1L, 1)
+    checkCast(1L, 1.toLong)
+    checkCast(1L, 1.0f)
+    checkCast(1L, 1.0)
+    checkCast(123L, "123")
+
+    checkEvaluation(cast(123L, DecimalType.Unlimited), Decimal(123))
+    checkEvaluation(cast(123L, DecimalType(3, 0)), Decimal(123))
+    checkEvaluation(cast(123L, DecimalType(3, 1)), Decimal(123.0))
+
+    // TODO: Fix the following bug and re-enable it.
+    // checkEvaluation(cast(123L, DecimalType(2, 0)), null)
+  }
+
+  test("cast from boolean") {
+    checkEvaluation(cast(true, IntegerType), 1)
+    checkEvaluation(cast(false, IntegerType), 0)
+    checkEvaluation(cast(true, StringType), "true")
+    checkEvaluation(cast(false, StringType), "false")
+    checkEvaluation(cast(cast(1, BooleanType), IntegerType), 1)
+    checkEvaluation(cast(cast(0, BooleanType), IntegerType), 0)
+  }
+
+  test("cast from int 2") {
+    checkEvaluation(cast(1, LongType), 1.toLong)
+    checkEvaluation(cast(cast(1000, TimestampType), LongType), 1.toLong)
+    checkEvaluation(cast(cast(-1200, TimestampType), LongType), -2.toLong)
+
+    checkEvaluation(cast(123, DecimalType.Unlimited), Decimal(123))
+    checkEvaluation(cast(123, DecimalType(3, 0)), Decimal(123))
+    checkEvaluation(cast(123, DecimalType(3, 1)), null)
+    checkEvaluation(cast(123, DecimalType(2, 0)), null)
+  }
+
+  test("cast from float") {
+
+  }
+
+  test("cast from double") {
+    checkEvaluation(cast(cast(1.toDouble, TimestampType), DoubleType), 1.toDouble)
+    checkEvaluation(cast(cast(1.toDouble, TimestampType), DoubleType), 1.toDouble)
+  }
+
+  test("cast from string") {
+    assert(cast("abcdef", StringType).nullable === false)
+    assert(cast("abcdef", BinaryType).nullable === false)
+    assert(cast("abcdef", BooleanType).nullable === false)
+    assert(cast("abcdef", TimestampType).nullable === true)
+    assert(cast("abcdef", LongType).nullable === true)
+    assert(cast("abcdef", IntegerType).nullable === true)
+    assert(cast("abcdef", ShortType).nullable === true)
+    assert(cast("abcdef", ByteType).nullable === true)
+    assert(cast("abcdef", DecimalType.Unlimited).nullable === true)
+    assert(cast("abcdef", DecimalType(4, 2)).nullable === true)
+    assert(cast("abcdef", DoubleType).nullable === true)
+    assert(cast("abcdef", FloatType).nullable === true)
+  }
+
+  test("data type casting") {
+    val sd = "1970-01-01"
+    val d = Date.valueOf(sd)
+    val zts = sd + " 00:00:00"
+    val sts = sd + " 00:00:02"
+    val nts = sts + ".1"
+    val ts = Timestamp.valueOf(nts)
+
+    checkEvaluation(cast("abdef", StringType), "abdef")
+    checkEvaluation(cast("abdef", DecimalType.Unlimited), null)
+    checkEvaluation(cast("abdef", TimestampType), null)
+    checkEvaluation(cast("12.65", DecimalType.Unlimited), Decimal(12.65))
+
+    checkEvaluation(cast(cast(sd, DateType), StringType), sd)
+    checkEvaluation(cast(cast(d, StringType), DateType), 0)
+    checkEvaluation(cast(cast(nts, TimestampType), StringType), nts)
+    checkEvaluation(cast(cast(ts, StringType), TimestampType), ts)
+
+    // all convert to string type to check
+    checkEvaluation(cast(cast(cast(nts, TimestampType), DateType), StringType), sd)
+    checkEvaluation(cast(cast(cast(ts, DateType), TimestampType), StringType), zts)
+
+    checkEvaluation(cast(cast("abdef", BinaryType), StringType), "abdef")
+
+    checkEvaluation(cast(cast(cast(cast(
+      cast(cast("5", ByteType), ShortType), IntegerType), FloatType), DoubleType), LongType),
+      5.toLong)
+    checkEvaluation(
+      cast(cast(cast(cast(cast(cast("5", ByteType), TimestampType),
+        DecimalType.Unlimited), LongType), StringType), ShortType),
+      0.toShort)
+    checkEvaluation(
+      cast(cast(cast(cast(cast(cast("5", TimestampType), ByteType),
+        DecimalType.Unlimited), LongType), StringType), ShortType),
+      null)
+    checkEvaluation(cast(cast(cast(cast(cast(cast("5", DecimalType.Unlimited),
+      ByteType), TimestampType), LongType), StringType), ShortType),
+      0.toShort)
+
+    checkEvaluation(cast("23", DoubleType), 23d)
+    checkEvaluation(cast("23", IntegerType), 23)
+    checkEvaluation(cast("23", FloatType), 23f)
+    checkEvaluation(cast("23", DecimalType.Unlimited), Decimal(23))
+    checkEvaluation(cast("23", ByteType), 23.toByte)
+    checkEvaluation(cast("23", ShortType), 23.toShort)
+    checkEvaluation(cast("2012-12-11", DoubleType), null)
+    checkEvaluation(cast(123, IntegerType), 123)
+
+
+    checkEvaluation(cast(Literal.create(null, IntegerType), ShortType), null)
+  }
+
+  test("cast and add") {
+    checkEvaluation(Add(Literal(23d), cast(true, DoubleType)), 24d)
+    checkEvaluation(Add(Literal(23), cast(true, IntegerType)), 24)
+    checkEvaluation(Add(Literal(23f), cast(true, FloatType)), 24f)
+    checkEvaluation(Add(Literal(Decimal(23)), cast(true, DecimalType.Unlimited)), Decimal(24))
+    checkEvaluation(Add(Literal(23.toByte), cast(true, ByteType)), 24.toByte)
+    checkEvaluation(Add(Literal(23.toShort), cast(true, ShortType)), 24.toShort)
+  }
+
+  test("casting to fixed-precision decimals") {
+    // Overflow and rounding for casting to fixed-precision decimals:
+    // - Values should round with HALF_UP mode by default when you lower scale
+    // - Values that would overflow the target precision should turn into null
+    // - Because of this, casts to fixed-precision decimals should be nullable
+
+    assert(cast(123, DecimalType.Unlimited).nullable === false)
+    assert(cast(10.03f, DecimalType.Unlimited).nullable === true)
+    assert(cast(10.03, DecimalType.Unlimited).nullable === true)
+    assert(cast(Decimal(10.03), DecimalType.Unlimited).nullable === false)
+
+    assert(cast(123, DecimalType(2, 1)).nullable === true)
+    assert(cast(10.03f, DecimalType(2, 1)).nullable === true)
+    assert(cast(10.03, DecimalType(2, 1)).nullable === true)
+    assert(cast(Decimal(10.03), DecimalType(2, 1)).nullable === true)
+
+
+    checkEvaluation(cast(10.03, DecimalType.Unlimited), Decimal(10.03))
+    checkEvaluation(cast(10.03, DecimalType(4, 2)), Decimal(10.03))
+    checkEvaluation(cast(10.03, DecimalType(3, 1)), Decimal(10.0))
+    checkEvaluation(cast(10.03, DecimalType(2, 0)), Decimal(10))
+    checkEvaluation(cast(10.03, DecimalType(1, 0)), null)
+    checkEvaluation(cast(10.03, DecimalType(2, 1)), null)
+    checkEvaluation(cast(10.03, DecimalType(3, 2)), null)
+    checkEvaluation(cast(Decimal(10.03), DecimalType(3, 1)), Decimal(10.0))
+    checkEvaluation(cast(Decimal(10.03), DecimalType(3, 2)), null)
+
+    checkEvaluation(cast(10.05, DecimalType.Unlimited), Decimal(10.05))
+    checkEvaluation(cast(10.05, DecimalType(4, 2)), Decimal(10.05))
+    checkEvaluation(cast(10.05, DecimalType(3, 1)), Decimal(10.1))
+    checkEvaluation(cast(10.05, DecimalType(2, 0)), Decimal(10))
+    checkEvaluation(cast(10.05, DecimalType(1, 0)), null)
+    checkEvaluation(cast(10.05, DecimalType(2, 1)), null)
+    checkEvaluation(cast(10.05, DecimalType(3, 2)), null)
+    checkEvaluation(cast(Decimal(10.05), DecimalType(3, 1)), Decimal(10.1))
+    checkEvaluation(cast(Decimal(10.05), DecimalType(3, 2)), null)
+
+    checkEvaluation(cast(9.95, DecimalType(3, 2)), Decimal(9.95))
+    checkEvaluation(cast(9.95, DecimalType(3, 1)), Decimal(10.0))
+    checkEvaluation(cast(9.95, DecimalType(2, 0)), Decimal(10))
+    checkEvaluation(cast(9.95, DecimalType(2, 1)), null)
+    checkEvaluation(cast(9.95, DecimalType(1, 0)), null)
+    checkEvaluation(cast(Decimal(9.95), DecimalType(3, 1)), Decimal(10.0))
+    checkEvaluation(cast(Decimal(9.95), DecimalType(1, 0)), null)
+
+    checkEvaluation(cast(-9.95, DecimalType(3, 2)), Decimal(-9.95))
+    checkEvaluation(cast(-9.95, DecimalType(3, 1)), Decimal(-10.0))
+    checkEvaluation(cast(-9.95, DecimalType(2, 0)), Decimal(-10))
+    checkEvaluation(cast(-9.95, DecimalType(2, 1)), null)
+    checkEvaluation(cast(-9.95, DecimalType(1, 0)), null)
+    checkEvaluation(cast(Decimal(-9.95), DecimalType(3, 1)), Decimal(-10.0))
+    checkEvaluation(cast(Decimal(-9.95), DecimalType(1, 0)), null)
+
+    checkEvaluation(cast(Double.NaN, DecimalType.Unlimited), null)
+    checkEvaluation(cast(1.0 / 0.0, DecimalType.Unlimited), null)
+    checkEvaluation(cast(Float.NaN, DecimalType.Unlimited), null)
+    checkEvaluation(cast(1.0f / 0.0f, DecimalType.Unlimited), null)
+
+    checkEvaluation(cast(Double.NaN, DecimalType(2, 1)), null)
+    checkEvaluation(cast(1.0 / 0.0, DecimalType(2, 1)), null)
+    checkEvaluation(cast(Float.NaN, DecimalType(2, 1)), null)
+    checkEvaluation(cast(1.0f / 0.0f, DecimalType(2, 1)), null)
+  }
+
+  test("cast from date") {
+    val d = Date.valueOf("1970-01-01")
+    checkEvaluation(cast(d, ShortType), null)
+    checkEvaluation(cast(d, IntegerType), null)
+    checkEvaluation(cast(d, LongType), null)
+    checkEvaluation(cast(d, FloatType), null)
+    checkEvaluation(cast(d, DoubleType), null)
+    checkEvaluation(cast(d, DecimalType.Unlimited), null)
+    checkEvaluation(cast(d, DecimalType(10, 2)), null)
+    checkEvaluation(cast(d, StringType), "1970-01-01")
+    checkEvaluation(cast(cast(d, TimestampType), StringType), "1970-01-01 00:00:00")
+  }
+
+  test("cast from timestamp") {
+    val millis = 15 * 1000 + 2
+    val seconds = millis * 1000 + 2
+    val ts = new Timestamp(millis)
+    val tss = new Timestamp(seconds)
+    checkEvaluation(cast(ts, ShortType), 15.toShort)
+    checkEvaluation(cast(ts, IntegerType), 15)
+    checkEvaluation(cast(ts, LongType), 15.toLong)
+    checkEvaluation(cast(ts, FloatType), 15.002f)
+    checkEvaluation(cast(ts, DoubleType), 15.002)
+    checkEvaluation(cast(cast(tss, ShortType), TimestampType), ts)
+    checkEvaluation(cast(cast(tss, IntegerType), TimestampType), ts)
+    checkEvaluation(cast(cast(tss, LongType), TimestampType), ts)
+    checkEvaluation(
+      cast(cast(millis.toFloat / 1000, TimestampType), FloatType),
+      millis.toFloat / 1000)
+    checkEvaluation(
+      cast(cast(millis.toDouble / 1000, TimestampType), DoubleType),
+      millis.toDouble / 1000)
+    checkEvaluation(
+      cast(cast(Decimal(1), TimestampType), DecimalType.Unlimited),
+      Decimal(1))
+
+    // A test for higher precision than millis
+    checkEvaluation(cast(cast(0.00000001, TimestampType), DoubleType), 0.00000001)
+
+    checkEvaluation(cast(Double.NaN, TimestampType), null)
+    checkEvaluation(cast(1.0 / 0.0, TimestampType), null)
+    checkEvaluation(cast(Float.NaN, TimestampType), null)
+    checkEvaluation(cast(1.0f / 0.0f, TimestampType), null)
+  }
+
+  test("cast from array") {
+    val array = Literal.create(Seq("123", "abc", "", null),
+      ArrayType(StringType, containsNull = true))
+    val array_notNull = Literal.create(Seq("123", "abc", ""),
+      ArrayType(StringType, containsNull = false))
+
+    {
+      val ret = cast(array, ArrayType(IntegerType, containsNull = true))
+      assert(ret.resolved === true)
+      checkEvaluation(ret, Seq(123, null, null, null))
+    }
+    {
+      val ret = cast(array, ArrayType(IntegerType, containsNull = false))
+      assert(ret.resolved === false)
+    }
+    {
+      val ret = cast(array, ArrayType(BooleanType, containsNull = true))
+      assert(ret.resolved === true)
+      checkEvaluation(ret, Seq(true, true, false, null))
+    }
+    {
+      val ret = cast(array, ArrayType(BooleanType, containsNull = false))
+      assert(ret.resolved === false)
+    }
+
+    {
+      val ret = cast(array_notNull, ArrayType(IntegerType, containsNull = true))
+      assert(ret.resolved === true)
+      checkEvaluation(ret, Seq(123, null, null))
+    }
+    {
+      val ret = cast(array_notNull, ArrayType(IntegerType, containsNull = false))
+      assert(ret.resolved === false)
+    }
+    {
+      val ret = cast(array_notNull, ArrayType(BooleanType, containsNull = true))
+      assert(ret.resolved === true)
+      checkEvaluation(ret, Seq(true, true, false))
+    }
+    {
+      val ret = cast(array_notNull, ArrayType(BooleanType, containsNull = false))
+      assert(ret.resolved === true)
+      checkEvaluation(ret, Seq(true, true, false))
+    }
+
+    {
+      val ret = cast(array, IntegerType)
+      assert(ret.resolved === false)
+    }
+  }
+
+  test("cast from map") {
+    val map = Literal.create(
+      Map("a" -> "123", "b" -> "abc", "c" -> "", "d" -> null),
+      MapType(StringType, StringType, valueContainsNull = true))
+    val map_notNull = Literal.create(
+      Map("a" -> "123", "b" -> "abc", "c" -> ""),
+      MapType(StringType, StringType, valueContainsNull = false))
+
+    {
+      val ret = cast(map, MapType(StringType, IntegerType, valueContainsNull = true))
+      assert(ret.resolved === true)
+      checkEvaluation(ret, Map("a" -> 123, "b" -> null, "c" -> null, "d" -> null))
+    }
+    {
+      val ret = cast(map, MapType(StringType, IntegerType, valueContainsNull = false))
+      assert(ret.resolved === false)
+    }
+    {
+      val ret = cast(map, MapType(StringType, BooleanType, valueContainsNull = true))
+      assert(ret.resolved === true)
+      checkEvaluation(ret, Map("a" -> true, "b" -> true, "c" -> false, "d" -> null))
+    }
+    {
+      val ret = cast(map, MapType(StringType, BooleanType, valueContainsNull = false))
+      assert(ret.resolved === false)
+    }
+    {
+      val ret = cast(map, MapType(IntegerType, StringType, valueContainsNull = true))
+      assert(ret.resolved === false)
+    }
+
+    {
+      val ret = cast(map_notNull, MapType(StringType, IntegerType, valueContainsNull = true))
+      assert(ret.resolved === true)
+      checkEvaluation(ret, Map("a" -> 123, "b" -> null, "c" -> null))
+    }
+    {
+      val ret = cast(map_notNull, MapType(StringType, IntegerType, valueContainsNull = false))
+      assert(ret.resolved === false)
+    }
+    {
+      val ret = cast(map_notNull, MapType(StringType, BooleanType, valueContainsNull = true))
+      assert(ret.resolved === true)
+      checkEvaluation(ret, Map("a" -> true, "b" -> true, "c" -> false))
+    }
+    {
+      val ret = cast(map_notNull, MapType(StringType, BooleanType, valueContainsNull = false))
+      assert(ret.resolved === true)
+      checkEvaluation(ret, Map("a" -> true, "b" -> true, "c" -> false))
+    }
+    {
+      val ret = cast(map_notNull, MapType(IntegerType, StringType, valueContainsNull = true))
+      assert(ret.resolved === false)
+    }
+
+    {
+      val ret = cast(map, IntegerType)
+      assert(ret.resolved === false)
+    }
+  }
+
+  test("cast from struct") {
+    val struct = Literal.create(
+      Row("123", "abc", "", null),
+      StructType(Seq(
+        StructField("a", StringType, nullable = true),
+        StructField("b", StringType, nullable = true),
+        StructField("c", StringType, nullable = true),
+        StructField("d", StringType, nullable = true))))
+    val struct_notNull = Literal.create(
+      Row("123", "abc", ""),
+      StructType(Seq(
+        StructField("a", StringType, nullable = false),
+        StructField("b", StringType, nullable = false),
+        StructField("c", StringType, nullable = false))))
+
+    {
+      val ret = cast(struct, StructType(Seq(
+        StructField("a", IntegerType, nullable = true),
+        StructField("b", IntegerType, nullable = true),
+        StructField("c", IntegerType, nullable = true),
+        StructField("d", IntegerType, nullable = true))))
+      assert(ret.resolved === true)
+      checkEvaluation(ret, Row(123, null, null, null))
+    }
+    {
+      val ret = cast(struct, StructType(Seq(
+        StructField("a", IntegerType, nullable = true),
+        StructField("b", IntegerType, nullable = true),
+        StructField("c", IntegerType, nullable = false),
+        StructField("d", IntegerType, nullable = true))))
+      assert(ret.resolved === false)
+    }
+    {
+      val ret = cast(struct, StructType(Seq(
+        StructField("a", BooleanType, nullable = true),
+        StructField("b", BooleanType, nullable = true),
+        StructField("c", BooleanType, nullable = true),
+        StructField("d", BooleanType, nullable = true))))
+      assert(ret.resolved === true)
+      checkEvaluation(ret, Row(true, true, false, null))
+    }
+    {
+      val ret = cast(struct, StructType(Seq(
+        StructField("a", BooleanType, nullable = true),
+        StructField("b", BooleanType, nullable = true),
+        StructField("c", BooleanType, nullable = false),
+        StructField("d", BooleanType, nullable = true))))
+      assert(ret.resolved === false)
+    }
+
+    {
+      val ret = cast(struct_notNull, StructType(Seq(
+        StructField("a", IntegerType, nullable = true),
+        StructField("b", IntegerType, nullable = true),
+        StructField("c", IntegerType, nullable = true))))
+      assert(ret.resolved === true)
+      checkEvaluation(ret, Row(123, null, null))
+    }
+    {
+      val ret = cast(struct_notNull, StructType(Seq(
+        StructField("a", IntegerType, nullable = true),
+        StructField("b", IntegerType, nullable = true),
+        StructField("c", IntegerType, nullable = false))))
+      assert(ret.resolved === false)
+    }
+    {
+      val ret = cast(struct_notNull, StructType(Seq(
+        StructField("a", BooleanType, nullable = true),
+        StructField("b", BooleanType, nullable = true),
+        StructField("c", BooleanType, nullable = true))))
+      assert(ret.resolved === true)
+      checkEvaluation(ret, Row(true, true, false))
+    }
+    {
+      val ret = cast(struct_notNull, StructType(Seq(
+        StructField("a", BooleanType, nullable = true),
+        StructField("b", BooleanType, nullable = true),
+        StructField("c", BooleanType, nullable = false))))
+      assert(ret.resolved === true)
+      checkEvaluation(ret, Row(true, true, false))
+    }
+
+    {
+      val ret = cast(struct, StructType(Seq(
+        StructField("a", StringType, nullable = true),
+        StructField("b", StringType, nullable = true),
+        StructField("c", StringType, nullable = true))))
+      assert(ret.resolved === false)
+    }
+    {
+      val ret = cast(struct, IntegerType)
+      assert(ret.resolved === false)
+    }
+  }
+
+  test("complex casting") {
+    val complex = Literal.create(
+      Row(
+        Seq("123", "abc", ""),
+        Map("a" -> "123", "b" -> "abc", "c" -> ""),
+        Row(0)),
+      StructType(Seq(
+        StructField("a",
+          ArrayType(StringType, containsNull = false), nullable = true),
+        StructField("m",
+          MapType(StringType, StringType, valueContainsNull = false), nullable = true),
+        StructField("s",
+          StructType(Seq(
+            StructField("i", IntegerType, nullable = true)))))))
+
+    val ret = cast(complex, StructType(Seq(
+      StructField("a",
+        ArrayType(IntegerType, containsNull = true), nullable = true),
+      StructField("m",
+        MapType(StringType, BooleanType, valueContainsNull = false), nullable = true),
+      StructField("s",
+        StructType(Seq(
+          StructField("l", LongType, nullable = true)))))))
+
+    assert(ret.resolved === true)
+    checkEvaluation(ret, Row(
+      Seq(123, null, null),
+      Map("a" -> true, "b" -> true, "c" -> false),
+      Row(0L)))
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/f74be744/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
new file mode 100644
index 0000000..481b335
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.expressions.codegen._
+
+/**
+ * Additional tests for code generation.
+ */
+class CodeGenerationSuite extends SparkFunSuite {
+
+  test("multithreaded eval") {
+    import scala.concurrent._
+    import ExecutionContext.Implicits.global
+    import scala.concurrent.duration._
+
+    val futures = (1 to 20).map { _ =>
+      future {
+        GeneratePredicate.generate(EqualTo(Literal(1), Literal(1)))
+        GenerateProjection.generate(EqualTo(Literal(1), Literal(1)) :: Nil)
+        GenerateMutableProjection.generate(EqualTo(Literal(1), Literal(1)) :: Nil)
+        GenerateOrdering.generate(Add(Literal(1), Literal(1)).asc :: Nil)
+      }
+    }
+
+    futures.foreach(Await.result(_, 10.seconds))
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org