You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by do...@apache.org on 2018/12/05 17:13:44 UTC

spark git commit: [SPARK-26233][SQL][BACKPORT-2.2] CheckOverflow when encoding a decimal value

Repository: spark
Updated Branches:
  refs/heads/branch-2.2 9ceee6f18 -> 1c892c00d


[SPARK-26233][SQL][BACKPORT-2.2] CheckOverflow when encoding a decimal value

## What changes were proposed in this pull request?

When we encode a Decimal from external source we don't check for overflow. That method is useful not only in order to enforce that we can represent the correct value in the specified range, but it also changes the underlying data to the right precision/scale. Since in our code generation we assume that a decimal has exactly the same precision and scale of its data type, missing to enforce it can lead to corrupted output/results when there are subsequent transformations.

## How was this patch tested?

added UT

Closes #23234 from mgaido91/SPARK-26233_2.2.

Authored-by: Marco Gaido <ma...@gmail.com>
Signed-off-by: Dongjoon Hyun <do...@apache.org>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/1c892c00
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/1c892c00
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/1c892c00

Branch: refs/heads/branch-2.2
Commit: 1c892c00dbcea8d9c9c0674f7daf9c3c19d3761f
Parents: 9ceee6f
Author: Marco Gaido <ma...@gmail.com>
Authored: Wed Dec 5 09:13:36 2018 -0800
Committer: Dongjoon Hyun <do...@apache.org>
Committed: Wed Dec 5 09:13:36 2018 -0800

----------------------------------------------------------------------
 .../org/apache/spark/sql/catalyst/encoders/RowEncoder.scala | 4 ++--
 .../src/test/scala/org/apache/spark/sql/DatasetSuite.scala  | 9 +++++++++
 2 files changed, 11 insertions(+), 2 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/1c892c00/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
index 0f8282d..368a70a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
@@ -106,11 +106,11 @@ object RowEncoder {
         inputObject :: Nil)
 
     case d: DecimalType =>
-      StaticInvoke(
+      CheckOverflow(StaticInvoke(
         Decimal.getClass,
         d,
         "fromDecimal",
-        inputObject :: Nil)
+        inputObject :: Nil), d)
 
     case StringType =>
       StaticInvoke(

http://git-wip-us.apache.org/repos/asf/spark/blob/1c892c00/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 683fe4a..dbaf628 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -1238,6 +1238,15 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
       checkDataset(ds, SpecialCharClass("1", "2"))
     }
   }
+
+  test("SPARK-26233: serializer should enforce decimal precision and scale") {
+    val s = StructType(Seq(StructField("a", StringType), StructField("b", DecimalType(38, 8))))
+    val encoder = RowEncoder(s)
+    implicit val uEnc = encoder
+    val df = spark.range(2).map(l => Row(l.toString, BigDecimal.valueOf(l + 0.1111)))
+    checkAnswer(df.groupBy(col("a")).agg(first(col("b"))),
+      Seq(Row("0", BigDecimal.valueOf(0.1111)), Row("1", BigDecimal.valueOf(1.1111))))
+  }
 }
 
 case class WithImmutableMap(id: String, map_test: scala.collection.immutable.Map[Long, String])


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org