You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2017/03/08 20:36:19 UTC

spark git commit: [SPARK-19727][SQL] Fix for round function that modifies original column

Repository: spark
Updated Branches:
  refs/heads/master f3387d974 -> e9e2c612d


[SPARK-19727][SQL] Fix for round function that modifies original column

## What changes were proposed in this pull request?

Fix for SQL round function that modifies original column when underlying data frame is created from a local product.

    import org.apache.spark.sql.functions._

    case class NumericRow(value: BigDecimal)

    val df = spark.createDataFrame(Seq(NumericRow(BigDecimal("1.23456789"))))

    df.show()
    +--------------------+
    |               value|
    +--------------------+
    |1.234567890000000000|
    +--------------------+

    df.withColumn("value_rounded", round('value)).show()

    // before
    +--------------------+-------------+
    |               value|value_rounded|
    +--------------------+-------------+
    |1.000000000000000000|            1|
    +--------------------+-------------+

    // after
    +--------------------+-------------+
    |               value|value_rounded|
    +--------------------+-------------+
    |1.234567890000000000|            1|
    +--------------------+-------------+

## How was this patch tested?

New unit test added to existing suite `org.apache.spark.sql.MathFunctionsSuite`

Author: Wojtek Szymanski <wk...@gmail.com>

Closes #17075 from wojtek-szymanski/SPARK-19727.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/e9e2c612
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/e9e2c612
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/e9e2c612

Branch: refs/heads/master
Commit: e9e2c612d58a19ddcb4b6abfb7389a4b0f7ef6f8
Parents: f3387d9
Author: Wojtek Szymanski <wk...@gmail.com>
Authored: Wed Mar 8 12:36:16 2017 -0800
Committer: Wenchen Fan <we...@databricks.com>
Committed: Wed Mar 8 12:36:16 2017 -0800

----------------------------------------------------------------------
 .../sql/catalyst/CatalystTypeConverters.scala   |  6 +----
 .../spark/sql/catalyst/expressions/Cast.scala   | 13 +++++++--
 .../expressions/decimalExpressions.scala        | 10 ++-----
 .../catalyst/expressions/mathExpressions.scala  |  2 +-
 .../org/apache/spark/sql/types/Decimal.scala    | 28 ++++++++++++++------
 .../apache/spark/sql/types/DecimalSuite.scala   |  8 +++++-
 .../apache/spark/sql/MathFunctionsSuite.scala   | 12 +++++++++
 7 files changed, 54 insertions(+), 25 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/e9e2c612/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
index 5b91615..d4ebdb1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
@@ -310,11 +310,7 @@ object CatalystTypeConverters {
         case d: JavaBigInteger => Decimal(d)
         case d: Decimal => d
       }
-      if (decimal.changePrecision(dataType.precision, dataType.scale)) {
-        decimal
-      } else {
-        null
-      }
+      decimal.toPrecision(dataType.precision, dataType.scale).orNull
     }
     override def toScala(catalystValue: Decimal): JavaBigDecimal = {
       if (catalystValue == null) null

http://git-wip-us.apache.org/repos/asf/spark/blob/e9e2c612/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 7c60f7d..1049915 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -352,6 +352,15 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
     if (value.changePrecision(decimalType.precision, decimalType.scale)) value else null
   }
 
+  /**
+   * Create new `Decimal` with precision and scale given in `decimalType` (if any),
+   * returning null if it overflows or creating a new `value` and returning it if successful.
+   *
+   */
+  private[this] def toPrecision(value: Decimal, decimalType: DecimalType): Decimal =
+    value.toPrecision(decimalType.precision, decimalType.scale).orNull
+
+
   private[this] def castToDecimal(from: DataType, target: DecimalType): Any => Any = from match {
     case StringType =>
       buildCast[UTF8String](_, s => try {
@@ -360,14 +369,14 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
         case _: NumberFormatException => null
       })
     case BooleanType =>
-      buildCast[Boolean](_, b => changePrecision(if (b) Decimal.ONE else Decimal.ZERO, target))
+      buildCast[Boolean](_, b => toPrecision(if (b) Decimal.ONE else Decimal.ZERO, target))
     case DateType =>
       buildCast[Int](_, d => null) // date can't cast to decimal in Hive
     case TimestampType =>
       // Note that we lose precision here.
       buildCast[Long](_, t => changePrecision(Decimal(timestampToDouble(t)), target))
     case dt: DecimalType =>
-      b => changePrecision(b.asInstanceOf[Decimal].clone(), target)
+      b => toPrecision(b.asInstanceOf[Decimal], target)
     case t: IntegralType =>
       b => changePrecision(Decimal(t.integral.asInstanceOf[Integral[Any]].toLong(b)), target)
     case x: FractionalType =>

http://git-wip-us.apache.org/repos/asf/spark/blob/e9e2c612/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
index fa5dea6..c2211ae 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
@@ -84,14 +84,8 @@ case class CheckOverflow(child: Expression, dataType: DecimalType) extends Unary
 
   override def nullable: Boolean = true
 
-  override def nullSafeEval(input: Any): Any = {
-    val d = input.asInstanceOf[Decimal].clone()
-    if (d.changePrecision(dataType.precision, dataType.scale)) {
-      d
-    } else {
-      null
-    }
-  }
+  override def nullSafeEval(input: Any): Any =
+    input.asInstanceOf[Decimal].toPrecision(dataType.precision, dataType.scale).orNull
 
   override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     nullSafeCodeGen(ctx, ev, eval => {

http://git-wip-us.apache.org/repos/asf/spark/blob/e9e2c612/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
index 65273a7..dea5f85 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
@@ -1024,7 +1024,7 @@ abstract class RoundBase(child: Expression, scale: Expression,
     child.dataType match {
       case _: DecimalType =>
         val decimal = input1.asInstanceOf[Decimal]
-        if (decimal.changePrecision(decimal.precision, _scale, mode)) decimal else null
+        decimal.toPrecision(decimal.precision, _scale, mode).orNull
       case ByteType =>
         BigDecimal(input1.asInstanceOf[Byte]).setScale(_scale, mode).toByte
       case ShortType =>

http://git-wip-us.apache.org/repos/asf/spark/blob/e9e2c612/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
index 089c84d..e8f6884 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
@@ -21,6 +21,7 @@ import java.lang.{Long => JLong}
 import java.math.{BigInteger, MathContext, RoundingMode}
 
 import org.apache.spark.annotation.InterfaceStability
+import org.apache.spark.sql.AnalysisException
 
 /**
  * A mutable implementation of BigDecimal that can hold a Long if values are small enough.
@@ -223,6 +224,19 @@ final class Decimal extends Ordered[Decimal] with Serializable {
   }
 
   /**
+   * Create new `Decimal` with given precision and scale.
+   *
+   * @return `Some(decimal)` if successful or `None` if overflow would occur
+   */
+  private[sql] def toPrecision(
+      precision: Int,
+      scale: Int,
+      roundMode: BigDecimal.RoundingMode.Value = ROUND_HALF_UP): Option[Decimal] = {
+    val copy = clone()
+    if (copy.changePrecision(precision, scale, roundMode)) Some(copy) else None
+  }
+
+  /**
    * Update precision and scale while keeping our value the same, and return true if successful.
    *
    * @return true if successful, false if overflow would occur
@@ -362,17 +376,15 @@ final class Decimal extends Ordered[Decimal] with Serializable {
   def abs: Decimal = if (this.compare(Decimal.ZERO) < 0) this.unary_- else this
 
   def floor: Decimal = if (scale == 0) this else {
-    val value = this.clone()
-    value.changePrecision(
-      DecimalType.bounded(precision - scale + 1, 0).precision, 0, ROUND_FLOOR)
-    value
+    val newPrecision = DecimalType.bounded(precision - scale + 1, 0).precision
+    toPrecision(newPrecision, 0, ROUND_FLOOR).getOrElse(
+      throw new AnalysisException(s"Overflow when setting precision to $newPrecision"))
   }
 
   def ceil: Decimal = if (scale == 0) this else {
-    val value = this.clone()
-    value.changePrecision(
-      DecimalType.bounded(precision - scale + 1, 0).precision, 0, ROUND_CEILING)
-    value
+    val newPrecision = DecimalType.bounded(precision - scale + 1, 0).precision
+    toPrecision(newPrecision, 0, ROUND_CEILING).getOrElse(
+      throw new AnalysisException(s"Overflow when setting precision to $newPrecision"))
   }
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/e9e2c612/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala
index 52d0692..714883a 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala
@@ -193,7 +193,7 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester {
     assert(Decimal(Long.MaxValue, 100, 0).toUnscaledLong === Long.MaxValue)
   }
 
-  test("changePrecision() on compact decimal should respect rounding mode") {
+  test("changePrecision/toPrecision on compact decimal should respect rounding mode") {
     Seq(ROUND_FLOOR, ROUND_CEILING, ROUND_HALF_UP, ROUND_HALF_EVEN).foreach { mode =>
       Seq("0.4", "0.5", "0.6", "1.0", "1.1", "1.6", "2.5", "5.5").foreach { n =>
         Seq("", "-").foreach { sign =>
@@ -202,6 +202,12 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester {
           val d = Decimal(unscaled, 8, 1)
           assert(d.changePrecision(10, 0, mode))
           assert(d.toString === bd.setScale(0, mode).toString(), s"num: $sign$n, mode: $mode")
+
+          val copy = d.toPrecision(10, 0, mode).orNull
+          assert(copy !== null)
+          assert(d.ne(copy))
+          assert(d === copy)
+          assert(copy.toString === bd.setScale(0, mode).toString(), s"num: $sign$n, mode: $mode")
         }
       }
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/e9e2c612/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala
index 37443d0..328c539 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala
@@ -233,6 +233,18 @@ class MathFunctionsSuite extends QueryTest with SharedSQLContext {
     )
   }
 
+  test("round/bround with data frame from a local Seq of Product") {
+    val df = spark.createDataFrame(Seq(Tuple1(BigDecimal("5.9")))).toDF("value")
+    checkAnswer(
+      df.withColumn("value_rounded", round('value)),
+      Seq(Row(BigDecimal("5.9"), BigDecimal("6")))
+    )
+    checkAnswer(
+      df.withColumn("value_brounded", bround('value)),
+      Seq(Row(BigDecimal("5.9"), BigDecimal("6")))
+    )
+  }
+
   test("exp") {
     testOneToOneMathFunction(exp, math.exp)
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org