You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by li...@apache.org on 2017/10/29 01:24:25 UTC

spark git commit: [SPARK-19727][SQL][FOLLOWUP] Fix for round function that modifies original column

Repository: spark
Updated Branches:
  refs/heads/master e80da8129 -> 7fdacbc77


[SPARK-19727][SQL][FOLLOWUP] Fix for round function that modifies original column

## What changes were proposed in this pull request?

This is a followup of https://github.com/apache/spark/pull/17075 , to fix the bug in codegen path.

## How was this patch tested?

new regression test

Author: Wenchen Fan <we...@databricks.com>

Closes #19576 from cloud-fan/bug.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/7fdacbc7
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/7fdacbc7
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/7fdacbc7

Branch: refs/heads/master
Commit: 7fdacbc77bbcf98c2c045a1873e749129769dcc0
Parents: e80da81
Author: Wenchen Fan <we...@databricks.com>
Authored: Sat Oct 28 18:24:18 2017 -0700
Committer: gatorsmile <ga...@gmail.com>
Committed: Sat Oct 28 18:24:18 2017 -0700

----------------------------------------------------------------------
 .../sql/catalyst/CatalystTypeConverters.scala   |  2 +-
 .../spark/sql/catalyst/expressions/Cast.scala   |  3 +-
 .../expressions/decimalExpressions.scala        |  2 +-
 .../catalyst/expressions/mathExpressions.scala  | 10 ++-----
 .../org/apache/spark/sql/types/Decimal.scala    | 31 +++++++++++---------
 .../apache/spark/sql/types/DecimalSuite.scala   |  2 +-
 .../apache/spark/sql/MathFunctionsSuite.scala   | 12 ++++++++
 7 files changed, 36 insertions(+), 26 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/7fdacbc7/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
index d4ebdb1..474ec59 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
@@ -310,7 +310,7 @@ object CatalystTypeConverters {
         case d: JavaBigInteger => Decimal(d)
         case d: Decimal => d
       }
-      decimal.toPrecision(dataType.precision, dataType.scale).orNull
+      decimal.toPrecision(dataType.precision, dataType.scale)
     }
     override def toScala(catalystValue: Decimal): JavaBigDecimal = {
       if (catalystValue == null) null

http://git-wip-us.apache.org/repos/asf/spark/blob/7fdacbc7/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index d949b8f..bc809f5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -387,10 +387,9 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
   /**
    * Create new `Decimal` with precision and scale given in `decimalType` (if any),
    * returning null if it overflows or creating a new `value` and returning it if successful.
-   *
    */
   private[this] def toPrecision(value: Decimal, decimalType: DecimalType): Decimal =
-    value.toPrecision(decimalType.precision, decimalType.scale).orNull
+    value.toPrecision(decimalType.precision, decimalType.scale)
 
 
   private[this] def castToDecimal(from: DataType, target: DecimalType): Any => Any = from match {

http://git-wip-us.apache.org/repos/asf/spark/blob/7fdacbc7/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
index c2211ae..752dea2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
@@ -85,7 +85,7 @@ case class CheckOverflow(child: Expression, dataType: DecimalType) extends Unary
   override def nullable: Boolean = true
 
   override def nullSafeEval(input: Any): Any =
-    input.asInstanceOf[Decimal].toPrecision(dataType.precision, dataType.scale).orNull
+    input.asInstanceOf[Decimal].toPrecision(dataType.precision, dataType.scale)
 
   override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     nullSafeCodeGen(ctx, ev, eval => {

http://git-wip-us.apache.org/repos/asf/spark/blob/7fdacbc7/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
index 547d5be..d8dc086 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
@@ -1044,7 +1044,7 @@ abstract class RoundBase(child: Expression, scale: Expression,
     dataType match {
       case DecimalType.Fixed(_, s) =>
         val decimal = input1.asInstanceOf[Decimal]
-        decimal.toPrecision(decimal.precision, s, mode).orNull
+        decimal.toPrecision(decimal.precision, s, mode)
       case ByteType =>
         BigDecimal(input1.asInstanceOf[Byte]).setScale(_scale, mode).toByte
       case ShortType =>
@@ -1076,12 +1076,8 @@ abstract class RoundBase(child: Expression, scale: Expression,
     val evaluationCode = dataType match {
       case DecimalType.Fixed(_, s) =>
         s"""
-        if (${ce.value}.changePrecision(${ce.value}.precision(), ${s},
-            java.math.BigDecimal.${modeStr})) {
-          ${ev.value} = ${ce.value};
-        } else {
-          ${ev.isNull} = true;
-        }"""
+        ${ev.value} = ${ce.value}.toPrecision(${ce.value}.precision(), $s, Decimal.$modeStr());
+        ${ev.isNull} = ${ev.value} == null;"""
       case ByteType =>
         if (_scale < 0) {
           s"""

http://git-wip-us.apache.org/repos/asf/spark/blob/7fdacbc7/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
index 1f1fb51..6da4f28 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
@@ -234,22 +234,17 @@ final class Decimal extends Ordered[Decimal] with Serializable {
     changePrecision(precision, scale, ROUND_HALF_UP)
   }
 
-  def changePrecision(precision: Int, scale: Int, mode: Int): Boolean = mode match {
-    case java.math.BigDecimal.ROUND_HALF_UP => changePrecision(precision, scale, ROUND_HALF_UP)
-    case java.math.BigDecimal.ROUND_HALF_EVEN => changePrecision(precision, scale, ROUND_HALF_EVEN)
-  }
-
   /**
    * Create new `Decimal` with given precision and scale.
    *
-   * @return `Some(decimal)` if successful or `None` if overflow would occur
+   * @return a non-null `Decimal` value if successful or `null` if overflow would occur.
    */
   private[sql] def toPrecision(
       precision: Int,
       scale: Int,
-      roundMode: BigDecimal.RoundingMode.Value = ROUND_HALF_UP): Option[Decimal] = {
+      roundMode: BigDecimal.RoundingMode.Value = ROUND_HALF_UP): Decimal = {
     val copy = clone()
-    if (copy.changePrecision(precision, scale, roundMode)) Some(copy) else None
+    if (copy.changePrecision(precision, scale, roundMode)) copy else null
   }
 
   /**
@@ -257,8 +252,10 @@ final class Decimal extends Ordered[Decimal] with Serializable {
    *
    * @return true if successful, false if overflow would occur
    */
-  private[sql] def changePrecision(precision: Int, scale: Int,
-                      roundMode: BigDecimal.RoundingMode.Value): Boolean = {
+  private[sql] def changePrecision(
+      precision: Int,
+      scale: Int,
+      roundMode: BigDecimal.RoundingMode.Value): Boolean = {
     // fast path for UnsafeProjection
     if (precision == this.precision && scale == this.scale) {
       return true
@@ -393,14 +390,20 @@ final class Decimal extends Ordered[Decimal] with Serializable {
 
   def floor: Decimal = if (scale == 0) this else {
     val newPrecision = DecimalType.bounded(precision - scale + 1, 0).precision
-    toPrecision(newPrecision, 0, ROUND_FLOOR).getOrElse(
-      throw new AnalysisException(s"Overflow when setting precision to $newPrecision"))
+    val res = toPrecision(newPrecision, 0, ROUND_FLOOR)
+    if (res == null) {
+      throw new AnalysisException(s"Overflow when setting precision to $newPrecision")
+    }
+    res
   }
 
   def ceil: Decimal = if (scale == 0) this else {
     val newPrecision = DecimalType.bounded(precision - scale + 1, 0).precision
-    toPrecision(newPrecision, 0, ROUND_CEILING).getOrElse(
-      throw new AnalysisException(s"Overflow when setting precision to $newPrecision"))
+    val res = toPrecision(newPrecision, 0, ROUND_CEILING)
+    if (res == null) {
+      throw new AnalysisException(s"Overflow when setting precision to $newPrecision")
+    }
+    res
   }
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/7fdacbc7/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala
index 3193d13..10de90c 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala
@@ -213,7 +213,7 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester {
           assert(d.changePrecision(10, 0, mode))
           assert(d.toString === bd.setScale(0, mode).toString(), s"num: $sign$n, mode: $mode")
 
-          val copy = d.toPrecision(10, 0, mode).orNull
+          val copy = d.toPrecision(10, 0, mode)
           assert(copy !== null)
           assert(d.ne(copy))
           assert(d === copy)

http://git-wip-us.apache.org/repos/asf/spark/blob/7fdacbc7/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala
index c2d08a0..5be8c58 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala
@@ -258,6 +258,18 @@ class MathFunctionsSuite extends QueryTest with SharedSQLContext {
     )
   }
 
+  test("round/bround with table columns") {
+    withTable("t") {
+      Seq(BigDecimal("5.9")).toDF("i").write.saveAsTable("t")
+      checkAnswer(
+        sql("select i, round(i) from t"),
+        Seq(Row(BigDecimal("5.9"), BigDecimal("6"))))
+      checkAnswer(
+        sql("select i, bround(i) from t"),
+        Seq(Row(BigDecimal("5.9"), BigDecimal("6"))))
+    }
+  }
+
   test("exp") {
     testOneToOneMathFunction(exp, math.exp)
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org