You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by sr...@apache.org on 2023/02/03 16:49:11 UTC

[spark] branch branch-3.3 updated: [SPARK-41554] fix changing of Decimal scale when scale decreased by m…

This is an automated email from the ASF dual-hosted git repository.

srowen pushed a commit to branch branch-3.3
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.3 by this push:
     new 2d539c5c702 [SPARK-41554] fix changing of Decimal scale when scale decreased by m…
2d539c5c702 is described below

commit 2d539c5c7022d44d8a2d53e752287c42c2601444
Author: oleksii.diagiliev <ol...@workday.com>
AuthorDate: Fri Feb 3 10:48:56 2023 -0600

    [SPARK-41554] fix changing of Decimal scale when scale decreased by m…
    
    …ore than 18
    
    This is a backport PR for https://github.com/apache/spark/pull/39099
    
    Closes #39813 from fe2s/branch-3.3-fix-decimal-scaling.
    
    Authored-by: oleksii.diagiliev <ol...@workday.com>
    Signed-off-by: Sean Owen <sr...@gmail.com>
---
 .../scala/org/apache/spark/sql/types/Decimal.scala | 60 +++++++++++++---------
 .../org/apache/spark/sql/types/DecimalSuite.scala  | 53 ++++++++++++++++++-
 2 files changed, 88 insertions(+), 25 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
index 7a43d01eb2f..07a2c47cff0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
@@ -397,30 +397,42 @@ final class Decimal extends Ordered[Decimal] with Serializable {
       if (scale < _scale) {
         // Easier case: we just need to divide our scale down
         val diff = _scale - scale
-        val pow10diff = POW_10(diff)
-        // % and / always round to 0
-        val droppedDigits = longVal % pow10diff
-        longVal /= pow10diff
-        roundMode match {
-          case ROUND_FLOOR =>
-            if (droppedDigits < 0) {
-              longVal += -1L
-            }
-          case ROUND_CEILING =>
-            if (droppedDigits > 0) {
-              longVal += 1L
-            }
-          case ROUND_HALF_UP =>
-            if (math.abs(droppedDigits) * 2 >= pow10diff) {
-              longVal += (if (droppedDigits < 0) -1L else 1L)
-            }
-          case ROUND_HALF_EVEN =>
-            val doubled = math.abs(droppedDigits) * 2
-            if (doubled > pow10diff || doubled == pow10diff && longVal % 2 != 0) {
-              longVal += (if (droppedDigits < 0) -1L else 1L)
-            }
-          case _ =>
-            throw QueryExecutionErrors.unsupportedRoundingMode(roundMode)
+        // If diff is greater than max number of digits we store in Long, then
+        // value becomes 0. Otherwise we calculate new value dividing by power of 10.
+        // In both cases we apply rounding after that.
+        if (diff > MAX_LONG_DIGITS) {
+          longVal = roundMode match {
+            case ROUND_FLOOR => if (longVal < 0) -1L else 0L
+            case ROUND_CEILING => if (longVal > 0) 1L else 0L
+            case ROUND_HALF_UP | ROUND_HALF_EVEN => 0L
+            case _ => sys.error(s"Not supported rounding mode: $roundMode")
+          }
+        } else {
+          val pow10diff = POW_10(diff)
+          // % and / always round to 0
+          val droppedDigits = longVal % pow10diff
+          longVal /= pow10diff
+          roundMode match {
+            case ROUND_FLOOR =>
+              if (droppedDigits < 0) {
+                longVal += -1L
+              }
+            case ROUND_CEILING =>
+              if (droppedDigits > 0) {
+                longVal += 1L
+              }
+            case ROUND_HALF_UP =>
+              if (math.abs(droppedDigits) * 2 >= pow10diff) {
+                longVal += (if (droppedDigits < 0) -1L else 1L)
+              }
+            case ROUND_HALF_EVEN =>
+              val doubled = math.abs(droppedDigits) * 2
+              if (doubled > pow10diff || doubled == pow10diff && longVal % 2 != 0) {
+                longVal += (if (droppedDigits < 0) -1L else 1L)
+              }
+            case _ =>
+              throw QueryExecutionErrors.unsupportedRoundingMode(roundMode)
+          }
         }
       } else if (scale > _scale) {
         // We might be able to multiply longVal by a power of 10 and not overflow, but if not,
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala
index 6f70dc51b95..6ccd2b9bd32 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala
@@ -27,6 +27,9 @@ import org.apache.spark.sql.types.Decimal._
 import org.apache.spark.unsafe.types.UTF8String
 
 class DecimalSuite extends SparkFunSuite with PrivateMethodTester with SQLHelper {
+
+  val allSupportedRoundModes = Seq(ROUND_HALF_UP, ROUND_HALF_EVEN, ROUND_CEILING, ROUND_FLOOR)
+
   /** Check that a Decimal has the given string representation, precision and scale */
   private def checkDecimal(d: Decimal, string: String, precision: Int, scale: Int): Unit = {
     assert(d.toString === string)
@@ -222,7 +225,7 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester with SQLHelper
   }
 
   test("changePrecision/toPrecision on compact decimal should respect rounding mode") {
-    Seq(ROUND_FLOOR, ROUND_CEILING, ROUND_HALF_UP, ROUND_HALF_EVEN).foreach { mode =>
+    allSupportedRoundModes.foreach { mode =>
       Seq("0.4", "0.5", "0.6", "1.0", "1.1", "1.6", "2.5", "5.5").foreach { n =>
         Seq("", "-").foreach { sign =>
           val bd = BigDecimal(sign + n)
@@ -315,4 +318,52 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester with SQLHelper
       }
     }
   }
+
+  // 18 is a max number of digits in Decimal's compact long
+  test("SPARK-41554: decrease/increase scale by 18 and more on compact decimal") {
+    val unscaledNums = Seq(
+      0L, 1L, 10L, 51L, 123L, 523L,
+      // 18 digits
+      912345678901234567L,
+      112345678901234567L,
+      512345678901234567L
+    )
+    val precision = 38
+    // generate some (from, to) scale pairs, e.g. (38, 18), (-20, -2), etc
+    val scalePairs = for {
+      scale <- Seq(38, 20, 19, 18)
+      delta <- Seq(38, 20, 19, 18)
+      a = scale
+      b = scale - delta
+    } yield {
+      Seq((a, b), (-a, -b), (b, a), (-b, -a))
+    }
+
+    for {
+      unscaled <- unscaledNums
+      mode <- allSupportedRoundModes
+      (scaleFrom, scaleTo) <- scalePairs.flatten
+      sign <- Seq(1L, -1L)
+    } {
+      val unscaledWithSign = unscaled * sign
+      if (scaleFrom < 0 || scaleTo < 0) {
+        withSQLConf(SQLConf.LEGACY_ALLOW_NEGATIVE_SCALE_OF_DECIMAL_ENABLED.key -> "true") {
+          checkScaleChange(unscaledWithSign, scaleFrom, scaleTo, mode)
+        }
+      } else {
+        checkScaleChange(unscaledWithSign, scaleFrom, scaleTo, mode)
+      }
+    }
+
+    def checkScaleChange(unscaled: Long, scaleFrom: Int, scaleTo: Int,
+                         roundMode: BigDecimal.RoundingMode.Value): Unit = {
+      val decimal = Decimal(unscaled, precision, scaleFrom)
+      checkCompact(decimal, true)
+      decimal.changePrecision(precision, scaleTo, roundMode)
+      val bd = BigDecimal(unscaled, scaleFrom).setScale(scaleTo, roundMode)
+      assert(decimal.toBigDecimal === bd,
+        s"unscaled: $unscaled, scaleFrom: $scaleFrom, scaleTo: $scaleTo, mode: $roundMode")
+    }
+  }
+
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org