You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2016/01/23 21:13:07 UTC

spark git commit: [SPARK-12904][SQL] Strength reduction for integral and decimal literal comparisons

Repository: spark
Updated Branches:
  refs/heads/master 5f5698012 -> 423783a08


[SPARK-12904][SQL] Strength reduction for integral and decimal literal comparisons

This pull request implements strength reduction for comparing integral expressions and decimal literals, which is more common now because we switch to parsing fractional literals as decimal types (rather than doubles). I added the rules to the existing DecimalPrecision rule with some refactoring to simplify the control flow. I also moved DecimalPrecision rule into its own file due to the growing size.

Author: Reynold Xin <rx...@databricks.com>

Closes #10882 from rxin/SPARK-12904-1.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/423783a0
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/423783a0
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/423783a0

Branch: refs/heads/master
Commit: 423783a08bb8730852973aca19603e444d15040d
Parents: 5f56980
Author: Reynold Xin <rx...@databricks.com>
Authored: Sat Jan 23 12:13:05 2016 -0800
Committer: Reynold Xin <rx...@databricks.com>
Committed: Sat Jan 23 12:13:05 2016 -0800

----------------------------------------------------------------------
 .../catalyst/analysis/DecimalPrecision.scala    | 259 +++++++++++++++++++
 .../catalyst/analysis/HiveTypeCoercion.scala    | 138 +---------
 .../sql/catalyst/expressions/literals.scala     |  18 ++
 .../sql/catalyst/optimizer/Optimizer.scala      |   2 +-
 .../analysis/DecimalPrecisionSuite.scala        |  97 ++++++-
 .../scala/org/apache/spark/sql/SQLContext.scala |   1 +
 6 files changed, 376 insertions(+), 139 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/423783a0/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala
new file mode 100644
index 0000000..ad56c98
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala
@@ -0,0 +1,259 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.Literal._
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.types._
+
+
+// scalastyle:off
+/**
+ * Calculates and propagates precision for fixed-precision decimals. Hive has a number of
+ * rules for this based on the SQL standard and MS SQL:
+ * https://cwiki.apache.org/confluence/download/attachments/27362075/Hive_Decimal_Precision_Scale_Support.pdf
+ * https://msdn.microsoft.com/en-us/library/ms190476.aspx
+ *
+ * In particular, if we have expressions e1 and e2 with precision/scale p1/s2 and p2/s2
+ * respectively, then the following operations have the following precision / scale:
+ *
+ *   Operation    Result Precision                        Result Scale
+ *   ------------------------------------------------------------------------
+ *   e1 + e2      max(s1, s2) + max(p1-s1, p2-s2) + 1     max(s1, s2)
+ *   e1 - e2      max(s1, s2) + max(p1-s1, p2-s2) + 1     max(s1, s2)
+ *   e1 * e2      p1 + p2 + 1                             s1 + s2
+ *   e1 / e2      p1 - s1 + s2 + max(6, s1 + p2 + 1)      max(6, s1 + p2 + 1)
+ *   e1 % e2      min(p1-s1, p2-s2) + max(s1, s2)         max(s1, s2)
+ *   e1 union e2  max(s1, s2) + max(p1-s1, p2-s2)         max(s1, s2)
+ *   sum(e1)      p1 + 10                                 s1
+ *   avg(e1)      p1 + 4                                  s1 + 4
+ *
+ * To implement the rules for fixed-precision types, we introduce casts to turn them to unlimited
+ * precision, do the math on unlimited-precision numbers, then introduce casts back to the
+ * required fixed precision. This allows us to do all rounding and overflow handling in the
+ * cast-to-fixed-precision operator.
+ *
+ * In addition, when mixing non-decimal types with decimals, we use the following rules:
+ * - BYTE gets turned into DECIMAL(3, 0)
+ * - SHORT gets turned into DECIMAL(5, 0)
+ * - INT gets turned into DECIMAL(10, 0)
+ * - LONG gets turned into DECIMAL(20, 0)
+ * - FLOAT and DOUBLE cause fixed-length decimals to turn into DOUBLE
+ */
+// scalastyle:on
+object DecimalPrecision extends Rule[LogicalPlan] {
+  import scala.math.{max, min}
+
+  private def isFloat(t: DataType): Boolean = t == FloatType || t == DoubleType
+
+  // Returns the wider decimal type that's wider than both of them
+  def widerDecimalType(d1: DecimalType, d2: DecimalType): DecimalType = {
+    widerDecimalType(d1.precision, d1.scale, d2.precision, d2.scale)
+  }
+  // max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2)
+  def widerDecimalType(p1: Int, s1: Int, p2: Int, s2: Int): DecimalType = {
+    val scale = max(s1, s2)
+    val range = max(p1 - s1, p2 - s2)
+    DecimalType.bounded(range + scale, scale)
+  }
+
+  private def promotePrecision(e: Expression, dataType: DataType): Expression = {
+    PromotePrecision(Cast(e, dataType))
+  }
+
+  def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
+    // fix decimal precision for expressions
+    case q => q.transformExpressions(
+      decimalAndDecimal.orElse(integralAndDecimalLiteral).orElse(nondecimalAndDecimal))
+  }
+
+  /** Decimal precision promotion for +, -, *, /, %, pmod, and binary comparison. */
+  private val decimalAndDecimal: PartialFunction[Expression, Expression] = {
+    // Skip nodes whose children have not been resolved yet
+    case e if !e.childrenResolved => e
+
+    // Skip nodes who is already promoted
+    case e: BinaryArithmetic if e.left.isInstanceOf[PromotePrecision] => e
+
+    case Add(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
+      val dt = DecimalType.bounded(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2))
+      CheckOverflow(Add(promotePrecision(e1, dt), promotePrecision(e2, dt)), dt)
+
+    case Subtract(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
+      val dt = DecimalType.bounded(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2))
+      CheckOverflow(Subtract(promotePrecision(e1, dt), promotePrecision(e2, dt)), dt)
+
+    case Multiply(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
+      val resultType = DecimalType.bounded(p1 + p2 + 1, s1 + s2)
+      val widerType = widerDecimalType(p1, s1, p2, s2)
+      CheckOverflow(Multiply(promotePrecision(e1, widerType), promotePrecision(e2, widerType)),
+        resultType)
+
+    case Divide(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
+      var intDig = min(DecimalType.MAX_SCALE, p1 - s1 + s2)
+      var decDig = min(DecimalType.MAX_SCALE, max(6, s1 + p2 + 1))
+      val diff = (intDig + decDig) - DecimalType.MAX_SCALE
+      if (diff > 0) {
+        decDig -= diff / 2 + 1
+        intDig = DecimalType.MAX_SCALE - decDig
+      }
+      val resultType = DecimalType.bounded(intDig + decDig, decDig)
+      val widerType = widerDecimalType(p1, s1, p2, s2)
+      CheckOverflow(Divide(promotePrecision(e1, widerType), promotePrecision(e2, widerType)),
+        resultType)
+
+    case Remainder(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
+      val resultType = DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2))
+      // resultType may have lower precision, so we cast them into wider type first.
+      val widerType = widerDecimalType(p1, s1, p2, s2)
+      CheckOverflow(Remainder(promotePrecision(e1, widerType), promotePrecision(e2, widerType)),
+        resultType)
+
+    case Pmod(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
+      val resultType = DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2))
+      // resultType may have lower precision, so we cast them into wider type first.
+      val widerType = widerDecimalType(p1, s1, p2, s2)
+      CheckOverflow(Pmod(promotePrecision(e1, widerType), promotePrecision(e2, widerType)),
+        resultType)
+
+    case b @ BinaryComparison(e1 @ DecimalType.Expression(p1, s1),
+    e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
+      val resultType = widerDecimalType(p1, s1, p2, s2)
+      b.makeCopy(Array(Cast(e1, resultType), Cast(e2, resultType)))
+
+    // TODO: MaxOf, MinOf, etc might want other rules
+    // SUM and AVERAGE are handled by the implementations of those expressions
+  }
+
+  /**
+   * Strength reduction for comparing integral expressions with decimal literals.
+   * 1. int_col > decimal_literal => int_col > floor(decimal_literal)
+   * 2. int_col >= decimal_literal => int_col >= ceil(decimal_literal)
+   * 3. int_col < decimal_literal => int_col < ceil(decimal_literal)
+   * 4. int_col <= decimal_literal => int_col <= floor(decimal_literal)
+   * 5. decimal_literal > int_col => ceil(decimal_literal) > int_col
+   * 6. decimal_literal >= int_col => floor(decimal_literal) >= int_col
+   * 7. decimal_literal < int_col => floor(decimal_literal) < int_col
+   * 8. decimal_literal <= int_col => ceil(decimal_literal) <= int_col
+   *
+   * Note that technically this is an "optimization" and should go into the optimizer. However,
+   * by the time the optimizer runs, these comparison expressions would be pretty hard to pattern
+   * match because there are multuple (at least 2) levels of casts involved.
+   *
+   * There are a lot more possible rules we can implement, but we don't do them
+   * because we are not sure how common they are.
+   */
+  private val integralAndDecimalLiteral: PartialFunction[Expression, Expression] = {
+
+    case GreaterThan(i @ IntegralType(), DecimalLiteral(value)) =>
+      if (DecimalLiteral.smallerThanSmallestLong(value)) {
+        TrueLiteral
+      } else if (DecimalLiteral.largerThanLargestLong(value)) {
+        FalseLiteral
+      } else {
+        GreaterThan(i, Literal(value.floor.toLong))
+      }
+
+    case GreaterThanOrEqual(i @ IntegralType(), DecimalLiteral(value)) =>
+      if (DecimalLiteral.smallerThanSmallestLong(value)) {
+        TrueLiteral
+      } else if (DecimalLiteral.largerThanLargestLong(value)) {
+        FalseLiteral
+      } else {
+        GreaterThanOrEqual(i, Literal(value.ceil.toLong))
+      }
+
+    case LessThan(i @ IntegralType(), DecimalLiteral(value)) =>
+      if (DecimalLiteral.smallerThanSmallestLong(value)) {
+        FalseLiteral
+      } else if (DecimalLiteral.largerThanLargestLong(value)) {
+        TrueLiteral
+      } else {
+        LessThan(i, Literal(value.ceil.toLong))
+      }
+
+    case LessThanOrEqual(i @ IntegralType(), DecimalLiteral(value)) =>
+      if (DecimalLiteral.smallerThanSmallestLong(value)) {
+        FalseLiteral
+      } else if (DecimalLiteral.largerThanLargestLong(value)) {
+        TrueLiteral
+      } else {
+        LessThanOrEqual(i, Literal(value.floor.toLong))
+      }
+
+    case GreaterThan(DecimalLiteral(value), i @ IntegralType()) =>
+      if (DecimalLiteral.smallerThanSmallestLong(value)) {
+        FalseLiteral
+      } else if (DecimalLiteral.largerThanLargestLong(value)) {
+        TrueLiteral
+      } else {
+        GreaterThan(Literal(value.ceil.toLong), i)
+      }
+
+    case GreaterThanOrEqual(DecimalLiteral(value), i @ IntegralType()) =>
+      if (DecimalLiteral.smallerThanSmallestLong(value)) {
+        FalseLiteral
+      } else if (DecimalLiteral.largerThanLargestLong(value)) {
+        TrueLiteral
+      } else {
+        GreaterThanOrEqual(Literal(value.floor.toLong), i)
+      }
+
+    case LessThan(DecimalLiteral(value), i @ IntegralType()) =>
+      if (DecimalLiteral.smallerThanSmallestLong(value)) {
+        TrueLiteral
+      } else if (DecimalLiteral.largerThanLargestLong(value)) {
+        FalseLiteral
+      } else {
+        LessThan(Literal(value.floor.toLong), i)
+      }
+
+    case LessThanOrEqual(DecimalLiteral(value), i @ IntegralType()) =>
+      if (DecimalLiteral.smallerThanSmallestLong(value)) {
+        TrueLiteral
+      } else if (DecimalLiteral.largerThanLargestLong(value)) {
+        FalseLiteral
+      } else {
+        LessThanOrEqual(Literal(value.ceil.toLong), i)
+      }
+  }
+
+  /**
+   * Type coercion for BinaryOperator in which one side is a non-decimal numeric, and the other
+   * side is a decimal.
+   */
+  private val nondecimalAndDecimal: PartialFunction[Expression, Expression] = {
+    // Promote integers inside a binary expression with fixed-precision decimals to decimals,
+    // and fixed-precision decimals in an expression with floats / doubles to doubles
+    case b @ BinaryOperator(left, right) if left.dataType != right.dataType =>
+      (left.dataType, right.dataType) match {
+        case (t: IntegralType, DecimalType.Fixed(p, s)) =>
+          b.makeCopy(Array(Cast(left, DecimalType.forType(t)), right))
+        case (DecimalType.Fixed(p, s), t: IntegralType) =>
+          b.makeCopy(Array(left, Cast(right, DecimalType.forType(t))))
+        case (t, DecimalType.Fixed(p, s)) if isFloat(t) =>
+          b.makeCopy(Array(left, Cast(right, DoubleType)))
+        case (DecimalType.Fixed(p, s), t) if isFloat(t) =>
+          b.makeCopy(Array(Cast(left, DoubleType), right))
+        case _ =>
+          b
+      }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/423783a0/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index 6e43bdf..957ac89 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -24,6 +24,7 @@ import scala.collection.mutable
 
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate._
+import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.types._
@@ -81,7 +82,7 @@ object HiveTypeCoercion {
    * Find the tightest common type of two types that might be used in a binary expression.
    * This handles all numeric types except fixed-precision decimals interacting with each other or
    * with primitive types, because in that case the precision and scale of the result depends on
-   * the operation. Those rules are implemented in [[HiveTypeCoercion.DecimalPrecision]].
+   * the operation. Those rules are implemented in [[DecimalPrecision]].
    */
   val findTightestCommonTypeOfTwo: (DataType, DataType) => Option[DataType] = {
     case (t1, t2) if t1 == t2 => Some(t1)
@@ -381,141 +382,6 @@ object HiveTypeCoercion {
     }
   }
 
-  // scalastyle:off
-  /**
-   * Calculates and propagates precision for fixed-precision decimals. Hive has a number of
-   * rules for this based on the SQL standard and MS SQL:
-   * https://cwiki.apache.org/confluence/download/attachments/27362075/Hive_Decimal_Precision_Scale_Support.pdf
-   * https://msdn.microsoft.com/en-us/library/ms190476.aspx
-   *
-   * In particular, if we have expressions e1 and e2 with precision/scale p1/s2 and p2/s2
-   * respectively, then the following operations have the following precision / scale:
-   *
-   *   Operation    Result Precision                        Result Scale
-   *   ------------------------------------------------------------------------
-   *   e1 + e2      max(s1, s2) + max(p1-s1, p2-s2) + 1     max(s1, s2)
-   *   e1 - e2      max(s1, s2) + max(p1-s1, p2-s2) + 1     max(s1, s2)
-   *   e1 * e2      p1 + p2 + 1                             s1 + s2
-   *   e1 / e2      p1 - s1 + s2 + max(6, s1 + p2 + 1)      max(6, s1 + p2 + 1)
-   *   e1 % e2      min(p1-s1, p2-s2) + max(s1, s2)         max(s1, s2)
-   *   e1 union e2  max(s1, s2) + max(p1-s1, p2-s2)         max(s1, s2)
-   *   sum(e1)      p1 + 10                                 s1
-   *   avg(e1)      p1 + 4                                  s1 + 4
-   *
-   * Catalyst also has unlimited-precision decimals. For those, all ops return unlimited precision.
-   *
-   * To implement the rules for fixed-precision types, we introduce casts to turn them to unlimited
-   * precision, do the math on unlimited-precision numbers, then introduce casts back to the
-   * required fixed precision. This allows us to do all rounding and overflow handling in the
-   * cast-to-fixed-precision operator.
-   *
-   * In addition, when mixing non-decimal types with decimals, we use the following rules:
-   * - BYTE gets turned into DECIMAL(3, 0)
-   * - SHORT gets turned into DECIMAL(5, 0)
-   * - INT gets turned into DECIMAL(10, 0)
-   * - LONG gets turned into DECIMAL(20, 0)
-   * - FLOAT and DOUBLE cause fixed-length decimals to turn into DOUBLE
-   */
-  // scalastyle:on
-  object DecimalPrecision extends Rule[LogicalPlan] {
-    import scala.math.{max, min}
-
-    private def isFloat(t: DataType): Boolean = t == FloatType || t == DoubleType
-
-    // Returns the wider decimal type that's wider than both of them
-    def widerDecimalType(d1: DecimalType, d2: DecimalType): DecimalType = {
-      widerDecimalType(d1.precision, d1.scale, d2.precision, d2.scale)
-    }
-    // max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2)
-    def widerDecimalType(p1: Int, s1: Int, p2: Int, s2: Int): DecimalType = {
-      val scale = max(s1, s2)
-      val range = max(p1 - s1, p2 - s2)
-      DecimalType.bounded(range + scale, scale)
-    }
-
-    private def promotePrecision(e: Expression, dataType: DataType): Expression = {
-      PromotePrecision(Cast(e, dataType))
-    }
-
-    def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
-
-      // fix decimal precision for expressions
-      case q => q.transformExpressions {
-        // Skip nodes whose children have not been resolved yet
-        case e if !e.childrenResolved => e
-
-        // Skip nodes who is already promoted
-        case e: BinaryArithmetic if e.left.isInstanceOf[PromotePrecision] => e
-
-        case Add(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
-          val dt = DecimalType.bounded(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2))
-          CheckOverflow(Add(promotePrecision(e1, dt), promotePrecision(e2, dt)), dt)
-
-        case Subtract(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
-          val dt = DecimalType.bounded(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2))
-          CheckOverflow(Subtract(promotePrecision(e1, dt), promotePrecision(e2, dt)), dt)
-
-        case Multiply(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
-          val resultType = DecimalType.bounded(p1 + p2 + 1, s1 + s2)
-          val widerType = widerDecimalType(p1, s1, p2, s2)
-          CheckOverflow(Multiply(promotePrecision(e1, widerType), promotePrecision(e2, widerType)),
-            resultType)
-
-        case Divide(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
-          var intDig = min(DecimalType.MAX_SCALE, p1 - s1 + s2)
-          var decDig = min(DecimalType.MAX_SCALE, max(6, s1 + p2 + 1))
-          val diff = (intDig + decDig) - DecimalType.MAX_SCALE
-          if (diff > 0) {
-            decDig -= diff / 2 + 1
-            intDig = DecimalType.MAX_SCALE - decDig
-          }
-          val resultType = DecimalType.bounded(intDig + decDig, decDig)
-          val widerType = widerDecimalType(p1, s1, p2, s2)
-          CheckOverflow(Divide(promotePrecision(e1, widerType), promotePrecision(e2, widerType)),
-            resultType)
-
-        case Remainder(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
-          val resultType = DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2))
-          // resultType may have lower precision, so we cast them into wider type first.
-          val widerType = widerDecimalType(p1, s1, p2, s2)
-          CheckOverflow(Remainder(promotePrecision(e1, widerType), promotePrecision(e2, widerType)),
-            resultType)
-
-        case Pmod(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
-          val resultType = DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2))
-          // resultType may have lower precision, so we cast them into wider type first.
-          val widerType = widerDecimalType(p1, s1, p2, s2)
-          CheckOverflow(Pmod(promotePrecision(e1, widerType), promotePrecision(e2, widerType)),
-            resultType)
-
-        case b @ BinaryComparison(e1 @ DecimalType.Expression(p1, s1),
-                                  e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
-          val resultType = widerDecimalType(p1, s1, p2, s2)
-          b.makeCopy(Array(Cast(e1, resultType), Cast(e2, resultType)))
-
-        // Promote integers inside a binary expression with fixed-precision decimals to decimals,
-        // and fixed-precision decimals in an expression with floats / doubles to doubles
-        case b @ BinaryOperator(left, right) if left.dataType != right.dataType =>
-          (left.dataType, right.dataType) match {
-            case (t: IntegralType, DecimalType.Fixed(p, s)) =>
-              b.makeCopy(Array(Cast(left, DecimalType.forType(t)), right))
-            case (DecimalType.Fixed(p, s), t: IntegralType) =>
-              b.makeCopy(Array(left, Cast(right, DecimalType.forType(t))))
-            case (t, DecimalType.Fixed(p, s)) if isFloat(t) =>
-              b.makeCopy(Array(left, Cast(right, DoubleType)))
-            case (DecimalType.Fixed(p, s), t) if isFloat(t) =>
-              b.makeCopy(Array(Cast(left, DoubleType), right))
-            case _ =>
-              b
-          }
-
-        // TODO: MaxOf, MinOf, etc might want other rules
-
-        // SUM and AVERAGE are handled by the implementations of those expressions
-      }
-    }
-  }
-
   /**
    * Changes numeric values to booleans so that expressions like true = 1 can be evaluated.
    */

http://git-wip-us.apache.org/repos/asf/spark/blob/423783a0/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
index db30845..ca0892e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
@@ -140,6 +140,24 @@ object IntegerLiteral {
 }
 
 /**
+ * Extractor for and other utility methods for decimal literals.
+ */
+object DecimalLiteral {
+  def apply(v: Long): Literal = Literal(Decimal(v))
+
+  def apply(v: Double): Literal = Literal(Decimal(v))
+
+  def unapply(e: Expression): Option[Decimal] = e match {
+    case Literal(v, _: DecimalType) => Some(v.asInstanceOf[Decimal])
+    case _ => None
+  }
+
+  def largerThanLargestLong(v: Decimal): Boolean = v > Decimal(Long.MaxValue)
+
+  def smallerThanSmallestLong(v: Decimal): Boolean = v < Decimal(Long.MinValue)
+}
+
+/**
  * In order to do type checking, use Literal.create() instead of constructor
  */
 case class Literal protected (value: Any, dataType: DataType)

http://git-wip-us.apache.org/repos/asf/spark/blob/423783a0/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 44455b4..6addc20 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -1006,7 +1006,7 @@ object SimplifyCaseConversionExpressions extends Rule[LogicalPlan] {
  * Speeds up aggregates on fixed-precision decimals by executing them on unscaled Long values.
  *
  * This uses the same rules for increasing the precision and scale of the output as
- * [[org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion.DecimalPrecision]].
+ * [[org.apache.spark.sql.catalyst.analysis.DecimalPrecision]].
  */
 object DecimalAggregates extends Rule[LogicalPlan] {
   import Decimal.MAX_LONG_DIGITS

http://git-wip-us.apache.org/repos/asf/spark/blob/423783a0/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala
index 24c608e..b2613e4 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala
@@ -19,14 +19,17 @@ package org.apache.spark.sql.catalyst.analysis
 
 import org.scalatest.BeforeAndAfter
 
-import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.catalyst.{SimpleCatalystConf, TableIdentifier}
+import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate._
+import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
+import org.apache.spark.sql.catalyst.plans.PlanTest
 import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project, Union}
 import org.apache.spark.sql.types._
 
-class DecimalPrecisionSuite extends SparkFunSuite with BeforeAndAfter {
+
+class DecimalPrecisionSuite extends PlanTest with BeforeAndAfter {
   val conf = new SimpleCatalystConf(true)
   val catalog = new SimpleCatalog(conf)
   val analyzer = new Analyzer(catalog, EmptyFunctionRegistry, conf)
@@ -181,4 +184,94 @@ class DecimalPrecisionSuite extends SparkFunSuite with BeforeAndAfter {
     assert(d4.isWiderThan(FloatType) === false)
     assert(d4.isWiderThan(DoubleType) === false)
   }
+
+  test("strength reduction for integer/decimal comparisons - basic test") {
+    Seq(ByteType, ShortType, IntegerType, LongType).foreach { dt =>
+      val int = AttributeReference("a", dt)()
+
+      ruleTest(int > Literal(Decimal(4)), int > Literal(4L))
+      ruleTest(int > Literal(Decimal(4.7)), int > Literal(4L))
+
+      ruleTest(int >= Literal(Decimal(4)), int >= Literal(4L))
+      ruleTest(int >= Literal(Decimal(4.7)), int >= Literal(5L))
+
+      ruleTest(int < Literal(Decimal(4)), int < Literal(4L))
+      ruleTest(int < Literal(Decimal(4.7)), int < Literal(5L))
+
+      ruleTest(int <= Literal(Decimal(4)), int <= Literal(4L))
+      ruleTest(int <= Literal(Decimal(4.7)), int <= Literal(4L))
+
+      ruleTest(Literal(Decimal(4)) > int, Literal(4L) > int)
+      ruleTest(Literal(Decimal(4.7)) > int, Literal(5L) > int)
+
+      ruleTest(Literal(Decimal(4)) >= int, Literal(4L) >= int)
+      ruleTest(Literal(Decimal(4.7)) >= int, Literal(4L) >= int)
+
+      ruleTest(Literal(Decimal(4)) < int, Literal(4L) < int)
+      ruleTest(Literal(Decimal(4.7)) < int, Literal(4L) < int)
+
+      ruleTest(Literal(Decimal(4)) <= int, Literal(4L) <= int)
+      ruleTest(Literal(Decimal(4.7)) <= int, Literal(5L) <= int)
+
+    }
+  }
+
+  test("strength reduction for integer/decimal comparisons - overflow test") {
+    val maxValue = Literal(Decimal(Long.MaxValue))
+    val overflow = Literal(Decimal(Long.MaxValue) + Decimal(0.1))
+    val minValue = Literal(Decimal(Long.MinValue))
+    val underflow = Literal(Decimal(Long.MinValue) - Decimal(0.1))
+
+    Seq(ByteType, ShortType, IntegerType, LongType).foreach { dt =>
+      val int = AttributeReference("a", dt)()
+
+      ruleTest(int > maxValue, int > Literal(Long.MaxValue))
+      ruleTest(int > overflow, FalseLiteral)
+      ruleTest(int > minValue, int > Literal(Long.MinValue))
+      ruleTest(int > underflow, TrueLiteral)
+
+      ruleTest(int >= maxValue, int >= Literal(Long.MaxValue))
+      ruleTest(int >= overflow, FalseLiteral)
+      ruleTest(int >= minValue, int >= Literal(Long.MinValue))
+      ruleTest(int >= underflow, TrueLiteral)
+
+      ruleTest(int < maxValue, int < Literal(Long.MaxValue))
+      ruleTest(int < overflow, TrueLiteral)
+      ruleTest(int < minValue, int < Literal(Long.MinValue))
+      ruleTest(int < underflow, FalseLiteral)
+
+      ruleTest(int <= maxValue, int <= Literal(Long.MaxValue))
+      ruleTest(int <= overflow, TrueLiteral)
+      ruleTest(int <= minValue, int <= Literal(Long.MinValue))
+      ruleTest(int <= underflow, FalseLiteral)
+
+      ruleTest(maxValue > int, Literal(Long.MaxValue) > int)
+      ruleTest(overflow > int, TrueLiteral)
+      ruleTest(minValue > int, Literal(Long.MinValue) > int)
+      ruleTest(underflow > int, FalseLiteral)
+
+      ruleTest(maxValue >= int, Literal(Long.MaxValue) >= int)
+      ruleTest(overflow >= int, TrueLiteral)
+      ruleTest(minValue >= int, Literal(Long.MinValue) >= int)
+      ruleTest(underflow >= int, FalseLiteral)
+
+      ruleTest(maxValue < int, Literal(Long.MaxValue) < int)
+      ruleTest(overflow < int, FalseLiteral)
+      ruleTest(minValue < int, Literal(Long.MinValue) < int)
+      ruleTest(underflow < int, TrueLiteral)
+
+      ruleTest(maxValue <= int, Literal(Long.MaxValue) <= int)
+      ruleTest(overflow <= int, FalseLiteral)
+      ruleTest(minValue <= int, Literal(Long.MinValue) <= int)
+      ruleTest(underflow <= int, TrueLiteral)
+    }
+  }
+
+  /** strength reduction for integer/decimal comparisons */
+  def ruleTest(initial: Expression, transformed: Expression): Unit = {
+    val testRelation = LocalRelation(AttributeReference("a", IntegerType)())
+    comparePlans(
+      DecimalPrecision(Project(Seq(Alias(initial, "a")()), testRelation)),
+      Project(Seq(Alias(transformed, "a")()), testRelation))
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/423783a0/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 147e355..b774da3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -74,6 +74,7 @@ class SQLContext private[sql](
   def this(sparkContext: SparkContext) = {
     this(sparkContext, new CacheManager, SQLContext.createListenerAndUI(sparkContext), true)
   }
+
   def this(sparkContext: JavaSparkContext) = this(sparkContext.sc)
 
   // If spark.sql.allowMultipleContexts is true, we will throw an exception if a user


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org