You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2014/11/02 03:29:55 UTC

[3/3] git commit: [SPARK-3930] [SPARK-3933] Support fixed-precision decimal in SQL, and some optimizations

[SPARK-3930] [SPARK-3933] Support fixed-precision decimal in SQL, and some optimizations

- Adds optional precision and scale to Spark SQL's decimal type, which behave similarly to those in Hive 13 (https://cwiki.apache.org/confluence/download/attachments/27362075/Hive_Decimal_Precision_Scale_Support.pdf)
- Replaces our internal representation of decimals with a Decimal class that can store small values in a mutable Long, saving memory in this situation and letting some operations happen directly on Longs

This is still marked WIP because there are a few TODOs, but I'll remove that tag when done.

Author: Matei Zaharia <ma...@databricks.com>

Closes #2983 from mateiz/decimal-1 and squashes the following commits:

35e6b02 [Matei Zaharia] Fix issues after merge
227f24a [Matei Zaharia] Review comments
31f915e [Matei Zaharia] Implement Davies's suggestions in Python
eb84820 [Matei Zaharia] Support reading/writing decimals as fixed-length binary in Parquet
4dc6bae [Matei Zaharia] Fix decimal support in PySpark
d1d9d68 [Matei Zaharia] Fix compile error and test issues after rebase
b28933d [Matei Zaharia] Support decimal precision/scale in Hive metastore
2118c0d [Matei Zaharia] Some test and bug fixes
81db9cb [Matei Zaharia] Added mutable Decimal that will be more efficient for small precisions
7af0c3b [Matei Zaharia] Add optional precision and scale to DecimalType, but use Unlimited for now
ec0a947 [Matei Zaharia] Make the result of AVG on Decimals be Decimal, not Double


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/23f966f4
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/23f966f4
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/23f966f4

Branch: refs/heads/master
Commit: 23f966f47523f85ba440b4080eee665271f53b5e
Parents: 56f2c61
Author: Matei Zaharia <ma...@databricks.com>
Authored: Sat Nov 1 19:29:14 2014 -0700
Committer: Michael Armbrust <mi...@databricks.com>
Committed: Sat Nov 1 19:29:14 2014 -0700

----------------------------------------------------------------------
 python/pyspark/sql.py                           |  35 +-
 .../spark/sql/catalyst/ScalaReflection.scala    |  20 +-
 .../apache/spark/sql/catalyst/SqlParser.scala   |  14 +-
 .../catalyst/analysis/HiveTypeCoercion.scala    | 146 +++++++-
 .../apache/spark/sql/catalyst/dsl/package.scala |  11 +-
 .../spark/sql/catalyst/expressions/Cast.scala   |  78 +++--
 .../sql/catalyst/expressions/aggregates.scala   |  55 ++-
 .../sql/catalyst/expressions/arithmetic.scala   |  10 +-
 .../expressions/codegen/CodeGenerator.scala     |  31 +-
 .../catalyst/expressions/decimalFunctions.scala |  59 ++++
 .../sql/catalyst/expressions/literals.scala     |   6 +-
 .../sql/catalyst/optimizer/Optimizer.scala      |  38 ++-
 .../spark/sql/catalyst/types/dataTypes.scala    |  84 ++++-
 .../sql/catalyst/types/decimal/Decimal.scala    | 335 +++++++++++++++++++
 .../sql/catalyst/ScalaReflectionSuite.scala     |  14 +-
 .../sql/catalyst/analysis/AnalysisSuite.scala   |   6 +-
 .../analysis/DecimalPrecisionSuite.scala        |  88 +++++
 .../analysis/HiveTypeCoercionSuite.scala        |  17 +-
 .../expressions/ExpressionEvaluationSuite.scala |  90 ++++-
 .../catalyst/types/decimal/DecimalSuite.scala   | 158 +++++++++
 .../org/apache/spark/sql/api/java/DataType.java |   5 -
 .../apache/spark/sql/api/java/DecimalType.java  |  58 +++-
 .../scala/org/apache/spark/sql/SchemaRDD.scala  |   3 +-
 .../spark/sql/api/java/JavaSQLContext.scala     |   2 +-
 .../org/apache/spark/sql/api/java/Row.scala     |   4 +
 .../sql/execution/GeneratedAggregate.scala      |  41 ++-
 .../apache/spark/sql/execution/SparkPlan.scala  |   4 +-
 .../sql/execution/SparkSqlSerializer.scala      |   2 +
 .../spark/sql/execution/basicOperators.scala    |   7 +-
 .../sql/execution/joins/BroadcastHashJoin.scala |   3 +-
 .../apache/spark/sql/execution/pythonUdfs.scala |   6 +-
 .../org/apache/spark/sql/json/JsonRDD.scala     |  20 +-
 .../scala/org/apache/spark/sql/package.scala    |  14 +
 .../spark/sql/parquet/ParquetConverter.scala    |  43 +++
 .../spark/sql/parquet/ParquetTableSupport.scala |  28 ++
 .../apache/spark/sql/parquet/ParquetTypes.scala |  79 +++--
 .../sql/types/util/DataTypeConversions.scala    |  13 +-
 .../sql/api/java/JavaApplySchemaSuite.java      |   2 +-
 .../java/JavaSideDataTypeConversionSuite.java   |   9 +-
 .../org/apache/spark/sql/DataTypeSuite.scala    |   2 +-
 .../sql/ScalaReflectionRelationSuite.scala      |   5 +-
 .../spark/sql/api/java/JavaSQLSuite.scala       |   2 +
 .../java/ScalaSideDataTypeConversionSuite.scala |   4 +-
 .../org/apache/spark/sql/json/JsonSuite.scala   |  46 +--
 .../spark/sql/parquet/ParquetQuerySuite.scala   |  35 +-
 .../server/SparkSQLOperationManager.scala       |   4 +-
 .../spark/sql/hive/thriftserver/Shim12.scala    |   4 +-
 .../spark/sql/hive/thriftserver/Shim13.scala    |   2 +-
 .../org/apache/spark/sql/hive/HiveContext.scala |   9 +-
 .../apache/spark/sql/hive/HiveInspectors.scala  |  24 +-
 .../spark/sql/hive/HiveMetastoreCatalog.scala   |  14 +-
 .../org/apache/spark/sql/hive/HiveQl.scala      |  15 +-
 .../hive/execution/InsertIntoHiveTable.scala    |   3 +-
 .../org/apache/spark/sql/hive/Shim12.scala      |  22 +-
 .../org/apache/spark/sql/hive/Shim13.scala      |  39 ++-
 55 files changed, 1636 insertions(+), 232 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/23f966f4/python/pyspark/sql.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py
index 93bfc25..98e41f8 100644
--- a/python/pyspark/sql.py
+++ b/python/pyspark/sql.py
@@ -35,6 +35,7 @@ import datetime
 import keyword
 import warnings
 import json
+import re
 from array import array
 from operator import itemgetter
 from itertools import imap
@@ -148,13 +149,30 @@ class TimestampType(PrimitiveType):
     """
 
 
-class DecimalType(PrimitiveType):
+class DecimalType(DataType):
 
     """Spark SQL DecimalType
 
     The data type representing decimal.Decimal values.
     """
 
+    def __init__(self, precision=None, scale=None):
+        self.precision = precision
+        self.scale = scale
+        self.hasPrecisionInfo = precision is not None
+
+    def jsonValue(self):
+        if self.hasPrecisionInfo:
+            return "decimal(%d,%d)" % (self.precision, self.scale)
+        else:
+            return "decimal"
+
+    def __repr__(self):
+        if self.hasPrecisionInfo:
+            return "DecimalType(%d,%d)" % (self.precision, self.scale)
+        else:
+            return "DecimalType()"
+
 
 class DoubleType(PrimitiveType):
 
@@ -446,9 +464,20 @@ def _parse_datatype_json_string(json_string):
     return _parse_datatype_json_value(json.loads(json_string))
 
 
+_FIXED_DECIMAL = re.compile("decimal\\((\\d+),(\\d+)\\)")
+
+
 def _parse_datatype_json_value(json_value):
-    if type(json_value) is unicode and json_value in _all_primitive_types.keys():
-        return _all_primitive_types[json_value]()
+    if type(json_value) is unicode:
+        if json_value in _all_primitive_types.keys():
+            return _all_primitive_types[json_value]()
+        elif json_value == u'decimal':
+            return DecimalType()
+        elif _FIXED_DECIMAL.match(json_value):
+            m = _FIXED_DECIMAL.match(json_value)
+            return DecimalType(int(m.group(1)), int(m.group(2)))
+        else:
+            raise ValueError("Could not parse datatype: %s" % json_value)
     else:
         return _all_complex_types[json_value["type"]].fromJson(json_value)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/23f966f4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 75923d9..8fbdf66 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -19,9 +19,10 @@ package org.apache.spark.sql.catalyst
 
 import java.sql.{Date, Timestamp}
 
-import org.apache.spark.sql.catalyst.expressions.{GenericRow, Attribute, AttributeReference}
+import org.apache.spark.sql.catalyst.expressions.{GenericRow, Attribute, AttributeReference, Row}
 import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
 import org.apache.spark.sql.catalyst.types._
+import org.apache.spark.sql.catalyst.types.decimal.Decimal
 
 /**
  * Provides experimental support for generating catalyst schemas for scala objects.
@@ -40,9 +41,20 @@ object ScalaReflection {
     case s: Seq[_] => s.map(convertToCatalyst)
     case m: Map[_, _] => m.map { case (k, v) => convertToCatalyst(k) -> convertToCatalyst(v) }
     case p: Product => new GenericRow(p.productIterator.map(convertToCatalyst).toArray)
+    case d: BigDecimal => Decimal(d)
     case other => other
   }
 
+  /** Converts Catalyst types used internally in rows to standard Scala types */
+  def convertToScala(a: Any): Any = a match {
+    case s: Seq[_] => s.map(convertToScala)
+    case m: Map[_, _] => m.map { case (k, v) => convertToScala(k) -> convertToScala(v) }
+    case d: Decimal => d.toBigDecimal
+    case other => other
+  }
+
+  def convertRowToScala(r: Row): Row = new GenericRow(r.toArray.map(convertToScala))
+
   /** Returns a Sequence of attributes for the given case class type. */
   def attributesFor[T: TypeTag]: Seq[Attribute] = schemaFor[T] match {
     case Schema(s: StructType, _) =>
@@ -83,7 +95,8 @@ object ScalaReflection {
     case t if t <:< typeOf[String] => Schema(StringType, nullable = true)
     case t if t <:< typeOf[Timestamp] => Schema(TimestampType, nullable = true)
     case t if t <:< typeOf[Date] => Schema(DateType, nullable = true)
-    case t if t <:< typeOf[BigDecimal] => Schema(DecimalType, nullable = true)
+    case t if t <:< typeOf[BigDecimal] => Schema(DecimalType.Unlimited, nullable = true)
+    case t if t <:< typeOf[Decimal] => Schema(DecimalType.Unlimited, nullable = true)
     case t if t <:< typeOf[java.lang.Integer] => Schema(IntegerType, nullable = true)
     case t if t <:< typeOf[java.lang.Long] => Schema(LongType, nullable = true)
     case t if t <:< typeOf[java.lang.Double] => Schema(DoubleType, nullable = true)
@@ -111,8 +124,9 @@ object ScalaReflection {
     case obj: LongType.JvmType => LongType
     case obj: FloatType.JvmType => FloatType
     case obj: DoubleType.JvmType => DoubleType
-    case obj: DecimalType.JvmType => DecimalType
     case obj: DateType.JvmType => DateType
+    case obj: BigDecimal => DecimalType.Unlimited
+    case obj: Decimal => DecimalType.Unlimited
     case obj: TimestampType.JvmType => TimestampType
     case null => NullType
     // For other cases, there is no obvious mapping from the type of the given object to a

http://git-wip-us.apache.org/repos/asf/spark/blob/23f966f4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
index b1e7570..00fc4d7 100755
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
@@ -52,11 +52,13 @@ class SqlParser extends AbstractSparkSQLParser {
   protected val CASE = Keyword("CASE")
   protected val CAST = Keyword("CAST")
   protected val COUNT = Keyword("COUNT")
+  protected val DECIMAL = Keyword("DECIMAL")
   protected val DESC = Keyword("DESC")
   protected val DISTINCT = Keyword("DISTINCT")
   protected val ELSE = Keyword("ELSE")
   protected val END = Keyword("END")
   protected val EXCEPT = Keyword("EXCEPT")
+  protected val DOUBLE = Keyword("DOUBLE")
   protected val FALSE = Keyword("FALSE")
   protected val FIRST = Keyword("FIRST")
   protected val FROM = Keyword("FROM")
@@ -385,5 +387,15 @@ class SqlParser extends AbstractSparkSQLParser {
     }
 
   protected lazy val dataType: Parser[DataType] =
-    STRING ^^^ StringType | TIMESTAMP ^^^ TimestampType
+    ( STRING ^^^ StringType
+    | TIMESTAMP ^^^ TimestampType
+    | DOUBLE ^^^ DoubleType
+    | fixedDecimalType
+    | DECIMAL ^^^ DecimalType.Unlimited
+    )
+
+  protected lazy val fixedDecimalType: Parser[DataType] =
+    (DECIMAL ~ "(" ~> numericLit) ~ ("," ~> numericLit <~ ")") ^^ {
+      case precision ~ scale => DecimalType(precision.toInt, scale.toInt)
+    }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/23f966f4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index 2b69c02..e38114a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -25,19 +25,31 @@ import org.apache.spark.sql.catalyst.types._
 object HiveTypeCoercion {
   // See https://cwiki.apache.org/confluence/display/Hive/LanguageManual+Types.
   // The conversion for integral and floating point types have a linear widening hierarchy:
-  val numericPrecedence =
-    Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, DecimalType)
-  val allPromotions: Seq[Seq[DataType]] = numericPrecedence :: Nil
+  private val numericPrecedence =
+    Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, DecimalType.Unlimited)
 
+  /**
+   * Find the tightest common type of two types that might be used in a binary expression.
+   * This handles all numeric types except fixed-precision decimals interacting with each other or
+   * with primitive types, because in that case the precision and scale of the result depends on
+   * the operation. Those rules are implemented in [[HiveTypeCoercion.DecimalPrecision]].
+   */
   def findTightestCommonType(t1: DataType, t2: DataType): Option[DataType] = {
     val valueTypes = Seq(t1, t2).filter(t => t != NullType)
     if (valueTypes.distinct.size > 1) {
-      // Try and find a promotion rule that contains both types in question.
-      val applicableConversion =
-        HiveTypeCoercion.allPromotions.find(p => p.contains(t1) && p.contains(t2))
-
-      // If found return the widest common type, otherwise None
-      applicableConversion.map(_.filter(t => t == t1 || t == t2).last)
+      // Promote numeric types to the highest of the two and all numeric types to unlimited decimal
+      if (numericPrecedence.contains(t1) && numericPrecedence.contains(t2)) {
+        Some(numericPrecedence.filter(t => t == t1 || t == t2).last)
+      } else if (t1.isInstanceOf[DecimalType] && t2.isInstanceOf[DecimalType]) {
+        // Fixed-precision decimals can up-cast into unlimited
+        if (t1 == DecimalType.Unlimited || t2 == DecimalType.Unlimited) {
+          Some(DecimalType.Unlimited)
+        } else {
+          None
+        }
+      } else {
+        None
+      }
     } else {
       Some(if (valueTypes.size == 0) NullType else valueTypes.head)
     }
@@ -59,6 +71,7 @@ trait HiveTypeCoercion {
     ConvertNaNs ::
     WidenTypes ::
     PromoteStrings ::
+    DecimalPrecision ::
     BooleanComparisons ::
     BooleanCasts ::
     StringToIntegralCasts ::
@@ -151,6 +164,7 @@ trait HiveTypeCoercion {
     import HiveTypeCoercion._
 
     def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+      // TODO: unions with fixed-precision decimals
       case u @ Union(left, right) if u.childrenResolved && !u.resolved =>
         val castedInput = left.output.zip(right.output).map {
           // When a string is found on one side, make the other side a string too.
@@ -265,6 +279,110 @@ trait HiveTypeCoercion {
     }
   }
 
+  // scalastyle:off
+  /**
+   * Calculates and propagates precision for fixed-precision decimals. Hive has a number of
+   * rules for this based on the SQL standard and MS SQL:
+   * https://cwiki.apache.org/confluence/download/attachments/27362075/Hive_Decimal_Precision_Scale_Support.pdf
+   *
+   * In particular, if we have expressions e1 and e2 with precision/scale p1/s2 and p2/s2
+   * respectively, then the following operations have the following precision / scale:
+   *
+   *   Operation    Result Precision                        Result Scale
+   *   ------------------------------------------------------------------------
+   *   e1 + e2      max(s1, s2) + max(p1-s1, p2-s2) + 1     max(s1, s2)
+   *   e1 - e2      max(s1, s2) + max(p1-s1, p2-s2) + 1     max(s1, s2)
+   *   e1 * e2      p1 + p2 + 1                             s1 + s2
+   *   e1 / e2      p1 - s1 + s2 + max(6, s1 + p2 + 1)      max(6, s1 + p2 + 1)
+   *   e1 % e2      min(p1-s1, p2-s2) + max(s1, s2)         max(s1, s2)
+   *   sum(e1)      p1 + 10                                 s1
+   *   avg(e1)      p1 + 4                                  s1 + 4
+   *
+   * Catalyst also has unlimited-precision decimals. For those, all ops return unlimited precision.
+   *
+   * To implement the rules for fixed-precision types, we introduce casts to turn them to unlimited
+   * precision, do the math on unlimited-precision numbers, then introduce casts back to the
+   * required fixed precision. This allows us to do all rounding and overflow handling in the
+   * cast-to-fixed-precision operator.
+   *
+   * In addition, when mixing non-decimal types with decimals, we use the following rules:
+   * - BYTE gets turned into DECIMAL(3, 0)
+   * - SHORT gets turned into DECIMAL(5, 0)
+   * - INT gets turned into DECIMAL(10, 0)
+   * - LONG gets turned into DECIMAL(20, 0)
+   * - FLOAT and DOUBLE cause fixed-length decimals to turn into DOUBLE (this is the same as Hive,
+   *   but note that unlimited decimals are considered bigger than doubles in WidenTypes)
+   */
+  // scalastyle:on
+  object DecimalPrecision extends Rule[LogicalPlan] {
+    import scala.math.{max, min}
+
+    // Conversion rules for integer types into fixed-precision decimals
+    val intTypeToFixed: Map[DataType, DecimalType] = Map(
+      ByteType -> DecimalType(3, 0),
+      ShortType -> DecimalType(5, 0),
+      IntegerType -> DecimalType(10, 0),
+      LongType -> DecimalType(20, 0)
+    )
+
+    def isFloat(t: DataType): Boolean = t == FloatType || t == DoubleType
+
+    def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
+      // Skip nodes whose children have not been resolved yet
+      case e if !e.childrenResolved => e
+
+      case Add(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
+        Cast(
+          Add(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)),
+          DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2))
+        )
+
+      case Subtract(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
+        Cast(
+          Subtract(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)),
+          DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2))
+        )
+
+      case Multiply(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
+        Cast(
+          Multiply(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)),
+          DecimalType(p1 + p2 + 1, s1 + s2)
+        )
+
+      case Divide(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
+        Cast(
+          Divide(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)),
+          DecimalType(p1 - s1 + s2 + max(6, s1 + p2 + 1), max(6, s1 + p2 + 1))
+        )
+
+      case Remainder(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
+        Cast(
+          Remainder(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)),
+          DecimalType(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2))
+        )
+
+      // Promote integers inside a binary expression with fixed-precision decimals to decimals,
+      // and fixed-precision decimals in an expression with floats / doubles to doubles
+      case b: BinaryExpression if b.left.dataType != b.right.dataType =>
+        (b.left.dataType, b.right.dataType) match {
+          case (t, DecimalType.Fixed(p, s)) if intTypeToFixed.contains(t) =>
+            b.makeCopy(Array(Cast(b.left, intTypeToFixed(t)), b.right))
+          case (DecimalType.Fixed(p, s), t) if intTypeToFixed.contains(t) =>
+            b.makeCopy(Array(b.left, Cast(b.right, intTypeToFixed(t))))
+          case (t, DecimalType.Fixed(p, s)) if isFloat(t) =>
+            b.makeCopy(Array(b.left, Cast(b.right, DoubleType)))
+          case (DecimalType.Fixed(p, s), t) if isFloat(t) =>
+            b.makeCopy(Array(Cast(b.left, DoubleType), b.right))
+          case _ =>
+            b
+        }
+
+      // TODO: MaxOf, MinOf, etc might want other rules
+
+      // SUM and AVERAGE are handled by the implementations of those expressions
+    }
+  }
+
   /**
    * Changes Boolean values to Bytes so that expressions like true < false can be Evaluated.
    */
@@ -330,7 +448,7 @@ trait HiveTypeCoercion {
       case e if !e.childrenResolved => e
 
       case Cast(e @ StringType(), t: IntegralType) =>
-        Cast(Cast(e, DecimalType), t)
+        Cast(Cast(e, DecimalType.Unlimited), t)
     }
   }
 
@@ -383,10 +501,12 @@ trait HiveTypeCoercion {
 
       // Decimal and Double remain the same
       case d: Divide if d.resolved && d.dataType == DoubleType => d
-      case d: Divide if d.resolved && d.dataType == DecimalType => d
+      case d: Divide if d.resolved && d.dataType.isInstanceOf[DecimalType] => d
 
-      case Divide(l, r) if l.dataType == DecimalType => Divide(l, Cast(r, DecimalType))
-      case Divide(l, r) if r.dataType == DecimalType => Divide(Cast(l, DecimalType), r)
+      case Divide(l, r) if l.dataType.isInstanceOf[DecimalType] =>
+        Divide(l, Cast(r, DecimalType.Unlimited))
+      case Divide(l, r) if r.dataType.isInstanceOf[DecimalType] =>
+        Divide(Cast(l, DecimalType.Unlimited), r)
 
       case Divide(l, r) => Divide(Cast(l, DoubleType), Cast(r, DoubleType))
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/23f966f4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index 23cfd48..7e6d770 100755
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst
 
 import java.sql.{Date, Timestamp}
 
+import org.apache.spark.sql.catalyst.types.decimal.Decimal
+
 import scala.language.implicitConversions
 
 import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
@@ -124,7 +126,8 @@ package object dsl {
     implicit def doubleToLiteral(d: Double) = Literal(d)
     implicit def stringToLiteral(s: String) = Literal(s)
     implicit def dateToLiteral(d: Date) = Literal(d)
-    implicit def decimalToLiteral(d: BigDecimal) = Literal(d)
+    implicit def bigDecimalToLiteral(d: BigDecimal) = Literal(d)
+    implicit def decimalToLiteral(d: Decimal) = Literal(d)
     implicit def timestampToLiteral(t: Timestamp) = Literal(t)
     implicit def binaryToLiteral(a: Array[Byte]) = Literal(a)
 
@@ -183,7 +186,11 @@ package object dsl {
       def date = AttributeReference(s, DateType, nullable = true)()
 
       /** Creates a new AttributeReference of type decimal */
-      def decimal = AttributeReference(s, DecimalType, nullable = true)()
+      def decimal = AttributeReference(s, DecimalType.Unlimited, nullable = true)()
+
+      /** Creates a new AttributeReference of type decimal */
+      def decimal(precision: Int, scale: Int) =
+        AttributeReference(s, DecimalType(precision, scale), nullable = true)()
 
       /** Creates a new AttributeReference of type timestamp */
       def timestamp = AttributeReference(s, TimestampType, nullable = true)()

http://git-wip-us.apache.org/repos/asf/spark/blob/23f966f4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 8e5baf0..2200966 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -23,6 +23,7 @@ import java.text.{DateFormat, SimpleDateFormat}
 import org.apache.spark.Logging
 import org.apache.spark.sql.catalyst.errors.TreeNodeException
 import org.apache.spark.sql.catalyst.types._
+import org.apache.spark.sql.catalyst.types.decimal.Decimal
 
 /** Cast the child expression to the target data type. */
 case class Cast(child: Expression, dataType: DataType) extends UnaryExpression with Logging {
@@ -36,6 +37,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
     case (BooleanType, DateType)      => true
     case (DateType, _: NumericType)   => true
     case (DateType, BooleanType)      => true
+    case (_, DecimalType.Fixed(_, _)) => true  // TODO: not all upcasts here can really give null
     case _                            => child.nullable
   }
 
@@ -76,8 +78,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
       buildCast[Short](_, _ != 0)
     case ByteType =>
       buildCast[Byte](_, _ != 0)
-    case DecimalType =>
-      buildCast[BigDecimal](_, _ != 0)
+    case DecimalType() =>
+      buildCast[Decimal](_, _ != 0)
     case DoubleType =>
       buildCast[Double](_, _ != 0)
     case FloatType =>
@@ -109,19 +111,19 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
     case DateType =>
       buildCast[Date](_, d => new Timestamp(d.getTime))
     // TimestampWritable.decimalToTimestamp
-    case DecimalType =>
-      buildCast[BigDecimal](_, d => decimalToTimestamp(d))
+    case DecimalType() =>
+      buildCast[Decimal](_, d => decimalToTimestamp(d))
     // TimestampWritable.doubleToTimestamp
     case DoubleType =>
-      buildCast[Double](_, d => decimalToTimestamp(d))
+      buildCast[Double](_, d => decimalToTimestamp(Decimal(d)))
     // TimestampWritable.floatToTimestamp
     case FloatType =>
-      buildCast[Float](_, f => decimalToTimestamp(f))
+      buildCast[Float](_, f => decimalToTimestamp(Decimal(f)))
   }
 
-  private[this]  def decimalToTimestamp(d: BigDecimal) = {
+  private[this]  def decimalToTimestamp(d: Decimal) = {
     val seconds = Math.floor(d.toDouble).toLong
-    val bd = (d - seconds) * 1000000000
+    val bd = (d.toBigDecimal - seconds) * 1000000000
     val nanos = bd.intValue()
 
     val millis = seconds * 1000
@@ -196,8 +198,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
       buildCast[Date](_, d => dateToLong(d))
     case TimestampType =>
       buildCast[Timestamp](_, t => timestampToLong(t))
-    case DecimalType =>
-      buildCast[BigDecimal](_, _.toLong)
+    case DecimalType() =>
+      buildCast[Decimal](_, _.toLong)
     case x: NumericType =>
       b => x.numeric.asInstanceOf[Numeric[Any]].toLong(b)
   }
@@ -214,8 +216,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
       buildCast[Date](_, d => dateToLong(d))
     case TimestampType =>
       buildCast[Timestamp](_, t => timestampToLong(t).toInt)
-    case DecimalType =>
-      buildCast[BigDecimal](_, _.toInt)
+    case DecimalType() =>
+      buildCast[Decimal](_, _.toInt)
     case x: NumericType =>
       b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b)
   }
@@ -232,8 +234,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
       buildCast[Date](_, d => dateToLong(d))
     case TimestampType =>
       buildCast[Timestamp](_, t => timestampToLong(t).toShort)
-    case DecimalType =>
-      buildCast[BigDecimal](_, _.toShort)
+    case DecimalType() =>
+      buildCast[Decimal](_, _.toShort)
     case x: NumericType =>
       b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toShort
   }
@@ -250,27 +252,45 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
       buildCast[Date](_, d => dateToLong(d))
     case TimestampType =>
       buildCast[Timestamp](_, t => timestampToLong(t).toByte)
-    case DecimalType =>
-      buildCast[BigDecimal](_, _.toByte)
+    case DecimalType() =>
+      buildCast[Decimal](_, _.toByte)
     case x: NumericType =>
       b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toByte
   }
 
-  // DecimalConverter
-  private[this] def castToDecimal: Any => Any = child.dataType match {
+  /**
+   * Change the precision / scale in a given decimal to those set in `decimalType` (if any),
+   * returning null if it overflows or modifying `value` in-place and returning it if successful.
+   *
+   * NOTE: this modifies `value` in-place, so don't call it on external data.
+   */
+  private[this] def changePrecision(value: Decimal, decimalType: DecimalType): Decimal = {
+    decimalType match {
+      case DecimalType.Unlimited =>
+        value
+      case DecimalType.Fixed(precision, scale) =>
+        if (value.changePrecision(precision, scale)) value else null
+    }
+  }
+
+  private[this] def castToDecimal(target: DecimalType): Any => Any = child.dataType match {
     case StringType =>
-      buildCast[String](_, s => try BigDecimal(s.toDouble) catch {
+      buildCast[String](_, s => try changePrecision(Decimal(s.toDouble), target) catch {
         case _: NumberFormatException => null
       })
     case BooleanType =>
-      buildCast[Boolean](_, b => if (b) BigDecimal(1) else BigDecimal(0))
+      buildCast[Boolean](_, b => changePrecision(if (b) Decimal(1) else Decimal(0), target))
     case DateType =>
-      buildCast[Date](_, d => dateToDouble(d))
+      buildCast[Date](_, d => changePrecision(null, target)) // date can't cast to decimal in Hive
     case TimestampType =>
       // Note that we lose precision here.
-      buildCast[Timestamp](_, t => BigDecimal(timestampToDouble(t)))
-    case x: NumericType =>
-      b => BigDecimal(x.numeric.asInstanceOf[Numeric[Any]].toDouble(b))
+      buildCast[Timestamp](_, t => changePrecision(Decimal(timestampToDouble(t)), target))
+    case DecimalType() =>
+      b => changePrecision(b.asInstanceOf[Decimal].clone(), target)
+    case LongType =>
+      b => changePrecision(Decimal(b.asInstanceOf[Long]), target)
+    case x: NumericType =>  // All other numeric types can be represented precisely as Doubles
+      b => changePrecision(Decimal(x.numeric.asInstanceOf[Numeric[Any]].toDouble(b)), target)
   }
 
   // DoubleConverter
@@ -285,8 +305,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
       buildCast[Date](_, d => dateToDouble(d))
     case TimestampType =>
       buildCast[Timestamp](_, t => timestampToDouble(t))
-    case DecimalType =>
-      buildCast[BigDecimal](_, _.toDouble)
+    case DecimalType() =>
+      buildCast[Decimal](_, _.toDouble)
     case x: NumericType =>
       b => x.numeric.asInstanceOf[Numeric[Any]].toDouble(b)
   }
@@ -303,8 +323,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
       buildCast[Date](_, d => dateToDouble(d))
     case TimestampType =>
       buildCast[Timestamp](_, t => timestampToDouble(t).toFloat)
-    case DecimalType =>
-      buildCast[BigDecimal](_, _.toFloat)
+    case DecimalType() =>
+      buildCast[Decimal](_, _.toFloat)
     case x: NumericType =>
       b => x.numeric.asInstanceOf[Numeric[Any]].toFloat(b)
   }
@@ -313,8 +333,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
     case dt if dt == child.dataType => identity[Any]
     case StringType    => castToString
     case BinaryType    => castToBinary
-    case DecimalType   => castToDecimal
     case DateType      => castToDate
+    case decimal: DecimalType => castToDecimal(decimal)
     case TimestampType => castToTimestamp
     case BooleanType   => castToBoolean
     case ByteType      => castToByte

http://git-wip-us.apache.org/repos/asf/spark/blob/23f966f4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
index 1b4d892..2b364fc 100755
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
@@ -286,18 +286,38 @@ case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05)
 case class Average(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
 
   override def nullable = false
-  override def dataType = DoubleType
+
+  override def dataType = child.dataType match {
+    case DecimalType.Fixed(precision, scale) =>
+      DecimalType(precision + 4, scale + 4)  // Add 4 digits after decimal point, like Hive
+    case DecimalType.Unlimited =>
+      DecimalType.Unlimited
+    case _ =>
+      DoubleType
+  }
+
   override def toString = s"AVG($child)"
 
   override def asPartial: SplitEvaluation = {
     val partialSum = Alias(Sum(child), "PartialSum")()
     val partialCount = Alias(Count(child), "PartialCount")()
-    val castedSum = Cast(Sum(partialSum.toAttribute), dataType)
-    val castedCount = Cast(Sum(partialCount.toAttribute), dataType)
 
-    SplitEvaluation(
-      Divide(castedSum, castedCount),
-      partialCount :: partialSum :: Nil)
+    child.dataType match {
+      case DecimalType.Fixed(_, _) =>
+        // Turn the results to unlimited decimals for the divsion, before going back to fixed
+        val castedSum = Cast(Sum(partialSum.toAttribute), DecimalType.Unlimited)
+        val castedCount = Cast(Sum(partialCount.toAttribute), DecimalType.Unlimited)
+        SplitEvaluation(
+          Cast(Divide(castedSum, castedCount), dataType),
+          partialCount :: partialSum :: Nil)
+
+      case _ =>
+        val castedSum = Cast(Sum(partialSum.toAttribute), dataType)
+        val castedCount = Cast(Sum(partialCount.toAttribute), dataType)
+        SplitEvaluation(
+          Divide(castedSum, castedCount),
+          partialCount :: partialSum :: Nil)
+    }
   }
 
   override def newInstance() = new AverageFunction(child, this)
@@ -306,7 +326,16 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN
 case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
 
   override def nullable = false
-  override def dataType = child.dataType
+
+  override def dataType = child.dataType match {
+    case DecimalType.Fixed(precision, scale) =>
+      DecimalType(precision + 10, scale)  // Add 10 digits left of decimal point, like Hive
+    case DecimalType.Unlimited =>
+      DecimalType.Unlimited
+    case _ =>
+      child.dataType
+  }
+
   override def toString = s"SUM($child)"
 
   override def asPartial: SplitEvaluation = {
@@ -322,9 +351,17 @@ case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[
 case class SumDistinct(child: Expression)
   extends AggregateExpression with trees.UnaryNode[Expression] {
 
-
   override def nullable = false
-  override def dataType = child.dataType
+
+  override def dataType = child.dataType match {
+    case DecimalType.Fixed(precision, scale) =>
+      DecimalType(precision + 10, scale)  // Add 10 digits left of decimal point, like Hive
+    case DecimalType.Unlimited =>
+      DecimalType.Unlimited
+    case _ =>
+      child.dataType
+  }
+
   override def toString = s"SUM(DISTINCT $child)"
 
   override def newInstance() = new SumDistinctFunction(child, this)

http://git-wip-us.apache.org/repos/asf/spark/blob/23f966f4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index 83e8466..8574cab 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -36,7 +36,7 @@ case class UnaryMinus(child: Expression) extends UnaryExpression {
 
 case class Sqrt(child: Expression) extends UnaryExpression {
   type EvaluatedType = Any
-  
+
   def dataType = DoubleType
   override def foldable = child.foldable
   def nullable = child.nullable
@@ -55,7 +55,9 @@ abstract class BinaryArithmetic extends BinaryExpression {
   def nullable = left.nullable || right.nullable
 
   override lazy val resolved =
-    left.resolved && right.resolved && left.dataType == right.dataType
+    left.resolved && right.resolved &&
+    left.dataType == right.dataType &&
+    !DecimalType.isFixed(left.dataType)
 
   def dataType = {
     if (!resolved) {
@@ -104,6 +106,8 @@ case class Multiply(left: Expression, right: Expression) extends BinaryArithmeti
 case class Divide(left: Expression, right: Expression) extends BinaryArithmetic {
   def symbol = "/"
 
+  override def nullable = left.nullable || right.nullable || dataType.isInstanceOf[DecimalType]
+
   override def eval(input: Row): Any = dataType match {
     case _: FractionalType => f2(input, left, right, _.div(_, _))
     case _: IntegralType => i2(input, left , right, _.quot(_, _))
@@ -114,6 +118,8 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic
 case class Remainder(left: Expression, right: Expression) extends BinaryArithmetic {
   def symbol = "%"
 
+  override def nullable = left.nullable || right.nullable || dataType.isInstanceOf[DecimalType]
+
   override def eval(input: Row): Any = i2(input, left, right, _.rem(_, _))
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/23f966f4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 5a3f013..67f8d41 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.sql.catalyst.expressions.codegen
 
 import com.google.common.cache.{CacheLoader, CacheBuilder}
+import org.apache.spark.sql.catalyst.types.decimal.Decimal
 
 import scala.language.existentials
 
@@ -485,6 +486,34 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
           }
         """.children
 
+      case UnscaledValue(child) =>
+        val childEval = expressionEvaluator(child)
+
+        childEval.code ++
+        q"""
+         var $nullTerm = ${childEval.nullTerm}
+         var $primitiveTerm: Long = if (!$nullTerm) {
+           ${childEval.primitiveTerm}.toUnscaledLong
+         } else {
+           ${defaultPrimitive(LongType)}
+         }
+         """.children
+
+      case MakeDecimal(child, precision, scale) =>
+        val childEval = expressionEvaluator(child)
+
+        childEval.code ++
+        q"""
+         var $nullTerm = ${childEval.nullTerm}
+         var $primitiveTerm: org.apache.spark.sql.catalyst.types.decimal.Decimal =
+           ${defaultPrimitive(DecimalType())}
+
+         if (!$nullTerm) {
+           $primitiveTerm = new org.apache.spark.sql.catalyst.types.decimal.Decimal()
+           $primitiveTerm = $primitiveTerm.setOrNull(${childEval.primitiveTerm}, $precision, $scale)
+           $nullTerm = $primitiveTerm == null
+         }
+         """.children
     }
 
     // If there was no match in the partial function above, we fall back on calling the interpreted
@@ -562,7 +591,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
     case LongType => ru.Literal(Constant(1L))
     case ByteType => ru.Literal(Constant(-1.toByte))
     case DoubleType => ru.Literal(Constant(-1.toDouble))
-    case DecimalType => ru.Literal(Constant(-1)) // Will get implicity converted as needed.
+    case DecimalType() => q"org.apache.spark.sql.catalyst.types.decimal.Decimal(-1)"
     case IntegerType => ru.Literal(Constant(-1))
     case _ => ru.Literal(Constant(null))
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/23f966f4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala
new file mode 100644
index 0000000..d1eab2e
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.sql.catalyst.types.decimal.Decimal
+import org.apache.spark.sql.catalyst.types.{DecimalType, LongType, DoubleType, DataType}
+
+/** Return the unscaled Long value of a Decimal, assuming it fits in a Long */
+case class UnscaledValue(child: Expression) extends UnaryExpression {
+  override type EvaluatedType = Any
+
+  override def dataType: DataType = LongType
+  override def foldable = child.foldable
+  def nullable = child.nullable
+  override def toString = s"UnscaledValue($child)"
+
+  override def eval(input: Row): Any = {
+    val childResult = child.eval(input)
+    if (childResult == null) {
+      null
+    } else {
+      childResult.asInstanceOf[Decimal].toUnscaledLong
+    }
+  }
+}
+
+/** Create a Decimal from an unscaled Long value */
+case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends UnaryExpression {
+  override type EvaluatedType = Decimal
+
+  override def dataType: DataType = DecimalType(precision, scale)
+  override def foldable = child.foldable
+  def nullable = child.nullable
+  override def toString = s"MakeDecimal($child,$precision,$scale)"
+
+  override def eval(input: Row): Decimal = {
+    val childResult = child.eval(input)
+    if (childResult == null) {
+      null
+    } else {
+      new Decimal().setOrNull(childResult.asInstanceOf[Long], precision, scale)
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/23f966f4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
index ba24023..93c1932 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
 import java.sql.{Date, Timestamp}
 
 import org.apache.spark.sql.catalyst.types._
+import org.apache.spark.sql.catalyst.types.decimal.Decimal
 
 object Literal {
   def apply(v: Any): Literal = v match {
@@ -31,7 +32,8 @@ object Literal {
     case s: Short => Literal(s, ShortType)
     case s: String => Literal(s, StringType)
     case b: Boolean => Literal(b, BooleanType)
-    case d: BigDecimal => Literal(d, DecimalType)
+    case d: BigDecimal => Literal(Decimal(d), DecimalType.Unlimited)
+    case d: Decimal => Literal(d, DecimalType.Unlimited)
     case t: Timestamp => Literal(t, TimestampType)
     case d: Date => Literal(d, DateType)
     case a: Array[Byte] => Literal(a, BinaryType)
@@ -62,7 +64,7 @@ case class Literal(value: Any, dataType: DataType) extends LeafExpression {
 }
 
 // TODO: Specialize
-case class MutableLiteral(var value: Any, dataType: DataType, nullable: Boolean = true) 
+case class MutableLiteral(var value: Any, dataType: DataType, nullable: Boolean = true)
     extends LeafExpression {
   type EvaluatedType = Any
 

http://git-wip-us.apache.org/repos/asf/spark/blob/23f966f4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 9ce7c78..a4aa322 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.plans.LeftSemi
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.rules._
 import org.apache.spark.sql.catalyst.types._
+import org.apache.spark.sql.catalyst.types.decimal.Decimal
 
 abstract class Optimizer extends RuleExecutor[LogicalPlan]
 
@@ -43,6 +44,8 @@ object DefaultOptimizer extends Optimizer {
       SimplifyCasts,
       SimplifyCaseConversionExpressions,
       OptimizeIn) ::
+    Batch("Decimal Optimizations", FixedPoint(100),
+      DecimalAggregates) ::
     Batch("Filter Pushdown", FixedPoint(100),
       UnionPushdown,
       CombineFilters,
@@ -390,9 +393,9 @@ object PushPredicateThroughProject extends Rule[LogicalPlan] {
  * evaluated using only the attributes of the left or right side of a join.  Other
  * [[Filter]] conditions are moved into the `condition` of the [[Join]].
  *
- * And also Pushes down the join filter, where the `condition` can be evaluated using only the 
- * attributes of the left or right side of sub query when applicable. 
- * 
+ * And also Pushes down the join filter, where the `condition` can be evaluated using only the
+ * attributes of the left or right side of sub query when applicable.
+ *
  * Check https://cwiki.apache.org/confluence/display/Hive/OuterJoinBehavior for more details
  */
 object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
@@ -404,7 +407,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
   private def split(condition: Seq[Expression], left: LogicalPlan, right: LogicalPlan) = {
     val (leftEvaluateCondition, rest) =
         condition.partition(_.references subsetOf left.outputSet)
-    val (rightEvaluateCondition, commonCondition) = 
+    val (rightEvaluateCondition, commonCondition) =
         rest.partition(_.references subsetOf right.outputSet)
 
     (leftEvaluateCondition, rightEvaluateCondition, commonCondition)
@@ -413,7 +416,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
   def apply(plan: LogicalPlan): LogicalPlan = plan transform {
     // push the where condition down into join filter
     case f @ Filter(filterCondition, Join(left, right, joinType, joinCondition)) =>
-      val (leftFilterConditions, rightFilterConditions, commonFilterCondition) = 
+      val (leftFilterConditions, rightFilterConditions, commonFilterCondition) =
         split(splitConjunctivePredicates(filterCondition), left, right)
 
       joinType match {
@@ -451,7 +454,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
 
     // push down the join filter into sub query scanning if applicable
     case f @ Join(left, right, joinType, joinCondition) =>
-      val (leftJoinConditions, rightJoinConditions, commonJoinCondition) = 
+      val (leftJoinConditions, rightJoinConditions, commonJoinCondition) =
         split(joinCondition.map(splitConjunctivePredicates).getOrElse(Nil), left, right)
 
       joinType match {
@@ -519,3 +522,26 @@ object SimplifyCaseConversionExpressions extends Rule[LogicalPlan] {
     }
   }
 }
+
+/**
+ * Speeds up aggregates on fixed-precision decimals by executing them on unscaled Long values.
+ *
+ * This uses the same rules for increasing the precision and scale of the output as
+ * [[org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion.DecimalPrecision]].
+ */
+object DecimalAggregates extends Rule[LogicalPlan] {
+  import Decimal.MAX_LONG_DIGITS
+
+  /** Maximum number of decimal digits representable precisely in a Double */
+  val MAX_DOUBLE_DIGITS = 15
+
+  def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
+    case Sum(e @ DecimalType.Expression(prec, scale)) if prec + 10 <= MAX_LONG_DIGITS =>
+      MakeDecimal(Sum(UnscaledValue(e)), prec + 10, scale)
+
+    case Average(e @ DecimalType.Expression(prec, scale)) if prec + 4 <= MAX_DOUBLE_DIGITS =>
+      Cast(
+        Divide(Average(UnscaledValue(e)), Literal(math.pow(10.0, scale), DoubleType)),
+        DecimalType(prec + 4, scale + 4))
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/23f966f4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala
index 6069f9b..8dda0b1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.types
 
 import java.sql.{Date, Timestamp}
 
-import scala.math.Numeric.{BigDecimalAsIfIntegral, DoubleAsIfIntegral, FloatAsIfIntegral}
+import scala.math.Numeric.{FloatAsIfIntegral, BigDecimalAsIfIntegral, DoubleAsIfIntegral}
 import scala.reflect.ClassTag
 import scala.reflect.runtime.universe.{TypeTag, runtimeMirror, typeTag}
 import scala.util.parsing.combinator.RegexParsers
@@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.ScalaReflectionLock
 import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression}
 import org.apache.spark.sql.catalyst.util.Metadata
 import org.apache.spark.util.Utils
+import org.apache.spark.sql.catalyst.types.decimal._
 
 object DataType {
   def fromJson(json: String): DataType = parseDataType(parse(json))
@@ -91,11 +92,17 @@ object DataType {
       | "LongType" ^^^ LongType
       | "BinaryType" ^^^ BinaryType
       | "BooleanType" ^^^ BooleanType
-      | "DecimalType" ^^^ DecimalType
       | "DateType" ^^^ DateType
+      | "DecimalType()" ^^^ DecimalType.Unlimited
+      | fixedDecimalType
       | "TimestampType" ^^^ TimestampType
       )
 
+    protected lazy val fixedDecimalType: Parser[DataType] =
+      ("DecimalType(" ~> "[0-9]+".r) ~ ("," ~> "[0-9]+".r <~ ")") ^^ {
+        case precision ~ scale => DecimalType(precision.toInt, scale.toInt)
+      }
+
     protected lazy val arrayType: Parser[DataType] =
       "ArrayType" ~> "(" ~> dataType ~ "," ~ boolVal <~ ")" ^^ {
         case tpe ~ _ ~ containsNull => ArrayType(tpe, containsNull)
@@ -200,10 +207,18 @@ trait PrimitiveType extends DataType {
 }
 
 object PrimitiveType {
-  private[sql] val all = Seq(DecimalType, DateType, TimestampType, BinaryType) ++
-    NativeType.all
-
-  private[sql] val nameToType = all.map(t => t.typeName -> t).toMap
+  private val nonDecimals = Seq(DateType, TimestampType, BinaryType) ++ NativeType.all
+  private val nonDecimalNameToType = nonDecimals.map(t => t.typeName -> t).toMap
+
+  /** Given the string representation of a type, return its DataType */
+  private[sql] def nameToType(name: String): DataType = {
+    val FIXED_DECIMAL = """decimal\(\s*(\d+)\s*,\s*(\d+)\s*\)""".r
+    name match {
+      case "decimal" => DecimalType.Unlimited
+      case FIXED_DECIMAL(precision, scale) => DecimalType(precision.toInt, scale.toInt)
+      case other => nonDecimalNameToType(other)
+    }
+  }
 }
 
 abstract class NativeType extends DataType {
@@ -332,13 +347,58 @@ abstract class FractionalType extends NumericType {
   private[sql] val asIntegral: Integral[JvmType]
 }
 
-case object DecimalType extends FractionalType {
-  private[sql] type JvmType = BigDecimal
+/** Precision parameters for a Decimal */
+case class PrecisionInfo(precision: Int, scale: Int)
+
+/** A Decimal that might have fixed precision and scale, or unlimited values for these */
+case class DecimalType(precisionInfo: Option[PrecisionInfo]) extends FractionalType {
+  private[sql] type JvmType = Decimal
   @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] }
-  private[sql] val numeric = implicitly[Numeric[BigDecimal]]
-  private[sql] val fractional = implicitly[Fractional[BigDecimal]]
-  private[sql] val ordering = implicitly[Ordering[JvmType]]
-  private[sql] val asIntegral = BigDecimalAsIfIntegral
+  private[sql] val numeric = Decimal.DecimalIsFractional
+  private[sql] val fractional = Decimal.DecimalIsFractional
+  private[sql] val ordering = Decimal.DecimalIsFractional
+  private[sql] val asIntegral = Decimal.DecimalAsIfIntegral
+
+  override def typeName: String = precisionInfo match {
+    case Some(PrecisionInfo(precision, scale)) => s"decimal($precision,$scale)"
+    case None => "decimal"
+  }
+
+  override def toString: String = precisionInfo match {
+    case Some(PrecisionInfo(precision, scale)) => s"DecimalType($precision,$scale)"
+    case None => "DecimalType()"
+  }
+}
+
+/** Extra factory methods and pattern matchers for Decimals */
+object DecimalType {
+  val Unlimited: DecimalType = DecimalType(None)
+
+  object Fixed {
+    def unapply(t: DecimalType): Option[(Int, Int)] =
+      t.precisionInfo.map(p => (p.precision, p.scale))
+  }
+
+  object Expression {
+    def unapply(e: Expression): Option[(Int, Int)] = e.dataType match {
+      case t: DecimalType => t.precisionInfo.map(p => (p.precision, p.scale))
+      case _ => None
+    }
+  }
+
+  def apply(): DecimalType = Unlimited
+
+  def apply(precision: Int, scale: Int): DecimalType =
+    DecimalType(Some(PrecisionInfo(precision, scale)))
+
+  def unapply(t: DataType): Boolean = t.isInstanceOf[DecimalType]
+
+  def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[DecimalType]
+
+  def isFixed(dataType: DataType): Boolean = dataType match {
+    case DecimalType.Fixed(_, _) => true
+    case _ => false
+  }
 }
 
 case object DoubleType extends FractionalType {

http://git-wip-us.apache.org/repos/asf/spark/blob/23f966f4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/decimal/Decimal.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/decimal/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/decimal/Decimal.scala
new file mode 100644
index 0000000..708362a
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/decimal/Decimal.scala
@@ -0,0 +1,335 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.types.decimal
+
+import org.apache.spark.annotation.DeveloperApi
+
+/**
+ * A mutable implementation of BigDecimal that can hold a Long if values are small enough.
+ *
+ * The semantics of the fields are as follows:
+ * - _precision and _scale represent the SQL precision and scale we are looking for
+ * - If decimalVal is set, it represents the whole decimal value
+ * - Otherwise, the decimal value is longVal / (10 ** _scale)
+ */
+final class Decimal extends Ordered[Decimal] with Serializable {
+  import Decimal.{MAX_LONG_DIGITS, POW_10, ROUNDING_MODE, BIG_DEC_ZERO}
+
+  private var decimalVal: BigDecimal = null
+  private var longVal: Long = 0L
+  private var _precision: Int = 1
+  private var _scale: Int = 0
+
+  def precision: Int = _precision
+  def scale: Int = _scale
+
+  /**
+   * Set this Decimal to the given Long. Will have precision 20 and scale 0.
+   */
+  def set(longVal: Long): Decimal = {
+    if (longVal <= -POW_10(MAX_LONG_DIGITS) || longVal >= POW_10(MAX_LONG_DIGITS)) {
+      // We can't represent this compactly as a long without risking overflow
+      this.decimalVal = BigDecimal(longVal)
+      this.longVal = 0L
+    } else {
+      this.decimalVal = null
+      this.longVal = longVal
+    }
+    this._precision = 20
+    this._scale = 0
+    this
+  }
+
+  /**
+   * Set this Decimal to the given Int. Will have precision 10 and scale 0.
+   */
+  def set(intVal: Int): Decimal = {
+    this.decimalVal = null
+    this.longVal = intVal
+    this._precision = 10
+    this._scale = 0
+    this
+  }
+
+  /**
+   * Set this Decimal to the given unscaled Long, with a given precision and scale.
+   */
+  def set(unscaled: Long, precision: Int, scale: Int): Decimal = {
+    if (setOrNull(unscaled, precision, scale) == null) {
+      throw new IllegalArgumentException("Unscaled value too large for precision")
+    }
+    this
+  }
+
+  /**
+   * Set this Decimal to the given unscaled Long, with a given precision and scale,
+   * and return it, or return null if it cannot be set due to overflow.
+   */
+  def setOrNull(unscaled: Long, precision: Int, scale: Int): Decimal = {
+    if (unscaled <= -POW_10(MAX_LONG_DIGITS) || unscaled >= POW_10(MAX_LONG_DIGITS)) {
+      // We can't represent this compactly as a long without risking overflow
+      if (precision < 19) {
+        return null  // Requested precision is too low to represent this value
+      }
+      this.decimalVal = BigDecimal(longVal)
+      this.longVal = 0L
+    } else {
+      val p = POW_10(math.min(precision, MAX_LONG_DIGITS))
+      if (unscaled <= -p || unscaled >= p) {
+        return null  // Requested precision is too low to represent this value
+      }
+      this.decimalVal = null
+      this.longVal = unscaled
+    }
+    this._precision = precision
+    this._scale = scale
+    this
+  }
+
+  /**
+   * Set this Decimal to the given BigDecimal value, with a given precision and scale.
+   */
+  def set(decimal: BigDecimal, precision: Int, scale: Int): Decimal = {
+    this.decimalVal = decimal.setScale(scale, ROUNDING_MODE)
+    require(decimalVal.precision <= precision, "Overflowed precision")
+    this.longVal = 0L
+    this._precision = precision
+    this._scale = scale
+    this
+  }
+
+  /**
+   * Set this Decimal to the given BigDecimal value, inheriting its precision and scale.
+   */
+  def set(decimal: BigDecimal): Decimal = {
+    this.decimalVal = decimal
+    this.longVal = 0L
+    this._precision = decimal.precision
+    this._scale = decimal.scale
+    this
+  }
+
+  /**
+   * Set this Decimal to the given Decimal value.
+   */
+  def set(decimal: Decimal): Decimal = {
+    this.decimalVal = decimal.decimalVal
+    this.longVal = decimal.longVal
+    this._precision = decimal._precision
+    this._scale = decimal._scale
+    this
+  }
+
+  def toBigDecimal: BigDecimal = {
+    if (decimalVal.ne(null)) {
+      decimalVal
+    } else {
+      BigDecimal(longVal, _scale)
+    }
+  }
+
+  def toUnscaledLong: Long = {
+    if (decimalVal.ne(null)) {
+      decimalVal.underlying().unscaledValue().longValue()
+    } else {
+      longVal
+    }
+  }
+
+  override def toString: String = toBigDecimal.toString()
+
+  @DeveloperApi
+  def toDebugString: String = {
+    if (decimalVal.ne(null)) {
+      s"Decimal(expanded,$decimalVal,$precision,$scale})"
+    } else {
+      s"Decimal(compact,$longVal,$precision,$scale})"
+    }
+  }
+
+  def toDouble: Double = toBigDecimal.doubleValue()
+
+  def toFloat: Float = toBigDecimal.floatValue()
+
+  def toLong: Long = {
+    if (decimalVal.eq(null)) {
+      longVal / POW_10(_scale)
+    } else {
+      decimalVal.longValue()
+    }
+  }
+
+  def toInt: Int = toLong.toInt
+
+  def toShort: Short = toLong.toShort
+
+  def toByte: Byte = toLong.toByte
+
+  /**
+   * Update precision and scale while keeping our value the same, and return true if successful.
+   *
+   * @return true if successful, false if overflow would occur
+   */
+  def changePrecision(precision: Int, scale: Int): Boolean = {
+    // First, update our longVal if we can, or transfer over to using a BigDecimal
+    if (decimalVal.eq(null)) {
+      if (scale < _scale) {
+        // Easier case: we just need to divide our scale down
+        val diff = _scale - scale
+        val droppedDigits = longVal % POW_10(diff)
+        longVal /= POW_10(diff)
+        if (math.abs(droppedDigits) * 2 >= POW_10(diff)) {
+          longVal += (if (longVal < 0) -1L else 1L)
+        }
+      } else if (scale > _scale) {
+        // We might be able to multiply longVal by a power of 10 and not overflow, but if not,
+        // switch to using a BigDecimal
+        val diff = scale - _scale
+        val p = POW_10(math.max(MAX_LONG_DIGITS - diff, 0))
+        if (diff <= MAX_LONG_DIGITS && longVal > -p && longVal < p) {
+          // Multiplying longVal by POW_10(diff) will still keep it below MAX_LONG_DIGITS
+          longVal *= POW_10(diff)
+        } else {
+          // Give up on using Longs; switch to BigDecimal, which we'll modify below
+          decimalVal = BigDecimal(longVal, _scale)
+        }
+      }
+      // In both cases, we will check whether our precision is okay below
+    }
+
+    if (decimalVal.ne(null)) {
+      // We get here if either we started with a BigDecimal, or we switched to one because we would
+      // have overflowed our Long; in either case we must rescale decimalVal to the new scale.
+      val newVal = decimalVal.setScale(scale, ROUNDING_MODE)
+      if (newVal.precision > precision) {
+        return false
+      }
+      decimalVal = newVal
+    } else {
+      // We're still using Longs, but we should check whether we match the new precision
+      val p = POW_10(math.min(_precision, MAX_LONG_DIGITS))
+      if (longVal <= -p || longVal >= p) {
+        // Note that we shouldn't have been able to fix this by switching to BigDecimal
+        return false
+      }
+    }
+
+    _precision = precision
+    _scale = scale
+    true
+  }
+
+  override def clone(): Decimal = new Decimal().set(this)
+
+  override def compare(other: Decimal): Int = {
+    if (decimalVal.eq(null) && other.decimalVal.eq(null) && _scale == other._scale) {
+      if (longVal < other.longVal) -1 else if (longVal == other.longVal) 0 else 1
+    } else {
+      toBigDecimal.compare(other.toBigDecimal)
+    }
+  }
+
+  override def equals(other: Any) = other match {
+    case d: Decimal =>
+      compare(d) == 0
+    case _ =>
+      false
+  }
+
+  override def hashCode(): Int = toBigDecimal.hashCode()
+
+  def isZero: Boolean = if (decimalVal.ne(null)) decimalVal == BIG_DEC_ZERO else longVal == 0
+
+  def + (that: Decimal): Decimal = Decimal(toBigDecimal + that.toBigDecimal)
+
+  def - (that: Decimal): Decimal = Decimal(toBigDecimal - that.toBigDecimal)
+
+  def * (that: Decimal): Decimal = Decimal(toBigDecimal * that.toBigDecimal)
+
+  def / (that: Decimal): Decimal =
+    if (that.isZero) null else Decimal(toBigDecimal / that.toBigDecimal)
+
+  def % (that: Decimal): Decimal =
+    if (that.isZero) null else Decimal(toBigDecimal % that.toBigDecimal)
+
+  def remainder(that: Decimal): Decimal = this % that
+
+  def unary_- : Decimal = {
+    if (decimalVal.ne(null)) {
+      Decimal(-decimalVal)
+    } else {
+      Decimal(-longVal, precision, scale)
+    }
+  }
+}
+
+object Decimal {
+  private val ROUNDING_MODE = BigDecimal.RoundingMode.HALF_UP
+
+  /** Maximum number of decimal digits a Long can represent */
+  val MAX_LONG_DIGITS = 18
+
+  private val POW_10 = Array.tabulate[Long](MAX_LONG_DIGITS + 1)(i => math.pow(10, i).toLong)
+
+  private val BIG_DEC_ZERO = BigDecimal(0)
+
+  def apply(value: Double): Decimal = new Decimal().set(value)
+
+  def apply(value: Long): Decimal = new Decimal().set(value)
+
+  def apply(value: Int): Decimal = new Decimal().set(value)
+
+  def apply(value: BigDecimal): Decimal = new Decimal().set(value)
+
+  def apply(value: BigDecimal, precision: Int, scale: Int): Decimal =
+    new Decimal().set(value, precision, scale)
+
+  def apply(unscaled: Long, precision: Int, scale: Int): Decimal =
+    new Decimal().set(unscaled, precision, scale)
+
+  def apply(value: String): Decimal = new Decimal().set(BigDecimal(value))
+
+  // Evidence parameters for Decimal considered either as Fractional or Integral. We provide two
+  // parameters inheriting from a common trait since both traits define mkNumericOps.
+  // See scala.math's Numeric.scala for examples for Scala's built-in types.
+
+  /** Common methods for Decimal evidence parameters */
+  trait DecimalIsConflicted extends Numeric[Decimal] {
+    override def plus(x: Decimal, y: Decimal): Decimal = x + y
+    override def times(x: Decimal, y: Decimal): Decimal = x * y
+    override def minus(x: Decimal, y: Decimal): Decimal = x - y
+    override def negate(x: Decimal): Decimal = -x
+    override def toDouble(x: Decimal): Double = x.toDouble
+    override def toFloat(x: Decimal): Float = x.toFloat
+    override def toInt(x: Decimal): Int = x.toInt
+    override def toLong(x: Decimal): Long = x.toLong
+    override def fromInt(x: Int): Decimal = new Decimal().set(x)
+    override def compare(x: Decimal, y: Decimal): Int = x.compare(y)
+  }
+
+  /** A [[scala.math.Fractional]] evidence parameter for Decimals. */
+  object DecimalIsFractional extends DecimalIsConflicted with Fractional[Decimal] {
+    override def div(x: Decimal, y: Decimal): Decimal = x / y
+  }
+
+  /** A [[scala.math.Integral]] evidence parameter for Decimals. */
+  object DecimalAsIfIntegral extends DecimalIsConflicted with Integral[Decimal] {
+    override def quot(x: Decimal, y: Decimal): Decimal = x / y
+    override def rem(x: Decimal, y: Decimal): Decimal = x % y
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/23f966f4/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
index 430f066..21b2c8e 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
@@ -96,7 +96,7 @@ class ScalaReflectionSuite extends FunSuite {
         StructField("byteField", ByteType, nullable = true),
         StructField("booleanField", BooleanType, nullable = true),
         StructField("stringField", StringType, nullable = true),
-        StructField("decimalField", DecimalType, nullable = true),
+        StructField("decimalField", DecimalType.Unlimited, nullable = true),
         StructField("dateField", DateType, nullable = true),
         StructField("timestampField", TimestampType, nullable = true),
         StructField("binaryField", BinaryType, nullable = true))),
@@ -199,7 +199,7 @@ class ScalaReflectionSuite extends FunSuite {
     assert(DoubleType === typeOfObject(1.7976931348623157E308))
 
     // DecimalType
-    assert(DecimalType === typeOfObject(BigDecimal("1.7976931348623157E318")))
+    assert(DecimalType.Unlimited === typeOfObject(BigDecimal("1.7976931348623157E318")))
 
     // DateType
     assert(DateType === typeOfObject(Date.valueOf("2014-07-25")))
@@ -211,19 +211,19 @@ class ScalaReflectionSuite extends FunSuite {
     assert(NullType === typeOfObject(null))
 
     def typeOfObject1: PartialFunction[Any, DataType] = typeOfObject orElse {
-      case value: java.math.BigInteger => DecimalType
-      case value: java.math.BigDecimal => DecimalType
+      case value: java.math.BigInteger => DecimalType.Unlimited
+      case value: java.math.BigDecimal => DecimalType.Unlimited
       case _ => StringType
     }
 
-    assert(DecimalType === typeOfObject1(
+    assert(DecimalType.Unlimited === typeOfObject1(
       new BigInteger("92233720368547758070")))
-    assert(DecimalType === typeOfObject1(
+    assert(DecimalType.Unlimited === typeOfObject1(
       new java.math.BigDecimal("1.7976931348623157E318")))
     assert(StringType === typeOfObject1(BigInt("92233720368547758070")))
 
     def typeOfObject2: PartialFunction[Any, DataType] = typeOfObject orElse {
-      case value: java.math.BigInteger => DecimalType
+      case value: java.math.BigInteger => DecimalType.Unlimited
     }
 
     intercept[MatchError](typeOfObject2(BigInt("92233720368547758070")))

http://git-wip-us.apache.org/repos/asf/spark/blob/23f966f4/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index 7b45738..33a3cba 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -38,7 +38,7 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {
     AttributeReference("a", StringType)(),
     AttributeReference("b", StringType)(),
     AttributeReference("c", DoubleType)(),
-    AttributeReference("d", DecimalType)(),
+    AttributeReference("d", DecimalType.Unlimited)(),
     AttributeReference("e", ShortType)())
 
   before {
@@ -119,7 +119,7 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {
       AttributeReference("a", StringType)(),
       AttributeReference("b", StringType)(),
       AttributeReference("c", DoubleType)(),
-      AttributeReference("d", DecimalType)(),
+      AttributeReference("d", DecimalType.Unlimited)(),
       AttributeReference("e", ShortType)())
 
     val expr0 = 'a / 2
@@ -137,7 +137,7 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {
     assert(pl(0).dataType == DoubleType)
     assert(pl(1).dataType == DoubleType)
     assert(pl(2).dataType == DoubleType)
-    assert(pl(3).dataType == DecimalType)
+    assert(pl(3).dataType == DecimalType.Unlimited)
     assert(pl(4).dataType == DoubleType)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/23f966f4/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala
new file mode 100644
index 0000000..d5b7d27
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala
@@ -0,0 +1,88 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.logical.{Project, LocalRelation}
+import org.apache.spark.sql.catalyst.types._
+import org.scalatest.{BeforeAndAfter, FunSuite}
+
+class DecimalPrecisionSuite extends FunSuite with BeforeAndAfter {
+  val catalog = new SimpleCatalog(false)
+  val analyzer = new Analyzer(catalog, EmptyFunctionRegistry, caseSensitive = false)
+
+  val relation = LocalRelation(
+    AttributeReference("i", IntegerType)(),
+    AttributeReference("d1", DecimalType(2, 1))(),
+    AttributeReference("d2", DecimalType(5, 2))(),
+    AttributeReference("u", DecimalType.Unlimited)(),
+    AttributeReference("f", FloatType)()
+  )
+
+  val i: Expression = UnresolvedAttribute("i")
+  val d1: Expression = UnresolvedAttribute("d1")
+  val d2: Expression = UnresolvedAttribute("d2")
+  val u: Expression = UnresolvedAttribute("u")
+  val f: Expression = UnresolvedAttribute("f")
+
+  before {
+    catalog.registerTable(None, "table", relation)
+  }
+
+  private def checkType(expression: Expression, expectedType: DataType): Unit = {
+    val plan = Project(Seq(Alias(expression, "c")()), relation)
+    assert(analyzer(plan).schema.fields(0).dataType === expectedType)
+  }
+
+  test("basic operations") {
+    checkType(Add(d1, d2), DecimalType(6, 2))
+    checkType(Subtract(d1, d2), DecimalType(6, 2))
+    checkType(Multiply(d1, d2), DecimalType(8, 3))
+    checkType(Divide(d1, d2), DecimalType(10, 7))
+    checkType(Divide(d2, d1), DecimalType(10, 6))
+    checkType(Remainder(d1, d2), DecimalType(3, 2))
+    checkType(Remainder(d2, d1), DecimalType(3, 2))
+    checkType(Sum(d1), DecimalType(12, 1))
+    checkType(Average(d1), DecimalType(6, 5))
+
+    checkType(Add(Add(d1, d2), d1), DecimalType(7, 2))
+    checkType(Add(Add(Add(d1, d2), d1), d2), DecimalType(8, 2))
+    checkType(Add(Add(d1, d2), Add(d1, d2)), DecimalType(7, 2))
+  }
+
+  test("bringing in primitive types") {
+    checkType(Add(d1, i), DecimalType(12, 1))
+    checkType(Add(d1, f), DoubleType)
+    checkType(Add(i, d1), DecimalType(12, 1))
+    checkType(Add(f, d1), DoubleType)
+    checkType(Add(d1, Cast(i, LongType)), DecimalType(22, 1))
+    checkType(Add(d1, Cast(i, ShortType)), DecimalType(7, 1))
+    checkType(Add(d1, Cast(i, ByteType)), DecimalType(5, 1))
+    checkType(Add(d1, Cast(i, DoubleType)), DoubleType)
+  }
+
+  test("unlimited decimals make everything else cast up") {
+    for (expr <- Seq(d1, d2, i, f, u)) {
+      checkType(Add(expr, u), DecimalType.Unlimited)
+      checkType(Subtract(expr, u), DecimalType.Unlimited)
+      checkType(Multiply(expr, u), DecimalType.Unlimited)
+      checkType(Divide(expr, u), DecimalType.Unlimited)
+      checkType(Remainder(expr, u), DecimalType.Unlimited)
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/23f966f4/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
index baeb9b0..dfa2d95 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
@@ -68,6 +68,21 @@ class HiveTypeCoercionSuite extends FunSuite {
     widenTest(LongType, FloatType, Some(FloatType))
     widenTest(LongType, DoubleType, Some(DoubleType))
 
+    // Casting up to unlimited-precision decimal
+    widenTest(IntegerType, DecimalType.Unlimited, Some(DecimalType.Unlimited))
+    widenTest(DoubleType, DecimalType.Unlimited, Some(DecimalType.Unlimited))
+    widenTest(DecimalType(3, 2), DecimalType.Unlimited, Some(DecimalType.Unlimited))
+    widenTest(DecimalType.Unlimited, IntegerType, Some(DecimalType.Unlimited))
+    widenTest(DecimalType.Unlimited, DoubleType, Some(DecimalType.Unlimited))
+    widenTest(DecimalType.Unlimited, DecimalType(3, 2), Some(DecimalType.Unlimited))
+
+    // No up-casting for fixed-precision decimal (this is handled by arithmetic rules)
+    widenTest(DecimalType(2, 1), DecimalType(3, 2), None)
+    widenTest(DecimalType(2, 1), DoubleType, None)
+    widenTest(DecimalType(2, 1), IntegerType, None)
+    widenTest(DoubleType, DecimalType(2, 1), None)
+    widenTest(IntegerType, DecimalType(2, 1), None)
+
     // StringType
     widenTest(NullType, StringType, Some(StringType))
     widenTest(StringType, StringType, Some(StringType))
@@ -92,7 +107,7 @@ class HiveTypeCoercionSuite extends FunSuite {
     def ruleTest(initial: Expression, transformed: Expression) {
       val testRelation = LocalRelation(AttributeReference("a", IntegerType)())
       assert(booleanCasts(Project(Seq(Alias(initial, "a")()), testRelation)) ==
-        Project(Seq(Alias(transformed, "a")()), testRelation))      
+        Project(Seq(Alias(transformed, "a")()), testRelation))
     }
     // Remove superflous boolean -> boolean casts.
     ruleTest(Cast(Literal(true), BooleanType), Literal(true))


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org