You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by sr...@apache.org on 2023/02/03 16:49:11 UTC
[spark] branch branch-3.3 updated: [SPARK-41554] fix changing of Decimal scale when scale decreased by m…
This is an automated email from the ASF dual-hosted git repository.
srowen pushed a commit to branch branch-3.3
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.3 by this push:
new 2d539c5c702 [SPARK-41554] fix changing of Decimal scale when scale decreased by m…
2d539c5c702 is described below
commit 2d539c5c7022d44d8a2d53e752287c42c2601444
Author: oleksii.diagiliev <ol...@workday.com>
AuthorDate: Fri Feb 3 10:48:56 2023 -0600
[SPARK-41554] fix changing of Decimal scale when scale decreased by m…
…ore than 18
This is a backport PR for https://github.com/apache/spark/pull/39099
Closes #39813 from fe2s/branch-3.3-fix-decimal-scaling.
Authored-by: oleksii.diagiliev <ol...@workday.com>
Signed-off-by: Sean Owen <sr...@gmail.com>
---
.../scala/org/apache/spark/sql/types/Decimal.scala | 60 +++++++++++++---------
.../org/apache/spark/sql/types/DecimalSuite.scala | 53 ++++++++++++++++++-
2 files changed, 88 insertions(+), 25 deletions(-)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
index 7a43d01eb2f..07a2c47cff0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
@@ -397,30 +397,42 @@ final class Decimal extends Ordered[Decimal] with Serializable {
if (scale < _scale) {
// Easier case: we just need to divide our scale down
val diff = _scale - scale
- val pow10diff = POW_10(diff)
- // % and / always round to 0
- val droppedDigits = longVal % pow10diff
- longVal /= pow10diff
- roundMode match {
- case ROUND_FLOOR =>
- if (droppedDigits < 0) {
- longVal += -1L
- }
- case ROUND_CEILING =>
- if (droppedDigits > 0) {
- longVal += 1L
- }
- case ROUND_HALF_UP =>
- if (math.abs(droppedDigits) * 2 >= pow10diff) {
- longVal += (if (droppedDigits < 0) -1L else 1L)
- }
- case ROUND_HALF_EVEN =>
- val doubled = math.abs(droppedDigits) * 2
- if (doubled > pow10diff || doubled == pow10diff && longVal % 2 != 0) {
- longVal += (if (droppedDigits < 0) -1L else 1L)
- }
- case _ =>
- throw QueryExecutionErrors.unsupportedRoundingMode(roundMode)
+ // If diff is greater than max number of digits we store in Long, then
+ // value becomes 0. Otherwise we calculate new value dividing by power of 10.
+ // In both cases we apply rounding after that.
+ if (diff > MAX_LONG_DIGITS) {
+ longVal = roundMode match {
+ case ROUND_FLOOR => if (longVal < 0) -1L else 0L
+ case ROUND_CEILING => if (longVal > 0) 1L else 0L
+ case ROUND_HALF_UP | ROUND_HALF_EVEN => 0L
+ case _ => sys.error(s"Not supported rounding mode: $roundMode")
+ }
+ } else {
+ val pow10diff = POW_10(diff)
+ // % and / always round to 0
+ val droppedDigits = longVal % pow10diff
+ longVal /= pow10diff
+ roundMode match {
+ case ROUND_FLOOR =>
+ if (droppedDigits < 0) {
+ longVal += -1L
+ }
+ case ROUND_CEILING =>
+ if (droppedDigits > 0) {
+ longVal += 1L
+ }
+ case ROUND_HALF_UP =>
+ if (math.abs(droppedDigits) * 2 >= pow10diff) {
+ longVal += (if (droppedDigits < 0) -1L else 1L)
+ }
+ case ROUND_HALF_EVEN =>
+ val doubled = math.abs(droppedDigits) * 2
+ if (doubled > pow10diff || doubled == pow10diff && longVal % 2 != 0) {
+ longVal += (if (droppedDigits < 0) -1L else 1L)
+ }
+ case _ =>
+ throw QueryExecutionErrors.unsupportedRoundingMode(roundMode)
+ }
}
} else if (scale > _scale) {
// We might be able to multiply longVal by a power of 10 and not overflow, but if not,
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala
index 6f70dc51b95..6ccd2b9bd32 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala
@@ -27,6 +27,9 @@ import org.apache.spark.sql.types.Decimal._
import org.apache.spark.unsafe.types.UTF8String
class DecimalSuite extends SparkFunSuite with PrivateMethodTester with SQLHelper {
+
+ val allSupportedRoundModes = Seq(ROUND_HALF_UP, ROUND_HALF_EVEN, ROUND_CEILING, ROUND_FLOOR)
+
/** Check that a Decimal has the given string representation, precision and scale */
private def checkDecimal(d: Decimal, string: String, precision: Int, scale: Int): Unit = {
assert(d.toString === string)
@@ -222,7 +225,7 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester with SQLHelper
}
test("changePrecision/toPrecision on compact decimal should respect rounding mode") {
- Seq(ROUND_FLOOR, ROUND_CEILING, ROUND_HALF_UP, ROUND_HALF_EVEN).foreach { mode =>
+ allSupportedRoundModes.foreach { mode =>
Seq("0.4", "0.5", "0.6", "1.0", "1.1", "1.6", "2.5", "5.5").foreach { n =>
Seq("", "-").foreach { sign =>
val bd = BigDecimal(sign + n)
@@ -315,4 +318,52 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester with SQLHelper
}
}
}
+
+ // 18 is a max number of digits in Decimal's compact long
+ test("SPARK-41554: decrease/increase scale by 18 and more on compact decimal") {
+ val unscaledNums = Seq(
+ 0L, 1L, 10L, 51L, 123L, 523L,
+ // 18 digits
+ 912345678901234567L,
+ 112345678901234567L,
+ 512345678901234567L
+ )
+ val precision = 38
+ // generate some (from, to) scale pairs, e.g. (38, 18), (-20, -2), etc
+ val scalePairs = for {
+ scale <- Seq(38, 20, 19, 18)
+ delta <- Seq(38, 20, 19, 18)
+ a = scale
+ b = scale - delta
+ } yield {
+ Seq((a, b), (-a, -b), (b, a), (-b, -a))
+ }
+
+ for {
+ unscaled <- unscaledNums
+ mode <- allSupportedRoundModes
+ (scaleFrom, scaleTo) <- scalePairs.flatten
+ sign <- Seq(1L, -1L)
+ } {
+ val unscaledWithSign = unscaled * sign
+ if (scaleFrom < 0 || scaleTo < 0) {
+ withSQLConf(SQLConf.LEGACY_ALLOW_NEGATIVE_SCALE_OF_DECIMAL_ENABLED.key -> "true") {
+ checkScaleChange(unscaledWithSign, scaleFrom, scaleTo, mode)
+ }
+ } else {
+ checkScaleChange(unscaledWithSign, scaleFrom, scaleTo, mode)
+ }
+ }
+
+ def checkScaleChange(unscaled: Long, scaleFrom: Int, scaleTo: Int,
+ roundMode: BigDecimal.RoundingMode.Value): Unit = {
+ val decimal = Decimal(unscaled, precision, scaleFrom)
+ checkCompact(decimal, true)
+ decimal.changePrecision(precision, scaleTo, roundMode)
+ val bd = BigDecimal(unscaled, scaleFrom).setScale(scaleTo, roundMode)
+ assert(decimal.toBigDecimal === bd,
+ s"unscaled: $unscaled, scaleFrom: $scaleFrom, scaleTo: $scaleTo, mode: $roundMode")
+ }
+ }
+
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org