You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ya...@apache.org on 2020/05/24 12:41:31 UTC

[spark] branch branch-3.0 updated: [SPARK-31761][SQL][3.0] cast integer to Long to avoid IntegerOverflow for IntegralDivide operator

This is an automated email from the ASF dual-hosted git repository.

yamamuro pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.0 by this push:
     new 72c466e  [SPARK-31761][SQL][3.0] cast integer to Long to avoid IntegerOverflow for IntegralDivide operator
72c466e is described below

commit 72c466e0c37e4cc639040161699b6c0bffde70d5
Author: sandeep katta <sa...@gmail.com>
AuthorDate: Sun May 24 21:39:16 2020 +0900

    [SPARK-31761][SQL][3.0] cast integer to Long to avoid IntegerOverflow for IntegralDivide operator
    
    ### What changes were proposed in this pull request?
    `IntegralDivide` operator returns Long DataType, so integer overflow case should be handled.
    If the operands are of type Int it will be casted to Long
    
    ### Why are the changes needed?
    As `IntegralDivide` returns Long datatype, integer overflow should not happen
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Added UT and also tested in the local cluster
    
    After fix
    
    ![image](https://user-images.githubusercontent.com/35216143/82603361-25eccc00-9bd0-11ea-9ca7-001c539e628b.png)
    
    SQL Test
    
    After fix
    ![image](https://user-images.githubusercontent.com/35216143/82637689-f0250300-9c22-11ea-85c3-886ab2c23471.png)
    
    Before Fix
    ![image](https://user-images.githubusercontent.com/35216143/82637984-878a5600-9c23-11ea-9e47-5ce2fb923c01.png)
    
    Closes #28628 from sandeep-katta/branch3Backport.
    
    Authored-by: sandeep katta <sa...@gmail.com>
    Signed-off-by: Takeshi Yamamuro <ya...@apache.org>
---
 .../spark/sql/catalyst/analysis/TypeCoercion.scala | 18 ++++++++++++++++
 .../sql/catalyst/expressions/arithmetic.scala      |  2 +-
 .../sql/catalyst/analysis/TypeCoercionSuite.scala  | 24 ++++++++++++++++++++++
 .../expressions/ArithmeticExpressionSuite.scala    |  7 +------
 .../sql-functions/sql-expression-schema.md         |  2 +-
 .../resources/sql-tests/results/operators.sql.out  |  8 ++++----
 .../scala/org/apache/spark/sql/SQLQuerySuite.scala |  8 ++++++++
 7 files changed, 57 insertions(+), 12 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
index c6e3f56..a6f8e12 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
@@ -61,6 +61,7 @@ object TypeCoercion {
       IfCoercion ::
       StackCoercion ::
       Division ::
+      IntegralDivision ::
       ImplicitTypeCasts ::
       DateTimeOperations ::
       WindowFrameCoercion ::
@@ -685,6 +686,23 @@ object TypeCoercion {
   }
 
   /**
+   * The DIV operator always returns long-type value.
+   * This rule cast the integral inputs to long type, to avoid overflow during calculation.
+   */
+  object IntegralDivision extends TypeCoercionRule {
+    override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
+      case e if !e.childrenResolved => e
+      case d @ IntegralDivide(left, right) =>
+        IntegralDivide(mayCastToLong(left), mayCastToLong(right))
+    }
+
+    private def mayCastToLong(expr: Expression): Expression = expr.dataType match {
+      case _: ByteType | _: ShortType | _: IntegerType => Cast(expr, LongType)
+      case _ => expr
+    }
+  }
+
+  /**
    * Coerces the type of different branches of a CASE WHEN statement to a common type.
    */
   object CaseWhenCoercion extends TypeCoercionRule {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index 354845d..7c52183 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -412,7 +412,7 @@ case class IntegralDivide(
     left: Expression,
     right: Expression) extends DivModLike {
 
-  override def inputType: AbstractDataType = TypeCollection(IntegralType, DecimalType)
+  override def inputType: AbstractDataType = TypeCollection(LongType, DecimalType)
 
   override def dataType: DataType = LongType
 
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
index e37555f..1ea1ddb 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
@@ -1559,6 +1559,30 @@ class TypeCoercionSuite extends AnalysisTest {
           Literal.create(null, DecimalType.SYSTEM_DEFAULT)))
     }
   }
+
+  test("SPARK-31761: byte, short and int should be cast to long for IntegralDivide's datatype") {
+    val rules = Seq(FunctionArgumentConversion, Division, ImplicitTypeCasts)
+    // Casts Byte to Long
+    ruleTest(TypeCoercion.IntegralDivision, IntegralDivide(2.toByte, 1.toByte),
+      IntegralDivide(Cast(2.toByte, LongType), Cast(1.toByte, LongType)))
+    // Casts Short to Long
+    ruleTest(TypeCoercion.IntegralDivision, IntegralDivide(2.toShort, 1.toShort),
+      IntegralDivide(Cast(2.toShort, LongType), Cast(1.toShort, LongType)))
+    // Casts Integer to Long
+    ruleTest(TypeCoercion.IntegralDivision, IntegralDivide(2, 1),
+      IntegralDivide(Cast(2, LongType), Cast(1, LongType)))
+    // should not be any change for Long data types
+    ruleTest(TypeCoercion.IntegralDivision, IntegralDivide(2L, 1L), IntegralDivide(2L, 1L))
+    // one of the operand is byte
+    ruleTest(TypeCoercion.IntegralDivision, IntegralDivide(2L, 1.toByte),
+      IntegralDivide(2L, Cast(1.toByte, LongType)))
+    // one of the operand is short
+    ruleTest(TypeCoercion.IntegralDivision, IntegralDivide(2.toShort, 1L),
+      IntegralDivide(Cast(2.toShort, LongType), 1L))
+    // one of the operand is int
+    ruleTest(TypeCoercion.IntegralDivision, IntegralDivide(2, 1L),
+      IntegralDivide(Cast(2, LongType), 1L))
+  }
 }
 
 
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
index 675f85f..f05598a 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
@@ -173,13 +173,8 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
     }
   }
 
-  test("/ (Divide) for integral type") {
-    checkEvaluation(IntegralDivide(Literal(1.toByte), Literal(2.toByte)), 0L)
-    checkEvaluation(IntegralDivide(Literal(1.toShort), Literal(2.toShort)), 0L)
-    checkEvaluation(IntegralDivide(Literal(1), Literal(2)), 0L)
+  test("/ (Divide) for Long type") {
     checkEvaluation(IntegralDivide(Literal(1.toLong), Literal(2.toLong)), 0L)
-    checkEvaluation(IntegralDivide(positiveShortLit, negativeShortLit), 0L)
-    checkEvaluation(IntegralDivide(positiveIntLit, negativeIntLit), 0L)
     checkEvaluation(IntegralDivide(positiveLongLit, negativeLongLit), 0L)
   }
 
diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
index c3ae2a7..9e24a54 100644
--- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
+++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
@@ -136,7 +136,7 @@
 | org.apache.spark.sql.catalyst.expressions.InputFileBlockLength | input_file_block_length | N/A | N/A |
 | org.apache.spark.sql.catalyst.expressions.InputFileBlockStart | input_file_block_start | N/A | N/A |
 | org.apache.spark.sql.catalyst.expressions.InputFileName | input_file_name | N/A | N/A |
-| org.apache.spark.sql.catalyst.expressions.IntegralDivide | div | SELECT 3 div 2 | struct<(3 div 2):bigint> |
+| org.apache.spark.sql.catalyst.expressions.IntegralDivide | div | SELECT 3 div 2 | struct<(CAST(3 AS BIGINT) div CAST(2 AS BIGINT)):bigint> |
 | org.apache.spark.sql.catalyst.expressions.IsNaN | isnan | SELECT isnan(cast('NaN' as double)) | struct<isnan(CAST(NaN AS DOUBLE)):boolean> |
 | org.apache.spark.sql.catalyst.expressions.IsNotNull | isnotnull | SELECT isnotnull(1) | struct<(1 IS NOT NULL):boolean> |
 | org.apache.spark.sql.catalyst.expressions.IsNull | isnull | SELECT isnull(1) | struct<(1 IS NULL):boolean> |
diff --git a/sql/core/src/test/resources/sql-tests/results/operators.sql.out b/sql/core/src/test/resources/sql-tests/results/operators.sql.out
index a94a123..9accc57 100644
--- a/sql/core/src/test/resources/sql-tests/results/operators.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/operators.sql.out
@@ -157,7 +157,7 @@ NULL
 -- !query
 select 5 div 2
 -- !query schema
-struct<(5 div 2):bigint>
+struct<(CAST(5 AS BIGINT) div CAST(2 AS BIGINT)):bigint>
 -- !query output
 2
 
@@ -165,7 +165,7 @@ struct<(5 div 2):bigint>
 -- !query
 select 5 div 0
 -- !query schema
-struct<(5 div 0):bigint>
+struct<(CAST(5 AS BIGINT) div CAST(0 AS BIGINT)):bigint>
 -- !query output
 NULL
 
@@ -173,7 +173,7 @@ NULL
 -- !query
 select 5 div null
 -- !query schema
-struct<(5 div CAST(NULL AS INT)):bigint>
+struct<(CAST(5 AS BIGINT) div CAST(NULL AS BIGINT)):bigint>
 -- !query output
 NULL
 
@@ -181,7 +181,7 @@ NULL
 -- !query
 select null div 5
 -- !query schema
-struct<(CAST(NULL AS INT) div 5):bigint>
+struct<(CAST(NULL AS BIGINT) div CAST(5 AS BIGINT)):bigint>
 -- !query output
 NULL
 
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index d336f52..a23e583 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -3441,6 +3441,14 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
       assert(SQLConf.get.getConf(SQLConf.CODEGEN_FALLBACK) === true)
     }
   }
+
+  test("SPARK-31761: test byte, short, integer overflow for (Divide) integral type") {
+    checkAnswer(sql("Select -2147483648 DIV -1"), Seq(Row(Integer.MIN_VALUE.toLong * -1)))
+    checkAnswer(sql("select CAST(-128 as Byte) DIV CAST (-1 as Byte)"),
+      Seq(Row(Byte.MinValue.toLong * -1)))
+    checkAnswer(sql("select CAST(-32768 as short) DIV CAST (-1 as short)"),
+      Seq(Row(Short.MinValue.toLong * -1)))
+  }
 }
 
 case class Foo(bar: Option[String])


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org