You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2015/06/11 21:57:40 UTC

spark git commit: [SPARK-8305] [SPARK-8190] [SQL] improve codegen

Repository: spark
Updated Branches:
  refs/heads/master 424b0075a -> 1191c3efc


[SPARK-8305] [SPARK-8190] [SQL] improve codegen

This PR fix a few small issues about codgen:

1. cast decimal to boolean
2. do not inline literal with null
3. improve SpecificRow.equals()
4. test expressions with optimized express
5. fix compare with BinaryType

cc rxin chenghao-intel

Author: Davies Liu <da...@databricks.com>

Closes #6755 from davies/fix_codegen and squashes the following commits:

ef27343 [Davies Liu] address comments
6617ea6 [Davies Liu] fix scala tyle
70b7dda [Davies Liu] improve codegen


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/1191c3ef
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/1191c3ef
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/1191c3ef

Branch: refs/heads/master
Commit: 1191c3efc605d9c6d1df4b38ddae8d210a361b5b
Parents: 424b007
Author: Davies Liu <da...@databricks.com>
Authored: Thu Jun 11 12:57:33 2015 -0700
Committer: Reynold Xin <rx...@databricks.com>
Committed: Thu Jun 11 12:57:33 2015 -0700

----------------------------------------------------------------------
 .../scala/org/apache/spark/sql/BaseRow.java     | 21 +++++++++
 .../spark/sql/catalyst/expressions/Cast.scala   |  4 +-
 .../expressions/codegen/CodeGenerator.scala     | 34 +++++++++++----
 .../codegen/GenerateMutableProjection.scala     |  1 -
 .../expressions/codegen/GenerateOrdering.scala  | 39 ++---------------
 .../codegen/GenerateProjection.scala            | 45 ++++++++++----------
 .../sql/catalyst/expressions/conditionals.scala |  2 +-
 .../sql/catalyst/expressions/literals.scala     |  3 +-
 .../sql/catalyst/expressions/predicates.scala   | 20 ++++-----
 .../spark/sql/catalyst/util/TypeUtils.scala     |  8 ++++
 .../org/apache/spark/sql/types/BinaryType.scala |  7 +--
 .../sql/catalyst/expressions/CastSuite.scala    | 37 ++++++++++++++--
 .../expressions/ExpressionEvalHelper.scala      | 12 ++++++
 .../optimizer/ExpressionOptimizationSuite.scala | 37 ----------------
 14 files changed, 141 insertions(+), 129 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/1191c3ef/sql/catalyst/src/main/scala/org/apache/spark/sql/BaseRow.java
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/BaseRow.java b/sql/catalyst/src/main/scala/org/apache/spark/sql/BaseRow.java
index 6584882..e91daf1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/BaseRow.java
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/BaseRow.java
@@ -154,6 +154,27 @@ public abstract class BaseRow implements Row {
     throw new UnsupportedOperationException();
   }
 
+  /**
+   * A generic version of Row.equals(Row), which is used for tests.
+   */
+  @Override
+  public boolean equals(Object other) {
+    if (other instanceof Row) {
+      Row row = (Row) other;
+      int n = size();
+      if (n != row.size()) {
+        return false;
+      }
+      for (int i = 0; i < n; i ++) {
+        if (isNullAt(i) != row.isNullAt(i) || (!isNullAt(i) && !get(i).equals(row.get(i)))) {
+          return false;
+        }
+      }
+      return true;
+    }
+    return false;
+  }
+
   @Override
   public Row copy() {
     final int n = size();

http://git-wip-us.apache.org/repos/asf/spark/blob/1191c3ef/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 8d93957..037efd7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -141,7 +141,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
     case ByteType =>
       buildCast[Byte](_, _ != 0)
     case DecimalType() =>
-      buildCast[Decimal](_, _ != 0)
+      buildCast[Decimal](_, _ != Decimal(0))
     case DoubleType =>
       buildCast[Double](_, _ != 0)
     case FloatType =>
@@ -454,7 +454,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
       case (BooleanType, dt: NumericType) =>
         defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})($c ? 1 : 0)")
       case (dt: DecimalType, BooleanType) =>
-        defineCodeGen(ctx, ev, c => s"$c.isZero()")
+        defineCodeGen(ctx, ev, c => s"!$c.isZero()")
       case (dt: NumericType, BooleanType) =>
         defineCodeGen(ctx, ev, c => s"$c != 0")
 

http://git-wip-us.apache.org/repos/asf/spark/blob/1191c3ef/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 80aa8fa..ecf8e0d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -161,15 +161,23 @@ class CodeGenContext {
   }
 
   /**
-   * Returns a function to generate equal expression in Java
+   * Generate code for equal expression in Java
    */
-  def equalFunc(dataType: DataType): ((String, String) => String) = dataType match {
-    case BinaryType => { case (eval1, eval2) =>
-      s"java.util.Arrays.equals($eval1, $eval2)" }
-    case IntegerType | BooleanType | LongType | DoubleType | FloatType | ShortType | ByteType =>
-      { case (eval1, eval2) => s"$eval1 == $eval2" }
-    case other =>
-      { case (eval1, eval2) => s"$eval1.equals($eval2)" }
+  def genEqual(dataType: DataType, c1: String, c2: String): String = dataType match {
+    case BinaryType => s"java.util.Arrays.equals($c1, $c2)"
+    case dt: DataType if isPrimitiveType(dt) => s"$c1 == $c2"
+    case other => s"$c1.equals($c2)"
+  }
+
+  /**
+   * Generate code for compare expression in Java
+   */
+  def genComp(dataType: DataType, c1: String, c2: String): String = dataType match {
+    // Use signum() to keep any small difference bwteen float/double
+    case FloatType | DoubleType => s"(int)java.lang.Math.signum($c1 - $c2)"
+    case dt: DataType if isPrimitiveType(dt) => s"(int)($c1 - $c2)"
+    case BinaryType => s"org.apache.spark.sql.catalyst.util.TypeUtils.compareBinary($c1, $c2)"
+    case other => s"$c1.compare($c2)"
   }
 
   /**
@@ -182,6 +190,16 @@ class CodeGenContext {
    * Returns true if the data type has a special accessor and setter in [[Row]].
    */
   def isNativeType(dt: DataType): Boolean = nativeTypes.contains(dt)
+
+  /**
+   * List of data types who's Java type is primitive type
+   */
+  val primitiveTypes = nativeTypes ++ Seq(DateType, TimestampType)
+
+  /**
+   * Returns true if the Java type is primitive type
+   */
+  def isPrimitiveType(dt: DataType): Boolean = primitiveTypes.contains(dt)
 }
 
 /**

http://git-wip-us.apache.org/repos/asf/spark/blob/1191c3ef/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
index e5ee2ac..ed3df54 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
@@ -82,7 +82,6 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu
       }
     """
 
-
     logDebug(s"code for ${expressions.mkString(",")}:\n$code")
 
     val c = compile(code)

http://git-wip-us.apache.org/repos/asf/spark/blob/1191c3ef/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
index 36e155d..56ecc5f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
@@ -21,7 +21,6 @@ import org.apache.spark.Logging
 import org.apache.spark.annotation.Private
 import org.apache.spark.sql.Row
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.types.{BinaryType, NumericType}
 
 /**
  * Inherits some default implementation for Java from `Ordering[Row]`
@@ -55,39 +54,6 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] wit
       val evalA = order.child.gen(ctx)
       val evalB = order.child.gen(ctx)
       val asc = order.direction == Ascending
-      val compare = order.child.dataType match {
-        case BinaryType =>
-          s"""
-            {
-              byte[] x = ${if (asc) evalA.primitive else evalB.primitive};
-              byte[] y = ${if (!asc) evalB.primitive else evalA.primitive};
-              int j = 0;
-              while (j < x.length && j < y.length) {
-                if (x[j] != y[j]) return x[j] - y[j];
-                j = j + 1;
-              }
-              int d = x.length - y.length;
-              if (d != 0) {
-                return d;
-              }
-            }"""
-        case _: NumericType =>
-          s"""
-            if (${evalA.primitive} != ${evalB.primitive}) {
-              if (${evalA.primitive} > ${evalB.primitive}) {
-                return ${if (asc) "1" else "-1"};
-              } else {
-                return ${if (asc) "-1" else "1"};
-              }
-            }"""
-        case _ =>
-          s"""
-            int comp = ${evalA.primitive}.compare(${evalB.primitive});
-            if (comp != 0) {
-              return ${if (asc) "comp" else "-comp"};
-            }"""
-      }
-
       s"""
           i = $a;
           ${evalA.code}
@@ -100,7 +66,10 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] wit
           } else if (${evalB.isNull}) {
             return ${if (order.direction == Ascending) "1" else "-1"};
           } else {
-            $compare
+            int comp = ${ctx.genComp(order.child.dataType, evalA.primitive, evalB.primitive)};
+            if (comp != 0) {
+              return ${if (asc) "comp" else "-comp"};
+            }
           }
       """
     }.mkString("\n")

http://git-wip-us.apache.org/repos/asf/spark/blob/1191c3ef/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
index 274429c..9b906c3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
@@ -72,14 +72,12 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
     }.mkString("\n        ")
 
     val specificAccessorFunctions = ctx.nativeTypes.map { dataType =>
-      val cases = expressions.zipWithIndex.map {
-        case (e, i) if e.dataType == dataType
-          || dataType == IntegerType && e.dataType == DateType
-          || dataType == LongType && e.dataType == TimestampType =>
-          s"case $i: return c$i;"
-        case _ => ""
+      val cases = expressions.zipWithIndex.flatMap {
+        case (e, i) if ctx.javaType(e.dataType) == ctx.javaType(dataType) =>
+          List(s"case $i: return c$i;")
+        case _ => Nil
       }.mkString("\n        ")
-      if (cases.count(_ != '\n') > 0) {
+      if (cases.length > 0) {
         s"""
       @Override
       public ${ctx.javaType(dataType)} ${ctx.accessorForType(dataType)}(int i) {
@@ -89,7 +87,8 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
         switch (i) {
         $cases
         }
-        return ${ctx.defaultValue(dataType)};
+        throw new IllegalArgumentException("Invalid index: " + i
+          + " in ${ctx.accessorForType(dataType)}");
       }"""
       } else {
         ""
@@ -97,14 +96,12 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
     }.mkString("\n")
 
     val specificMutatorFunctions = ctx.nativeTypes.map { dataType =>
-      val cases = expressions.zipWithIndex.map {
-        case (e, i) if e.dataType == dataType
-          || dataType == IntegerType && e.dataType == DateType
-          || dataType == LongType && e.dataType == TimestampType =>
-          s"case $i: { c$i = value; return; }"
-        case _ => ""
-      }.mkString("\n")
-      if (cases.count(_ != '\n') > 0) {
+      val cases = expressions.zipWithIndex.flatMap {
+        case (e, i) if ctx.javaType(e.dataType) == ctx.javaType(dataType) =>
+          List(s"case $i: { c$i = value; return; }")
+        case _ => Nil
+      }.mkString("\n        ")
+      if (cases.length > 0) {
         s"""
       @Override
       public void ${ctx.mutatorForType(dataType)}(int i, ${ctx.javaType(dataType)} value) {
@@ -112,6 +109,8 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
         switch (i) {
         $cases
         }
+        throw new IllegalArgumentException("Invalid index: " + i +
+          " in ${ctx.mutatorForType(dataType)}");
       }"""
       } else {
         ""
@@ -139,9 +138,10 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
 
     val columnChecks = expressions.zipWithIndex.map { case (e, i) =>
       s"""
-          if (isNullAt($i) != row.isNullAt($i) || !isNullAt($i) && !get($i).equals(row.get($i))) {
-            return false;
-          }
+        if (nullBits[$i] != row.nullBits[$i] ||
+          (!nullBits[$i] && !(${ctx.genEqual(e.dataType, s"c$i", s"row.c$i")}))) {
+          return false;
+        }
       """
     }.mkString("\n")
 
@@ -174,7 +174,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
       }
 
       public int size() { return ${expressions.length};}
-      private boolean[] nullBits = new boolean[${expressions.length}];
+      protected boolean[] nullBits = new boolean[${expressions.length}];
       public void setNullAt(int i) { nullBits[i] = true; }
       public boolean isNullAt(int i) { return nullBits[i]; }
 
@@ -207,9 +207,8 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
 
       @Override
       public boolean equals(Object other) {
-        if (other instanceof Row) {
-          Row row = (Row) other;
-          if (row.length() != size()) return false;
+        if (other instanceof SpecificRow) {
+          SpecificRow row = (SpecificRow) other;
           $columnChecks
           return true;
         }

http://git-wip-us.apache.org/repos/asf/spark/blob/1191c3ef/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala
index 1a5cde2..72b9f23 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala
@@ -261,7 +261,7 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW
           ${cond.code}
           if (${keyEval.isNull} && ${cond.isNull} ||
             !${keyEval.isNull} && !${cond.isNull}
-             && ${ctx.equalFunc(key.dataType)(keyEval.primitive, cond.primitive)}) {
+             && ${ctx.genEqual(key.dataType, keyEval.primitive, cond.primitive)}) {
             $got = true;
             ${res.code}
             ${ev.isNull} = ${res.isNull};

http://git-wip-us.apache.org/repos/asf/spark/blob/1191c3ef/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
index 833c08a..ef50c50 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
@@ -92,8 +92,7 @@ case class Literal protected (value: Any, dataType: DataType) extends LeafExpres
     // change the isNull and primitive to consts, to inline them
     if (value == null) {
       ev.isNull = "true"
-      ev.primitive = ctx.defaultValue(dataType)
-      ""
+      s"final ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};"
     } else {
       dataType match {
         case BooleanType =>

http://git-wip-us.apache.org/repos/asf/spark/blob/1191c3ef/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index 2c49352..7574d1c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -250,16 +250,11 @@ abstract class BinaryComparison extends BinaryExpression with Predicate {
   }
 
   override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
-    left.dataType match {
-      case dt: NumericType if ctx.isNativeType(dt) => defineCodeGen (ctx, ev, {
-        (c1, c3) => s"$c1 $symbol $c3"
-      })
-      case DateType | TimestampType => defineCodeGen (ctx, ev, {
-        (c1, c3) => s"$c1 $symbol $c3"
-      })
-      case other => defineCodeGen (ctx, ev, {
-        (c1, c2) => s"$c1.compare($c2) $symbol 0"
-      })
+    if (ctx.isPrimitiveType(left.dataType)) {
+      // faster version
+      defineCodeGen(ctx, ev, (c1, c2) => s"$c1 $symbol $c2")
+    } else {
+      defineCodeGen(ctx, ev, (c1, c2) => s"${ctx.genComp(left.dataType, c1, c2)} $symbol 0")
     }
   }
 
@@ -280,8 +275,9 @@ case class EqualTo(left: Expression, right: Expression) extends BinaryComparison
     if (left.dataType != BinaryType) l == r
     else java.util.Arrays.equals(l.asInstanceOf[Array[Byte]], r.asInstanceOf[Array[Byte]])
   }
+
   override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
-    defineCodeGen(ctx, ev, ctx.equalFunc(left.dataType))
+    defineCodeGen(ctx, ev, (c1, c2) => ctx.genEqual(left.dataType, c1, c2))
   }
 }
 
@@ -307,7 +303,7 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp
   override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
     val eval1 = left.gen(ctx)
     val eval2 = right.gen(ctx)
-    val equalCode = ctx.equalFunc(left.dataType)(eval1.primitive, eval2.primitive)
+    val equalCode = ctx.genEqual(left.dataType, eval1.primitive, eval2.primitive)
     ev.isNull = "false"
     eval1.code + eval2.code + s"""
         boolean ${ev.primitive} = (${eval1.isNull} && ${eval2.isNull}) ||

http://git-wip-us.apache.org/repos/asf/spark/blob/1191c3ef/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala
index 0bb12d2..04857a2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala
@@ -53,4 +53,12 @@ object TypeUtils {
 
   def getOrdering(t: DataType): Ordering[Any] =
     t.asInstanceOf[AtomicType].ordering.asInstanceOf[Ordering[Any]]
+
+  def compareBinary(x: Array[Byte], y: Array[Byte]): Int = {
+    for (i <- 0 until x.length; if i < y.length) {
+      val res = x(i).compareTo(y(i))
+      if (res != 0) return res
+    }
+    x.length - y.length
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/1191c3ef/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala
index a581a9e..9b58601 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala
@@ -22,6 +22,7 @@ import scala.reflect.runtime.universe.typeTag
 
 import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.sql.catalyst.ScalaReflectionLock
+import org.apache.spark.sql.catalyst.util.TypeUtils
 
 
 /**
@@ -43,11 +44,7 @@ class BinaryType private() extends AtomicType {
 
   private[sql] val ordering = new Ordering[InternalType] {
     def compare(x: Array[Byte], y: Array[Byte]): Int = {
-      for (i <- 0 until x.length; if i < y.length) {
-        val res = x(i).compareTo(y(i))
-        if (res != 0) return res
-      }
-      x.length - y.length
+      TypeUtils.compareBinary(x, y)
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/1191c3ef/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
index 3aca94d..969c6cc 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
@@ -43,7 +43,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
   test("cast from int") {
     checkCast(0, false)
     checkCast(1, true)
-    checkCast(5, true)
+    checkCast(-5, true)
     checkCast(1, 1.toByte)
     checkCast(1, 1.toShort)
     checkCast(1, 1)
@@ -61,7 +61,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
   test("cast from long") {
     checkCast(0L, false)
     checkCast(1L, true)
-    checkCast(5L, true)
+    checkCast(-5L, true)
     checkCast(1L, 1.toByte)
     checkCast(1L, 1.toShort)
     checkCast(1L, 1)
@@ -99,10 +99,28 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
   }
 
   test("cast from float") {
-
+    checkCast(0.0f, false)
+    checkCast(0.5f, true)
+    checkCast(-5.0f, true)
+    checkCast(1.5f, 1.toByte)
+    checkCast(1.5f, 1.toShort)
+    checkCast(1.5f, 1)
+    checkCast(1.5f, 1.toLong)
+    checkCast(1.5f, 1.5)
+    checkCast(1.5f, "1.5")
   }
 
   test("cast from double") {
+    checkCast(0.0, false)
+    checkCast(0.5, true)
+    checkCast(-5.0, true)
+    checkCast(1.5, 1.toByte)
+    checkCast(1.5, 1.toShort)
+    checkCast(1.5, 1)
+    checkCast(1.5, 1.toLong)
+    checkCast(1.5, 1.5f)
+    checkCast(1.5, "1.5")
+
     checkEvaluation(cast(cast(1.toDouble, TimestampType), DoubleType), 1.toDouble)
     checkEvaluation(cast(cast(1.toDouble, TimestampType), DoubleType), 1.toDouble)
   }
@@ -183,6 +201,19 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
     checkEvaluation(Add(Literal(23.toShort), cast(true, ShortType)), 24.toShort)
   }
 
+  test("from decimal") {
+    checkCast(Decimal(0.0), false)
+    checkCast(Decimal(0.5), true)
+    checkCast(Decimal(-5.0), true)
+    checkCast(Decimal(1.5), 1.toByte)
+    checkCast(Decimal(1.5), 1.toShort)
+    checkCast(Decimal(1.5), 1)
+    checkCast(Decimal(1.5), 1.toLong)
+    checkCast(Decimal(1.5), 1.5f)
+    checkCast(Decimal(1.5), 1.5)
+    checkCast(Decimal(1.5), "1.5")
+  }
+
   test("casting to fixed-precision decimals") {
     // Overflow and rounding for casting to fixed-precision decimals:
     // - Values should round with HALF_UP mode by default when you lower scale

http://git-wip-us.apache.org/repos/asf/spark/blob/1191c3ef/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
index 87a92b8..4a241d3 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
@@ -23,6 +23,8 @@ import org.scalatest.Matchers._
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.catalyst.CatalystTypeConverters
 import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateProjection, GenerateMutableProjection}
+import org.apache.spark.sql.catalyst.optimizer.DefaultOptimizer
+import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project}
 
 /**
  * A few helper functions for expression evaluation testing. Mixin this trait to use them.
@@ -39,6 +41,7 @@ trait ExpressionEvalHelper {
     checkEvaluationWithoutCodegen(expression, expected, inputRow)
     checkEvaluationWithGeneratedMutableProjection(expression, expected, inputRow)
     checkEvaluationWithGeneratedProjection(expression, expected, inputRow)
+    checkEvaluationWithOptimization(expression, expected, inputRow)
   }
 
   protected def evaluate(expression: Expression, inputRow: Row = EmptyRow): Any = {
@@ -122,6 +125,15 @@ trait ExpressionEvalHelper {
     }
   }
 
+  protected def checkEvaluationWithOptimization(
+      expression: Expression,
+      expected: Any,
+      inputRow: Row = EmptyRow): Unit = {
+    val plan = Project(Alias(expression, s"Optimized($expression)")() :: Nil, OneRowRelation)
+    val optimizedPlan = DefaultOptimizer.execute(plan)
+    checkEvaluationWithoutCodegen(optimizedPlan.expressions.head, expected, inputRow)
+  }
+
   protected def checkDoubleEvaluation(
       expression: Expression,
       expected: Spread[Double],

http://git-wip-us.apache.org/repos/asf/spark/blob/1191c3ef/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ExpressionOptimizationSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ExpressionOptimizationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ExpressionOptimizationSuite.scala
deleted file mode 100644
index f33a18d..0000000
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ExpressionOptimizationSuite.scala
+++ /dev/null
@@ -1,37 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.optimizer
-
-import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.logical._
-
-/**
- * Overrides our expression evaluation tests and reruns them after optimization has occured.  This
- * is to ensure that constant folding and other optimizations do not break anything.
- */
-class ExpressionOptimizationSuite extends SparkFunSuite with ExpressionEvalHelper {
-  override def checkEvaluation(
-      expression: Expression,
-      expected: Any,
-      inputRow: Row = EmptyRow): Unit = {
-    val plan = Project(Alias(expression, s"Optimized($expression)")() :: Nil, OneRowRelation)
-    val optimizedPlan = DefaultOptimizer.execute(plan)
-    super.checkEvaluation(optimizedPlan.expressions.head, expected, inputRow)
-  }
-}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org