You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2022/08/22 08:35:49 UTC

[spark] branch branch-3.2 updated: [SPARK-40089][SQL] Fix sorting for some Decimal types

This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.2
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.2 by this push:
     new f14db2cf0dc [SPARK-40089][SQL] Fix sorting for some Decimal types
f14db2cf0dc is described below

commit f14db2cf0dc43ab7e63b873e803d0d66f88e7d85
Author: Robert (Bobby) Evans <bo...@apache.org>
AuthorDate: Mon Aug 22 16:33:37 2022 +0800

    [SPARK-40089][SQL] Fix sorting for some Decimal types
    
    ### What changes were proposed in this pull request?
    This fixes https://issues.apache.org/jira/browse/SPARK-40089 where the prefix can overflow in some cases and the code assumes that the overflow is always on the negative side, not the positive side.
    
    ### Why are the changes needed?
    This adds a check when the overflow does happen to know what is the proper prefix to return.
    
    ### Does this PR introduce _any_ user-facing change?
    No, unless you consider getting the sort order correct a user facing change.
    
    ### How was this patch tested?
    I tested manually with the file in the JIRA and I added a small unit test.
    
    Closes #37540 from revans2/fix_dec_sort.
    
    Authored-by: Robert (Bobby) Evans <bo...@apache.org>
    Signed-off-by: Wenchen Fan <we...@databricks.com>
    (cherry picked from commit 8dfd3dfc115d6e249f00a9a434b866d28e2eae45)
    Signed-off-by: Wenchen Fan <we...@databricks.com>
---
 .../spark/sql/catalyst/expressions/SortOrder.scala | 23 +++++++++++++---------
 .../org/apache/spark/sql/execution/SortSuite.scala | 19 ++++++++++++++++++
 2 files changed, 33 insertions(+), 9 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
index 9aef25ce605..403f064616f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
@@ -171,7 +171,13 @@ case class SortPrefix(child: SortOrder) extends UnaryExpression {
       val s = p - (dt.precision - dt.scale)
       (raw) => {
         val value = raw.asInstanceOf[Decimal]
-        if (value.changePrecision(p, s)) value.toUnscaledLong else Long.MinValue
+        if (value.changePrecision(p, s)) {
+          value.toUnscaledLong
+        } else if (value.toBigDecimal.signum < 0) {
+          Long.MinValue
+        } else {
+          Long.MaxValue
+        }
       }
     case dt: DecimalType => (raw) =>
       DoublePrefixComparator.computePrefix(raw.asInstanceOf[Decimal].toDouble)
@@ -204,15 +210,14 @@ case class SortPrefix(child: SortOrder) extends UnaryExpression {
         s"$DoublePrefixCmp.computePrefix((double)$input)"
       case StringType => s"$StringPrefixCmp.computePrefix($input)"
       case BinaryType => s"$BinaryPrefixCmp.computePrefix($input)"
+      case dt: DecimalType if dt.precision < Decimal.MAX_LONG_DIGITS =>
+        s"$input.toUnscaledLong()"
       case dt: DecimalType if dt.precision - dt.scale <= Decimal.MAX_LONG_DIGITS =>
-        if (dt.precision <= Decimal.MAX_LONG_DIGITS) {
-          s"$input.toUnscaledLong()"
-        } else {
-          // reduce the scale to fit in a long
-          val p = Decimal.MAX_LONG_DIGITS
-          val s = p - (dt.precision - dt.scale)
-          s"$input.changePrecision($p, $s) ? $input.toUnscaledLong() : ${Long.MinValue}L"
-        }
+        // reduce the scale to fit in a long
+        val p = Decimal.MAX_LONG_DIGITS
+        val s = p - (dt.precision - dt.scale)
+        s"$input.changePrecision($p, $s) ? $input.toUnscaledLong() : " +
+            s"$input.toBigDecimal().signum() < 0 ? ${Long.MinValue}L : ${Long.MaxValue}L"
       case dt: DecimalType =>
         s"$DoublePrefixCmp.computePrefix($input.toDouble())"
       case _ => "0L"
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala
index 6a4f3f62641..6941fe41541 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala
@@ -110,6 +110,25 @@ class SortSuite extends SparkPlanTest with SharedSparkSession {
       sortAnswers = false)
   }
 
+  test("SPARK-40089: decimal values sort correctly") {
+    val input = Seq(
+      BigDecimal("999999999999999999.50"),
+      BigDecimal("1.11"),
+      BigDecimal("999999999999999999.49")
+    )
+    // The range partitioner does the right thing. If there are too many
+    // shuffle partitions the error might not always show up.
+    withSQLConf("spark.sql.shuffle.partitions" -> "1") {
+      val inputDf = spark.createDataFrame(sparkContext.parallelize(input.map(v => Row(v)), 1),
+        StructType(StructField("a", DecimalType(20, 2)) :: Nil))
+      checkAnswer(
+        inputDf,
+        (child: SparkPlan) => SortExec('a.asc :: Nil, global = true, child = child),
+        input.sorted.map(Row(_)),
+        sortAnswers = false)
+    }
+  }
+
   // Test sorting on different data types
   for (
     dataType <- DataTypeTestUtils.atomicTypes ++ Set(NullType);


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org