You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by da...@apache.org on 2016/04/18 19:45:03 UTC

spark git commit: [SPARK-14614] [SQL] Add `bround` function

Repository: spark
Updated Branches:
  refs/heads/master d6fb485de -> 432d1399c


[SPARK-14614] [SQL] Add `bround` function

## What changes were proposed in this pull request?

This PR aims to add `bound` function (aka Banker's round) by extending current `round` implementation. [Hive supports `bround` since 1.3.0.](https://cwiki.apache.org/confluence/display/Hive/LanguageManual+UDF)

**Hive (1.3 ~ 2.0)**
```
hive> select round(2.5), bround(2.5);
OK
3.0	2.0
```

**After this PR**
```scala
scala> sql("select round(2.5), bround(2.5)").head
res0: org.apache.spark.sql.Row = [3,2]
```

## How was this patch tested?

Pass the Jenkins tests (with extended tests).

Author: Dongjoon Hyun <do...@apache.org>

Closes #12376 from dongjoon-hyun/SPARK-14614.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/432d1399
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/432d1399
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/432d1399

Branch: refs/heads/master
Commit: 432d1399cb6985893932088875b2f3be981c0b5f
Parents: d6fb485
Author: Dongjoon Hyun <do...@apache.org>
Authored: Mon Apr 18 10:44:51 2016 -0700
Committer: Davies Liu <da...@gmail.com>
Committed: Mon Apr 18 10:44:51 2016 -0700

----------------------------------------------------------------------
 .../catalyst/analysis/FunctionRegistry.scala    |  1 +
 .../catalyst/expressions/mathExpressions.scala  | 71 +++++++++++++-------
 .../org/apache/spark/sql/types/Decimal.scala    |  6 ++
 .../analysis/ExpressionTypeCheckingSuite.scala  | 10 ++-
 .../expressions/MathFunctionsSuite.scala        | 23 ++++++-
 .../scala/org/apache/spark/sql/functions.scala  | 17 +++++
 .../apache/spark/sql/MathExpressionsSuite.scala | 12 +++-
 7 files changed, 113 insertions(+), 27 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/432d1399/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 028463e..ed19191 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -179,6 +179,7 @@ object FunctionRegistry {
     expression[Atan]("atan"),
     expression[Atan2]("atan2"),
     expression[Bin]("bin"),
+    expression[BRound]("bround"),
     expression[Cbrt]("cbrt"),
     expression[Ceil]("ceil"),
     expression[Ceil]("ceiling"),

http://git-wip-us.apache.org/repos/asf/spark/blob/432d1399/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
index c8a28e8..9e19028 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
@@ -779,7 +779,6 @@ case class Logarithm(left: Expression, right: Expression)
 /**
  * Round the `child`'s result to `scale` decimal place when `scale` >= 0
  * or round at integral part when `scale` < 0.
- * For example, round(31.415, 2) = 31.42 and round(31.415, -1) = 30.
  *
  * Child of IntegralType would round to itself when `scale` >= 0.
  * Child of FractionalType whose value is NaN or Infinite would always round to itself.
@@ -789,16 +788,12 @@ case class Logarithm(left: Expression, right: Expression)
  *
  * @param child expr to be round, all [[NumericType]] is allowed as Input
  * @param scale new scale to be round to, this should be a constant int at runtime
+ * @param mode rounding mode (e.g. HALF_UP, HALF_UP)
+ * @param modeStr rounding mode string name (e.g. "ROUND_HALF_UP", "ROUND_HALF_EVEN")
  */
-@ExpressionDescription(
-  usage = "_FUNC_(x, d) - Round x to d decimal places.",
-  extended = "> SELECT _FUNC_(12.3456, 1);\n 12.3")
-case class Round(child: Expression, scale: Expression)
-  extends BinaryExpression with ImplicitCastInputTypes {
-
-  import BigDecimal.RoundingMode.HALF_UP
-
-  def this(child: Expression) = this(child, Literal(0))
+abstract class RoundBase(child: Expression, scale: Expression,
+    mode: BigDecimal.RoundingMode.Value, modeStr: String)
+  extends BinaryExpression with Serializable with ImplicitCastInputTypes {
 
   override def left: Expression = child
   override def right: Expression = scale
@@ -853,28 +848,28 @@ case class Round(child: Expression, scale: Expression)
     child.dataType match {
       case _: DecimalType =>
         val decimal = input1.asInstanceOf[Decimal]
-        if (decimal.changePrecision(decimal.precision, _scale)) decimal else null
+        if (decimal.changePrecision(decimal.precision, _scale, mode)) decimal else null
       case ByteType =>
-        BigDecimal(input1.asInstanceOf[Byte]).setScale(_scale, HALF_UP).toByte
+        BigDecimal(input1.asInstanceOf[Byte]).setScale(_scale, mode).toByte
       case ShortType =>
-        BigDecimal(input1.asInstanceOf[Short]).setScale(_scale, HALF_UP).toShort
+        BigDecimal(input1.asInstanceOf[Short]).setScale(_scale, mode).toShort
       case IntegerType =>
-        BigDecimal(input1.asInstanceOf[Int]).setScale(_scale, HALF_UP).toInt
+        BigDecimal(input1.asInstanceOf[Int]).setScale(_scale, mode).toInt
       case LongType =>
-        BigDecimal(input1.asInstanceOf[Long]).setScale(_scale, HALF_UP).toLong
+        BigDecimal(input1.asInstanceOf[Long]).setScale(_scale, mode).toLong
       case FloatType =>
         val f = input1.asInstanceOf[Float]
         if (f.isNaN || f.isInfinite) {
           f
         } else {
-          BigDecimal(f.toDouble).setScale(_scale, HALF_UP).toFloat
+          BigDecimal(f.toDouble).setScale(_scale, mode).toFloat
         }
       case DoubleType =>
         val d = input1.asInstanceOf[Double]
         if (d.isNaN || d.isInfinite) {
           d
         } else {
-          BigDecimal(d).setScale(_scale, HALF_UP).toDouble
+          BigDecimal(d).setScale(_scale, mode).toDouble
         }
     }
   }
@@ -885,7 +880,8 @@ case class Round(child: Expression, scale: Expression)
     val evaluationCode = child.dataType match {
       case _: DecimalType =>
         s"""
-        if (${ce.value}.changePrecision(${ce.value}.precision(), ${_scale})) {
+        if (${ce.value}.changePrecision(${ce.value}.precision(), ${_scale},
+            java.math.BigDecimal.${modeStr})) {
           ${ev.value} = ${ce.value};
         } else {
           ${ev.isNull} = true;
@@ -894,7 +890,7 @@ case class Round(child: Expression, scale: Expression)
         if (_scale < 0) {
           s"""
           ${ev.value} = new java.math.BigDecimal(${ce.value}).
-            setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).byteValue();"""
+            setScale(${_scale}, java.math.BigDecimal.${modeStr}).byteValue();"""
         } else {
           s"${ev.value} = ${ce.value};"
         }
@@ -902,7 +898,7 @@ case class Round(child: Expression, scale: Expression)
         if (_scale < 0) {
           s"""
           ${ev.value} = new java.math.BigDecimal(${ce.value}).
-            setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).shortValue();"""
+            setScale(${_scale}, java.math.BigDecimal.${modeStr}).shortValue();"""
         } else {
           s"${ev.value} = ${ce.value};"
         }
@@ -910,7 +906,7 @@ case class Round(child: Expression, scale: Expression)
         if (_scale < 0) {
           s"""
           ${ev.value} = new java.math.BigDecimal(${ce.value}).
-            setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).intValue();"""
+            setScale(${_scale}, java.math.BigDecimal.${modeStr}).intValue();"""
         } else {
           s"${ev.value} = ${ce.value};"
         }
@@ -918,7 +914,7 @@ case class Round(child: Expression, scale: Expression)
         if (_scale < 0) {
           s"""
           ${ev.value} = new java.math.BigDecimal(${ce.value}).
-            setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).longValue();"""
+            setScale(${_scale}, java.math.BigDecimal.${modeStr}).longValue();"""
         } else {
           s"${ev.value} = ${ce.value};"
         }
@@ -928,7 +924,7 @@ case class Round(child: Expression, scale: Expression)
             ${ev.value} = ${ce.value};
           } else {
             ${ev.value} = java.math.BigDecimal.valueOf(${ce.value}).
-              setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).floatValue();
+              setScale(${_scale}, java.math.BigDecimal.${modeStr}).floatValue();
           }"""
       case DoubleType => // if child eval to NaN or Infinity, just return it.
         s"""
@@ -936,7 +932,7 @@ case class Round(child: Expression, scale: Expression)
             ${ev.value} = ${ce.value};
           } else {
             ${ev.value} = java.math.BigDecimal.valueOf(${ce.value}).
-              setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).doubleValue();
+              setScale(${_scale}, java.math.BigDecimal.${modeStr}).doubleValue();
           }"""
     }
 
@@ -957,3 +953,30 @@ case class Round(child: Expression, scale: Expression)
     }
   }
 }
+
+/**
+ * Round an expression to d decimal places using HALF_UP rounding mode.
+ * round(2.5) == 3.0, round(3.5) == 4.0.
+ */
+@ExpressionDescription(
+  usage = "_FUNC_(x, d) - Round x to d decimal places using HALF_UP rounding mode.",
+  extended = "> SELECT _FUNC_(2.5, 0);\n 3.0")
+case class Round(child: Expression, scale: Expression)
+  extends RoundBase(child, scale, BigDecimal.RoundingMode.HALF_UP, "ROUND_HALF_UP")
+    with Serializable with ImplicitCastInputTypes {
+  def this(child: Expression) = this(child, Literal(0))
+}
+
+/**
+ * Round an expression to d decimal places using HALF_EVEN rounding mode,
+ * also known as Gaussian rounding or bankers' rounding.
+ * round(2.5) = 2.0, round(3.5) = 4.0.
+ */
+@ExpressionDescription(
+  usage = "_FUNC_(x, d) - Round x to d decimal places using HALF_EVEN rounding mode.",
+  extended = "> SELECT _FUNC_(2.5, 0);\n 2.0")
+case class BRound(child: Expression, scale: Expression)
+  extends RoundBase(child, scale, BigDecimal.RoundingMode.HALF_EVEN, "ROUND_HALF_EVEN")
+    with Serializable with ImplicitCastInputTypes {
+  def this(child: Expression) = this(child, Literal(0))
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/432d1399/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
index a30a392..6f4ec6b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
@@ -201,6 +201,11 @@ final class Decimal extends Ordered[Decimal] with Serializable {
     changePrecision(precision, scale, ROUND_HALF_UP)
   }
 
+  def changePrecision(precision: Int, scale: Int, mode: Int): Boolean = mode match {
+    case java.math.BigDecimal.ROUND_HALF_UP => changePrecision(precision, scale, ROUND_HALF_UP)
+    case java.math.BigDecimal.ROUND_HALF_EVEN => changePrecision(precision, scale, ROUND_HALF_EVEN)
+  }
+
   /**
    * Update precision and scale while keeping our value the same, and return true if successful.
    *
@@ -337,6 +342,7 @@ final class Decimal extends Ordered[Decimal] with Serializable {
 
 object Decimal {
   val ROUND_HALF_UP = BigDecimal.RoundingMode.HALF_UP
+  val ROUND_HALF_EVEN = BigDecimal.RoundingMode.HALF_EVEN
   val ROUND_CEILING = BigDecimal.RoundingMode.CEILING
   val ROUND_FLOOR = BigDecimal.RoundingMode.FLOOR
 

http://git-wip-us.apache.org/repos/asf/spark/blob/432d1399/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
index ace6e10..660dc86 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
@@ -192,7 +192,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
       "values of function map should all be the same type")
   }
 
-  test("check types for ROUND") {
+  test("check types for ROUND/BROUND") {
     assertSuccess(Round(Literal(null), Literal(null)))
     assertSuccess(Round('intField, Literal(1)))
 
@@ -200,6 +200,14 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
     assertError(Round('intField, 'booleanField), "requires int type")
     assertError(Round('intField, 'mapField), "requires int type")
     assertError(Round('booleanField, 'intField), "requires numeric type")
+
+    assertSuccess(BRound(Literal(null), Literal(null)))
+    assertSuccess(BRound('intField, Literal(1)))
+
+    assertError(BRound('intField, 'intField), "Only foldable Expression is allowed")
+    assertError(BRound('intField, 'booleanField), "requires int type")
+    assertError(BRound('intField, 'mapField), "requires int type")
+    assertError(BRound('booleanField, 'intField), "requires numeric type")
   }
 
   test("check types for Greatest/Least") {

http://git-wip-us.apache.org/repos/asf/spark/blob/432d1399/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
index 452792d..1e5b657 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
@@ -508,7 +508,7 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
     checkConsistencyBetweenInterpretedAndCodegen(Logarithm, DoubleType, DoubleType)
   }
 
-  test("round") {
+  test("round/bround") {
     val scales = -6 to 6
     val doublePi: Double = math.Pi
     val shortPi: Short = 31415
@@ -529,11 +529,18 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
       31415926535900000L, 31415926535898000L, 31415926535897900L, 31415926535897930L) ++
       Seq.fill(7)(31415926535897932L)
 
+    val intResultsB: Seq[Int] = Seq(314000000, 314200000, 314160000, 314159000, 314159300,
+      314159260) ++ Seq.fill(7)(314159265)
+
     scales.zipWithIndex.foreach { case (scale, i) =>
       checkEvaluation(Round(doublePi, scale), doubleResults(i), EmptyRow)
       checkEvaluation(Round(shortPi, scale), shortResults(i), EmptyRow)
       checkEvaluation(Round(intPi, scale), intResults(i), EmptyRow)
       checkEvaluation(Round(longPi, scale), longResults(i), EmptyRow)
+      checkEvaluation(BRound(doublePi, scale), doubleResults(i), EmptyRow)
+      checkEvaluation(BRound(shortPi, scale), shortResults(i), EmptyRow)
+      checkEvaluation(BRound(intPi, scale), intResultsB(i), EmptyRow)
+      checkEvaluation(BRound(longPi, scale), longResults(i), EmptyRow)
     }
 
     val bdResults: Seq[BigDecimal] = Seq(BigDecimal(3.0), BigDecimal(3.1), BigDecimal(3.14),
@@ -543,19 +550,33 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
     // and not allowed by o.a.s.s.types.Decimal.changePrecision, therefore null
     (0 to 7).foreach { i =>
       checkEvaluation(Round(bdPi, i), bdResults(i), EmptyRow)
+      checkEvaluation(BRound(bdPi, i), bdResults(i), EmptyRow)
     }
     (8 to 10).foreach { scale =>
       checkEvaluation(Round(bdPi, scale), null, EmptyRow)
+      checkEvaluation(BRound(bdPi, scale), null, EmptyRow)
     }
 
     DataTypeTestUtils.numericTypes.foreach { dataType =>
       checkEvaluation(Round(Literal.create(null, dataType), Literal(2)), null)
       checkEvaluation(Round(Literal.create(null, dataType),
         Literal.create(null, IntegerType)), null)
+      checkEvaluation(BRound(Literal.create(null, dataType), Literal(2)), null)
+      checkEvaluation(BRound(Literal.create(null, dataType),
+        Literal.create(null, IntegerType)), null)
     }
 
+    checkEvaluation(Round(2.5, 0), 3.0)
+    checkEvaluation(Round(3.5, 0), 4.0)
+    checkEvaluation(Round(-2.5, 0), -3.0)
     checkEvaluation(Round(-3.5, 0), -4.0)
     checkEvaluation(Round(-0.35, 1), -0.4)
     checkEvaluation(Round(-35, -1), -40)
+    checkEvaluation(BRound(2.5, 0), 2.0)
+    checkEvaluation(BRound(3.5, 0), 4.0)
+    checkEvaluation(BRound(-2.5, 0), -2.0)
+    checkEvaluation(BRound(-3.5, 0), -4.0)
+    checkEvaluation(BRound(-0.35, 1), -0.4)
+    checkEvaluation(BRound(-35, -1), -40)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/432d1399/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 2231223..8e2e946 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -1777,6 +1777,23 @@ object functions {
   def round(e: Column, scale: Int): Column = withExpr { Round(e.expr, Literal(scale)) }
 
   /**
+   * Returns the value of the column `e` rounded to 0 decimal places with HALF_EVEN round mode.
+   *
+   * @group math_funcs
+   * @since 2.0.0
+   */
+  def bround(e: Column): Column = bround(e, 0)
+
+  /**
+   * Round the value of `e` to `scale` decimal places with HALF_EVEN round mode
+   * if `scale` >= 0 or at integral part when `scale` < 0.
+   *
+   * @group math_funcs
+   * @since 2.0.0
+   */
+  def bround(e: Column, scale: Int): Column = withExpr { BRound(e.expr, Literal(scale)) }
+
+  /**
    * Shift the given value numBits left. If the given value is a long value, this function
    * will return a long value else it will return an integer value.
    *

http://git-wip-us.apache.org/repos/asf/spark/blob/432d1399/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
index f5a67fd..0de7f23 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
@@ -207,12 +207,16 @@ class MathExpressionsSuite extends QueryTest with SharedSQLContext {
     testOneToOneMathFunction(rint, math.rint)
   }
 
-  test("round") {
+  test("round/bround") {
     val df = Seq(5, 55, 555).map(Tuple1(_)).toDF("a")
     checkAnswer(
       df.select(round('a), round('a, -1), round('a, -2)),
       Seq(Row(5, 10, 0), Row(55, 60, 100), Row(555, 560, 600))
     )
+    checkAnswer(
+      df.select(bround('a), bround('a, -1), bround('a, -2)),
+      Seq(Row(5, 0, 0), Row(55, 60, 100), Row(555, 560, 600))
+    )
 
     val pi = "3.1415"
     checkAnswer(
@@ -221,6 +225,12 @@ class MathExpressionsSuite extends QueryTest with SharedSQLContext {
       Seq(Row(BigDecimal("0E3"), BigDecimal("0E2"), BigDecimal("0E1"), BigDecimal(3),
         BigDecimal("3.1"), BigDecimal("3.14"), BigDecimal("3.142")))
     )
+    checkAnswer(
+      sql(s"SELECT bround($pi, -3), bround($pi, -2), bround($pi, -1), " +
+        s"bround($pi, 0), bround($pi, 1), bround($pi, 2), bround($pi, 3)"),
+      Seq(Row(BigDecimal("0E3"), BigDecimal("0E2"), BigDecimal("0E1"), BigDecimal(3),
+        BigDecimal("3.1"), BigDecimal("3.14"), BigDecimal("3.142")))
+    )
   }
 
   test("exp") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org