You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by da...@apache.org on 2015/10/15 05:56:13 UTC

spark git commit: [SPARK-11076] [SQL] Add decimal support for floor and ceil

Repository: spark
Updated Branches:
  refs/heads/master 4ace4f8a9 -> 9808052b5


[SPARK-11076] [SQL] Add decimal support for floor and ceil

Actually all of the `UnaryMathExpression` doens't support the Decimal, will create follow ups for supporing it. This is the first PR which will be good to review the approach I am taking.

Author: Cheng Hao <ha...@intel.com>

Closes #9086 from chenghao-intel/ceiling.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/9808052b
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/9808052b
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/9808052b

Branch: refs/heads/master
Commit: 9808052b5adfed7dafd6c1b3971b998e45b2799a
Parents: 4ace4f8
Author: Cheng Hao <ha...@intel.com>
Authored: Wed Oct 14 20:56:08 2015 -0700
Committer: Davies Liu <da...@gmail.com>
Committed: Wed Oct 14 20:56:08 2015 -0700

----------------------------------------------------------------------
 .../catalyst/expressions/mathExpressions.scala  | 48 ++++++++++++++++----
 .../org/apache/spark/sql/types/Decimal.scala    | 32 +++++++++++--
 .../catalyst/expressions/LiteralGenerator.scala | 14 +++++-
 .../expressions/MathFunctionsSuite.scala        | 10 ++++
 4 files changed, 91 insertions(+), 13 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/9808052b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
index a8164e9..28f616f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
@@ -55,7 +55,7 @@ abstract class LeafMathExpression(c: Double, name: String)
 abstract class UnaryMathExpression(val f: Double => Double, name: String)
   extends UnaryExpression with Serializable with ImplicitCastInputTypes {
 
-  override def inputTypes: Seq[DataType] = Seq(DoubleType)
+  override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType)
   override def dataType: DataType = DoubleType
   override def nullable: Boolean = true
   override def toString: String = s"$name($child)"
@@ -153,13 +153,28 @@ case class Atan(child: Expression) extends UnaryMathExpression(math.atan, "ATAN"
 case class Cbrt(child: Expression) extends UnaryMathExpression(math.cbrt, "CBRT")
 
 case class Ceil(child: Expression) extends UnaryMathExpression(math.ceil, "CEIL") {
-  override def dataType: DataType = LongType
-  protected override def nullSafeEval(input: Any): Any = {
-    f(input.asInstanceOf[Double]).toLong
+  override def dataType: DataType = child.dataType match {
+    case dt @ DecimalType.Fixed(_, 0) => dt
+    case DecimalType.Fixed(precision, scale) =>
+      DecimalType.bounded(precision - scale + 1, 0)
+    case _ => LongType
+  }
+
+  override def inputTypes: Seq[AbstractDataType] =
+    Seq(TypeCollection(DoubleType, DecimalType))
+
+  protected override def nullSafeEval(input: Any): Any = child.dataType match {
+    case DoubleType => f(input.asInstanceOf[Double]).toLong
+    case DecimalType.Fixed(precision, scale) => input.asInstanceOf[Decimal].ceil
   }
 
   override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
-    defineCodeGen(ctx, ev, c => s"(long)(java.lang.Math.${funcName}($c))")
+    child.dataType match {
+      case DecimalType.Fixed(_, 0) => defineCodeGen(ctx, ev, c => s"$c")
+      case DecimalType.Fixed(precision, scale) =>
+        defineCodeGen(ctx, ev, c => s"$c.ceil()")
+      case _ => defineCodeGen(ctx, ev, c => s"(long)(java.lang.Math.${funcName}($c))")
+    }
   }
 }
 
@@ -205,13 +220,28 @@ case class Exp(child: Expression) extends UnaryMathExpression(math.exp, "EXP")
 case class Expm1(child: Expression) extends UnaryMathExpression(math.expm1, "EXPM1")
 
 case class Floor(child: Expression) extends UnaryMathExpression(math.floor, "FLOOR") {
-  override def dataType: DataType = LongType
-  protected override def nullSafeEval(input: Any): Any = {
-    f(input.asInstanceOf[Double]).toLong
+  override def dataType: DataType = child.dataType match {
+    case dt @ DecimalType.Fixed(_, 0) => dt
+    case DecimalType.Fixed(precision, scale) =>
+      DecimalType.bounded(precision - scale + 1, 0)
+    case _ => LongType
+  }
+
+  override def inputTypes: Seq[AbstractDataType] =
+    Seq(TypeCollection(DoubleType, DecimalType))
+
+  protected override def nullSafeEval(input: Any): Any = child.dataType match {
+    case DoubleType => f(input.asInstanceOf[Double]).toLong
+    case DecimalType.Fixed(precision, scale) => input.asInstanceOf[Decimal].floor
   }
 
   override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
-    defineCodeGen(ctx, ev, c => s"(long)(java.lang.Math.${funcName}($c))")
+    child.dataType match {
+      case DecimalType.Fixed(_, 0) => defineCodeGen(ctx, ev, c => s"$c")
+      case DecimalType.Fixed(precision, scale) =>
+        defineCodeGen(ctx, ev, c => s"$c.floor()")
+      case _ => defineCodeGen(ctx, ev, c => s"(long)(java.lang.Math.${funcName}($c))")
+    }
   }
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/9808052b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
index c11dab3..c7a1a2e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
@@ -107,7 +107,7 @@ final class Decimal extends Ordered[Decimal] with Serializable {
    * Set this Decimal to the given BigDecimal value, with a given precision and scale.
    */
   def set(decimal: BigDecimal, precision: Int, scale: Int): Decimal = {
-    this.decimalVal = decimal.setScale(scale, ROUNDING_MODE)
+    this.decimalVal = decimal.setScale(scale, ROUND_HALF_UP)
     require(
       decimalVal.precision <= precision,
       s"Decimal precision ${decimalVal.precision} exceeds max precision $precision")
@@ -198,6 +198,16 @@ final class Decimal extends Ordered[Decimal] with Serializable {
    * @return true if successful, false if overflow would occur
    */
   def changePrecision(precision: Int, scale: Int): Boolean = {
+    changePrecision(precision, scale, ROUND_HALF_UP)
+  }
+
+  /**
+   * Update precision and scale while keeping our value the same, and return true if successful.
+   *
+   * @return true if successful, false if overflow would occur
+   */
+  private[sql] def changePrecision(precision: Int, scale: Int,
+                      roundMode: BigDecimal.RoundingMode.Value): Boolean = {
     // fast path for UnsafeProjection
     if (precision == this.precision && scale == this.scale) {
       return true
@@ -231,7 +241,7 @@ final class Decimal extends Ordered[Decimal] with Serializable {
     if (decimalVal.ne(null)) {
       // We get here if either we started with a BigDecimal, or we switched to one because we would
       // have overflowed our Long; in either case we must rescale decimalVal to the new scale.
-      val newVal = decimalVal.setScale(scale, ROUNDING_MODE)
+      val newVal = decimalVal.setScale(scale, roundMode)
       if (newVal.precision > precision) {
         return false
       }
@@ -309,10 +319,26 @@ final class Decimal extends Ordered[Decimal] with Serializable {
   }
 
   def abs: Decimal = if (this.compare(Decimal.ZERO) < 0) this.unary_- else this
+
+  def floor: Decimal = if (scale == 0) this else {
+    val value = this.clone()
+    value.changePrecision(
+      DecimalType.bounded(precision - scale + 1, 0).precision, 0, ROUND_FLOOR)
+    value
+  }
+
+  def ceil: Decimal = if (scale == 0) this else {
+    val value = this.clone()
+    value.changePrecision(
+      DecimalType.bounded(precision - scale + 1, 0).precision, 0, ROUND_CEILING)
+    value
+  }
 }
 
 object Decimal {
-  private val ROUNDING_MODE = BigDecimal.RoundingMode.HALF_UP
+  val ROUND_HALF_UP = BigDecimal.RoundingMode.HALF_UP
+  val ROUND_CEILING = BigDecimal.RoundingMode.CEILING
+  val ROUND_FLOOR = BigDecimal.RoundingMode.FLOOR
 
   /** Maximum number of decimal digits a Long can represent */
   val MAX_LONG_DIGITS = 18

http://git-wip-us.apache.org/repos/asf/spark/blob/9808052b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralGenerator.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralGenerator.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralGenerator.scala
index ee6d251..d9c9141 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralGenerator.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralGenerator.scala
@@ -78,7 +78,18 @@ object LiteralGenerator {
         Double.NaN, Double.PositiveInfinity, Double.NegativeInfinity)
     } yield Literal.create(f, DoubleType)
 
-  // TODO: decimal type
+  // TODO cache the generated data
+  def decimalLiteralGen(precision: Int, scale: Int): Gen[Literal] = {
+    assert(scale >= 0)
+    assert(precision >= scale)
+    Arbitrary.arbBigInt.arbitrary.map { s =>
+      val a = (s % BigInt(10).pow(precision - scale)).toString()
+      val b = (s % BigInt(10).pow(scale)).abs.toString()
+      Literal.create(
+        Decimal(BigDecimal(s"$a.$b"), precision, scale),
+        DecimalType(precision, scale))
+    }
+  }
 
   lazy val stringLiteralGen: Gen[Literal] =
     for { s <- Arbitrary.arbString.arbitrary } yield Literal.create(s, StringType)
@@ -122,6 +133,7 @@ object LiteralGenerator {
       case StringType => stringLiteralGen
       case BinaryType => binaryLiteralGen
       case CalendarIntervalType => calendarIntervalLiterGen
+      case DecimalType.Fixed(precision, scale) => decimalLiteralGen(precision, scale)
       case dt => throw new IllegalArgumentException(s"not supported type $dt")
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/9808052b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
index 1b2a916..88ed9fd 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
@@ -246,11 +246,21 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
   test("ceil") {
     testUnary(Ceil, (d: Double) => math.ceil(d).toLong)
     checkConsistencyBetweenInterpretedAndCodegen(Ceil, DoubleType)
+
+    testUnary(Ceil, (d: Decimal) => d.ceil, (-20 to 20).map(x => Decimal(x * 0.1)))
+    checkConsistencyBetweenInterpretedAndCodegen(Ceil, DecimalType(25, 3))
+    checkConsistencyBetweenInterpretedAndCodegen(Ceil, DecimalType(25, 0))
+    checkConsistencyBetweenInterpretedAndCodegen(Ceil, DecimalType(5, 0))
   }
 
   test("floor") {
     testUnary(Floor, (d: Double) => math.floor(d).toLong)
     checkConsistencyBetweenInterpretedAndCodegen(Floor, DoubleType)
+
+    testUnary(Floor, (d: Decimal) => d.floor, (-20 to 20).map(x => Decimal(x * 0.1)))
+    checkConsistencyBetweenInterpretedAndCodegen(Floor, DecimalType(25, 3))
+    checkConsistencyBetweenInterpretedAndCodegen(Floor, DecimalType(25, 0))
+    checkConsistencyBetweenInterpretedAndCodegen(Floor, DecimalType(5, 0))
   }
 
   test("factorial") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org