You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ya...@apache.org on 2020/05/24 12:41:31 UTC
[spark] branch branch-3.0 updated: [SPARK-31761][SQL][3.0] cast
integer to Long to avoid IntegerOverflow for IntegralDivide operator
This is an automated email from the ASF dual-hosted git repository.
yamamuro pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.0 by this push:
new 72c466e [SPARK-31761][SQL][3.0] cast integer to Long to avoid IntegerOverflow for IntegralDivide operator
72c466e is described below
commit 72c466e0c37e4cc639040161699b6c0bffde70d5
Author: sandeep katta <sa...@gmail.com>
AuthorDate: Sun May 24 21:39:16 2020 +0900
[SPARK-31761][SQL][3.0] cast integer to Long to avoid IntegerOverflow for IntegralDivide operator
### What changes were proposed in this pull request?
`IntegralDivide` operator returns Long DataType, so integer overflow case should be handled.
If the operands are of type Int it will be casted to Long
### Why are the changes needed?
As `IntegralDivide` returns Long datatype, integer overflow should not happen
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Added UT and also tested in the local cluster
After fix
![image](https://user-images.githubusercontent.com/35216143/82603361-25eccc00-9bd0-11ea-9ca7-001c539e628b.png)
SQL Test
After fix
![image](https://user-images.githubusercontent.com/35216143/82637689-f0250300-9c22-11ea-85c3-886ab2c23471.png)
Before Fix
![image](https://user-images.githubusercontent.com/35216143/82637984-878a5600-9c23-11ea-9e47-5ce2fb923c01.png)
Closes #28628 from sandeep-katta/branch3Backport.
Authored-by: sandeep katta <sa...@gmail.com>
Signed-off-by: Takeshi Yamamuro <ya...@apache.org>
---
.../spark/sql/catalyst/analysis/TypeCoercion.scala | 18 ++++++++++++++++
.../sql/catalyst/expressions/arithmetic.scala | 2 +-
.../sql/catalyst/analysis/TypeCoercionSuite.scala | 24 ++++++++++++++++++++++
.../expressions/ArithmeticExpressionSuite.scala | 7 +------
.../sql-functions/sql-expression-schema.md | 2 +-
.../resources/sql-tests/results/operators.sql.out | 8 ++++----
.../scala/org/apache/spark/sql/SQLQuerySuite.scala | 8 ++++++++
7 files changed, 57 insertions(+), 12 deletions(-)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
index c6e3f56..a6f8e12 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
@@ -61,6 +61,7 @@ object TypeCoercion {
IfCoercion ::
StackCoercion ::
Division ::
+ IntegralDivision ::
ImplicitTypeCasts ::
DateTimeOperations ::
WindowFrameCoercion ::
@@ -685,6 +686,23 @@ object TypeCoercion {
}
/**
+ * The DIV operator always returns long-type value.
+ * This rule cast the integral inputs to long type, to avoid overflow during calculation.
+ */
+ object IntegralDivision extends TypeCoercionRule {
+ override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
+ case e if !e.childrenResolved => e
+ case d @ IntegralDivide(left, right) =>
+ IntegralDivide(mayCastToLong(left), mayCastToLong(right))
+ }
+
+ private def mayCastToLong(expr: Expression): Expression = expr.dataType match {
+ case _: ByteType | _: ShortType | _: IntegerType => Cast(expr, LongType)
+ case _ => expr
+ }
+ }
+
+ /**
* Coerces the type of different branches of a CASE WHEN statement to a common type.
*/
object CaseWhenCoercion extends TypeCoercionRule {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index 354845d..7c52183 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -412,7 +412,7 @@ case class IntegralDivide(
left: Expression,
right: Expression) extends DivModLike {
- override def inputType: AbstractDataType = TypeCollection(IntegralType, DecimalType)
+ override def inputType: AbstractDataType = TypeCollection(LongType, DecimalType)
override def dataType: DataType = LongType
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
index e37555f..1ea1ddb 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
@@ -1559,6 +1559,30 @@ class TypeCoercionSuite extends AnalysisTest {
Literal.create(null, DecimalType.SYSTEM_DEFAULT)))
}
}
+
+ test("SPARK-31761: byte, short and int should be cast to long for IntegralDivide's datatype") {
+ val rules = Seq(FunctionArgumentConversion, Division, ImplicitTypeCasts)
+ // Casts Byte to Long
+ ruleTest(TypeCoercion.IntegralDivision, IntegralDivide(2.toByte, 1.toByte),
+ IntegralDivide(Cast(2.toByte, LongType), Cast(1.toByte, LongType)))
+ // Casts Short to Long
+ ruleTest(TypeCoercion.IntegralDivision, IntegralDivide(2.toShort, 1.toShort),
+ IntegralDivide(Cast(2.toShort, LongType), Cast(1.toShort, LongType)))
+ // Casts Integer to Long
+ ruleTest(TypeCoercion.IntegralDivision, IntegralDivide(2, 1),
+ IntegralDivide(Cast(2, LongType), Cast(1, LongType)))
+ // should not be any change for Long data types
+ ruleTest(TypeCoercion.IntegralDivision, IntegralDivide(2L, 1L), IntegralDivide(2L, 1L))
+ // one of the operand is byte
+ ruleTest(TypeCoercion.IntegralDivision, IntegralDivide(2L, 1.toByte),
+ IntegralDivide(2L, Cast(1.toByte, LongType)))
+ // one of the operand is short
+ ruleTest(TypeCoercion.IntegralDivision, IntegralDivide(2.toShort, 1L),
+ IntegralDivide(Cast(2.toShort, LongType), 1L))
+ // one of the operand is int
+ ruleTest(TypeCoercion.IntegralDivision, IntegralDivide(2, 1L),
+ IntegralDivide(Cast(2, LongType), 1L))
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
index 675f85f..f05598a 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
@@ -173,13 +173,8 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
}
}
- test("/ (Divide) for integral type") {
- checkEvaluation(IntegralDivide(Literal(1.toByte), Literal(2.toByte)), 0L)
- checkEvaluation(IntegralDivide(Literal(1.toShort), Literal(2.toShort)), 0L)
- checkEvaluation(IntegralDivide(Literal(1), Literal(2)), 0L)
+ test("/ (Divide) for Long type") {
checkEvaluation(IntegralDivide(Literal(1.toLong), Literal(2.toLong)), 0L)
- checkEvaluation(IntegralDivide(positiveShortLit, negativeShortLit), 0L)
- checkEvaluation(IntegralDivide(positiveIntLit, negativeIntLit), 0L)
checkEvaluation(IntegralDivide(positiveLongLit, negativeLongLit), 0L)
}
diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
index c3ae2a7..9e24a54 100644
--- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
+++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
@@ -136,7 +136,7 @@
| org.apache.spark.sql.catalyst.expressions.InputFileBlockLength | input_file_block_length | N/A | N/A |
| org.apache.spark.sql.catalyst.expressions.InputFileBlockStart | input_file_block_start | N/A | N/A |
| org.apache.spark.sql.catalyst.expressions.InputFileName | input_file_name | N/A | N/A |
-| org.apache.spark.sql.catalyst.expressions.IntegralDivide | div | SELECT 3 div 2 | struct<(3 div 2):bigint> |
+| org.apache.spark.sql.catalyst.expressions.IntegralDivide | div | SELECT 3 div 2 | struct<(CAST(3 AS BIGINT) div CAST(2 AS BIGINT)):bigint> |
| org.apache.spark.sql.catalyst.expressions.IsNaN | isnan | SELECT isnan(cast('NaN' as double)) | struct<isnan(CAST(NaN AS DOUBLE)):boolean> |
| org.apache.spark.sql.catalyst.expressions.IsNotNull | isnotnull | SELECT isnotnull(1) | struct<(1 IS NOT NULL):boolean> |
| org.apache.spark.sql.catalyst.expressions.IsNull | isnull | SELECT isnull(1) | struct<(1 IS NULL):boolean> |
diff --git a/sql/core/src/test/resources/sql-tests/results/operators.sql.out b/sql/core/src/test/resources/sql-tests/results/operators.sql.out
index a94a123..9accc57 100644
--- a/sql/core/src/test/resources/sql-tests/results/operators.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/operators.sql.out
@@ -157,7 +157,7 @@ NULL
-- !query
select 5 div 2
-- !query schema
-struct<(5 div 2):bigint>
+struct<(CAST(5 AS BIGINT) div CAST(2 AS BIGINT)):bigint>
-- !query output
2
@@ -165,7 +165,7 @@ struct<(5 div 2):bigint>
-- !query
select 5 div 0
-- !query schema
-struct<(5 div 0):bigint>
+struct<(CAST(5 AS BIGINT) div CAST(0 AS BIGINT)):bigint>
-- !query output
NULL
@@ -173,7 +173,7 @@ NULL
-- !query
select 5 div null
-- !query schema
-struct<(5 div CAST(NULL AS INT)):bigint>
+struct<(CAST(5 AS BIGINT) div CAST(NULL AS BIGINT)):bigint>
-- !query output
NULL
@@ -181,7 +181,7 @@ NULL
-- !query
select null div 5
-- !query schema
-struct<(CAST(NULL AS INT) div 5):bigint>
+struct<(CAST(NULL AS BIGINT) div CAST(5 AS BIGINT)):bigint>
-- !query output
NULL
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index d336f52..a23e583 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -3441,6 +3441,14 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
assert(SQLConf.get.getConf(SQLConf.CODEGEN_FALLBACK) === true)
}
}
+
+ test("SPARK-31761: test byte, short, integer overflow for (Divide) integral type") {
+ checkAnswer(sql("Select -2147483648 DIV -1"), Seq(Row(Integer.MIN_VALUE.toLong * -1)))
+ checkAnswer(sql("select CAST(-128 as Byte) DIV CAST (-1 as Byte)"),
+ Seq(Row(Byte.MinValue.toLong * -1)))
+ checkAnswer(sql("select CAST(-32768 as short) DIV CAST (-1 as short)"),
+ Seq(Row(Short.MinValue.toLong * -1)))
+ }
}
case class Foo(bar: Option[String])
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org