You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by db...@apache.org on 2019/01/24 00:38:45 UTC

[spark] branch branch-2.3 updated: [SPARK-26706][SQL] Fix `illegalNumericPrecedence` for ByteType

This is an automated email from the ASF dual-hosted git repository.

dbtsai pushed a commit to branch branch-2.3
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-2.3 by this push:
     new de3b5c4  [SPARK-26706][SQL] Fix `illegalNumericPrecedence` for ByteType
de3b5c4 is described below

commit de3b5c459869e9ff0979e579828e822e6b01f0e3
Author: Anton Okolnychyi <ao...@apple.com>
AuthorDate: Thu Jan 24 00:12:26 2019 +0000

    [SPARK-26706][SQL] Fix `illegalNumericPrecedence` for ByteType
    
    This PR contains a minor change in `Cast$mayTruncate` that fixes its logic for bytes.
    
    Right now, `mayTruncate(ByteType, LongType)` returns `false` while `mayTruncate(ShortType, LongType)` returns `true`. Consequently, `spark.range(1, 3).as[Byte]` and `spark.range(1, 3).as[Short]` behave differently.
    
    Potentially, this bug can silently corrupt someone's data.
    ```scala
    // executes silently even though Long is converted into Byte
    spark.range(Long.MaxValue - 10, Long.MaxValue).as[Byte]
      .map(b => b - 1)
      .show()
    +-----+
    |value|
    +-----+
    |  -12|
    |  -11|
    |  -10|
    |   -9|
    |   -8|
    |   -7|
    |   -6|
    |   -5|
    |   -4|
    |   -3|
    +-----+
    // throws an AnalysisException: Cannot up cast `id` from bigint to smallint as it may truncate
    spark.range(Long.MaxValue - 10, Long.MaxValue).as[Short]
      .map(s => s - 1)
      .show()
    ```
    
    This PR comes with a set of unit tests.
    
    Closes #23632 from aokolnychyi/cast-fix.
    
    Authored-by: Anton Okolnychyi <ao...@apple.com>
    Signed-off-by: DB Tsai <d_...@apple.com>
---
 .../spark/sql/catalyst/expressions/Cast.scala      |  2 +-
 .../spark/sql/catalyst/expressions/CastSuite.scala | 36 ++++++++++++++++++++++
 .../scala/org/apache/spark/sql/DatasetSuite.scala  |  9 ++++++
 3 files changed, 46 insertions(+), 1 deletion(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 79b0516..5a156c5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -130,7 +130,7 @@ object Cast {
   private def illegalNumericPrecedence(from: DataType, to: DataType): Boolean = {
     val fromPrecedence = TypeCoercion.numericPrecedence.indexOf(from)
     val toPrecedence = TypeCoercion.numericPrecedence.indexOf(to)
-    toPrecedence > 0 && fromPrecedence > toPrecedence
+    toPrecedence >= 0 && fromPrecedence > toPrecedence
   }
 
   def forceNullable(from: DataType, to: DataType): Boolean = (from, to) match {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
index 5b25bdf..777295c 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
@@ -23,6 +23,7 @@ import java.util.{Calendar, Locale, TimeZone}
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.Row
 import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.TypeCoercion.numericPrecedence
 import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
 import org.apache.spark.sql.catalyst.util.DateTimeTestUtils._
 import org.apache.spark.sql.catalyst.util.DateTimeUtils
@@ -922,4 +923,39 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
     val ret6 = cast(Literal.create((1, Map(1 -> "a", 2 -> "b", 3 -> "c"))), StringType)
     checkEvaluation(ret6, "[1, [1 -> a, 2 -> b, 3 -> c]]")
   }
+
+  test("SPARK-26706: Fix Cast.mayTruncate for bytes") {
+    assert(!Cast.mayTruncate(ByteType, ByteType))
+    assert(!Cast.mayTruncate(DecimalType.ByteDecimal, ByteType))
+    assert(Cast.mayTruncate(ShortType, ByteType))
+    assert(Cast.mayTruncate(IntegerType, ByteType))
+    assert(Cast.mayTruncate(LongType, ByteType))
+    assert(Cast.mayTruncate(FloatType, ByteType))
+    assert(Cast.mayTruncate(DoubleType, ByteType))
+    assert(Cast.mayTruncate(DecimalType.IntDecimal, ByteType))
+  }
+
+  test("canSafeCast and mayTruncate must be consistent for numeric types") {
+    import DataTypeTestUtils._
+
+    def isCastSafe(from: NumericType, to: NumericType): Boolean = (from, to) match {
+      case (_, dt: DecimalType) => dt.isWiderThan(from)
+      case (dt: DecimalType, _) => dt.isTighterThan(to)
+      case _ => numericPrecedence.indexOf(from) <= numericPrecedence.indexOf(to)
+    }
+
+    numericTypes.foreach { from =>
+      val (safeTargetTypes, unsafeTargetTypes) = numericTypes.partition(to => isCastSafe(from, to))
+
+      safeTargetTypes.foreach { to =>
+        assert(Cast.canSafeCast(from, to), s"It should be possible to safely cast $from to $to")
+        assert(!Cast.mayTruncate(from, to), s"No truncation is expected when casting $from to $to")
+      }
+
+      unsafeTargetTypes.foreach { to =>
+        assert(!Cast.canSafeCast(from, to), s"It shouldn't be possible to safely cast $from to $to")
+        assert(Cast.mayTruncate(from, to), s"Truncation is expected when casting $from to $to")
+      }
+    }
+  }
 }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 522ed8d..ff6777f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -1478,6 +1478,15 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
     val exceptDF = inputDF.filter(col("a").isin("0") or col("b") > "c")
     checkAnswer(inputDF.except(exceptDF), Seq(Row("1", null)))
   }
+
+  test("SPARK-26706: Fix Cast.mayTruncate for bytes") {
+    val thrownException = intercept[AnalysisException] {
+      spark.range(Long.MaxValue - 10, Long.MaxValue).as[Byte]
+        .map(b => b - 1)
+        .collect()
+    }
+    assert(thrownException.message.contains("Cannot up cast `id` from bigint to tinyint"))
+  }
 }
 
 case class TestDataUnion(x: Int, y: Int, z: Int)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org