You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2022/08/31 05:32:48 UTC
[spark] branch master updated: [SPARK-39896][SQL] UnwrapCastInBinaryComparison should work when the literal of In/InSet downcast failed
This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 6e62b93f3d1 [SPARK-39896][SQL] UnwrapCastInBinaryComparison should work when the literal of In/InSet downcast failed
6e62b93f3d1 is described below
commit 6e62b93f3d1ef7e2d6be0a3bb729ab9b2d55a36d
Author: Fu Chen <cf...@gmail.com>
AuthorDate: Wed Aug 31 13:32:17 2022 +0800
[SPARK-39896][SQL] UnwrapCastInBinaryComparison should work when the literal of In/InSet downcast failed
### Why are the changes needed?
This PR aims to fix the case
```scala
sql("create table t1(a decimal(3, 0)) using parquet")
sql("insert into t1 values(100), (10), (1)")
sql("select * from t1 where a in(100000, 1.00)").show
```
```
java.lang.RuntimeException: After applying rule org.apache.spark.sql.catalyst.optimizer.UnwrapCastInBinaryComparison in batch Operator Optimization before Inferring Filters, the structural integrity of the plan is broken.
at org.apache.spark.sql.errors.QueryExecutionErrors$.structuralIntegrityIsBrokenAfterApplyingRuleError(QueryExecutionErrors.scala:1325)
```
1. the rule `UnwrapCastInBinaryComparison` transforms the expression `In` to Equals
```
CAST(a as decimal(12,2)) IN (100000.00,1.00)
OR(
CAST(a as decimal(12,2)) = 100000.00,
CAST(a as decimal(12,2)) = 1.00
)
```
2. using `UnwrapCastInBinaryComparison.unwrapCast()` to optimize each `EqualTo`
```
// Expression1
CAST(a as decimal(12,2)) = 100000.00 => CAST(a as decimal(12,2)) = 100000.00
// Expression2
CAST(a as decimal(12,2)) = 1.00 => a = 1
```
3. return the new unwrapped cast expression `In`
```
a IN (100000.00, 1.00)
```
Before this PR:
the method `UnwrapCastInBinaryComparison.unwrapCast()` returns the original expression when downcasting to a decimal type fails (the `Expression1`),returns the original expression if the downcast to the decimal type succeeds (the `Expression2`), the two expressions have different data type which would break the structural integrity
```
a IN (100000.00, 1.00)
| |
decimal(12, 2) |
decimal(3, 0)
```
After this PR:
the PR transform the downcasting failed expression to `falseIfNotNull(fromExp)`
```
((isnull(a) AND null) OR a IN (1.00)
```
### Does this PR introduce _any_ user-facing change?
No, only bug fix.
### How was this patch tested?
Unit test.
Closes #37439 from cfmcgrady/SPARK-39896.
Authored-by: Fu Chen <cf...@gmail.com>
Signed-off-by: Wenchen Fan <we...@databricks.com>
---
.../optimizer/UnwrapCastInBinaryComparison.scala | 131 ++++++++++-----------
.../UnwrapCastInBinaryComparisonSuite.scala | 68 ++++++++---
2 files changed, 113 insertions(+), 86 deletions(-)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala
index 94e27379b74..f4a92760d22 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala
@@ -17,7 +17,6 @@
package org.apache.spark.sql.catalyst.optimizer
-import scala.collection.immutable.HashSet
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.sql.catalyst.expressions._
@@ -145,80 +144,28 @@ object UnwrapCastInBinaryComparison extends Rule[LogicalPlan] {
case in @ In(Cast(fromExp, toType: NumericType, _, _), list @ Seq(firstLit, _*))
if canImplicitlyCast(fromExp, toType, firstLit.dataType) && in.inSetConvertible =>
- // There are 3 kinds of literals in the list:
- // 1. null literals
- // 2. The literals that can cast to fromExp.dataType
- // 3. The literals that cannot cast to fromExp.dataType
- // null literals is special as we can cast null literals to any data type.
- val (nullList, canCastList, cannotCastList) =
- (ArrayBuffer[Literal](), ArrayBuffer[Literal](), ArrayBuffer[Expression]())
- list.foreach {
- case lit @ Literal(null, _) => nullList += lit
- case lit @ NonNullLiteral(_, _) =>
- unwrapCast(EqualTo(in.value, lit)) match {
- case EqualTo(_, unwrapLit: Literal) => canCastList += unwrapLit
- case e @ And(IsNull(_), Literal(null, BooleanType)) => cannotCastList += e
- case _ => throw new IllegalStateException("Illegal unwrap cast result found.")
- }
- case _ => throw new IllegalStateException("Illegal value found in in.list.")
- }
-
- // return original expression when in.list contains only null values.
- if (canCastList.isEmpty && cannotCastList.isEmpty) {
- exp
- } else {
- // cast null value to fromExp.dataType, to make sure the new return list is in the same data
- // type.
- val newList = nullList.map(lit => Cast(lit, fromExp.dataType)) ++ canCastList
- val unwrapIn = In(fromExp, newList.toSeq)
- cannotCastList.headOption match {
- case None => unwrapIn
- // since `cannotCastList` are all the same,
- // convert to a single value `And(IsNull(_), Literal(null, BooleanType))`.
- case Some(falseIfNotNull @ And(IsNull(_), Literal(null, BooleanType)))
- if cannotCastList.map(_.canonicalized).distinct.length == 1 =>
- Or(falseIfNotNull, unwrapIn)
- case _ => exp
- }
+ val buildIn = {
+ (nullList: ArrayBuffer[Literal], canCastList: ArrayBuffer[Literal]) =>
+ // cast null value to fromExp.dataType, to make sure the new return list is in the same
+ // data type.
+ val newList = nullList.map(lit => Cast(lit, fromExp.dataType)) ++ canCastList
+ In(fromExp, newList.toSeq)
}
+ simplifyIn(fromExp, toType, list, buildIn).getOrElse(exp)
// The same with `In` expression, the analyzer makes sure that the hset of InSet is already of
// the same data type, so simply check `fromExp.dataType` can implicitly cast to `toType` and
// both `fromExp.dataType` and `toType` is numeric type or not.
- case inSet @ InSet(Cast(fromExp, toType: NumericType, _, _), hset)
+ case InSet(Cast(fromExp, toType: NumericType, _, _), hset)
if hset.nonEmpty && canImplicitlyCast(fromExp, toType, toType) =>
-
- // The same with `In`, there are 3 kinds of literals in the hset:
- // 1. null literals
- // 2. The literals that can cast to fromExp.dataType
- // 3. The literals that cannot cast to fromExp.dataType
- var (nullSet, canCastSet, cannotCastSet) =
- (HashSet[Any](), HashSet[Any](), HashSet[Expression]())
- hset.map(value => Literal.create(value, toType))
- .foreach {
- case lit @ Literal(null, _) => nullSet += lit.value
- case lit @ NonNullLiteral(_, _) =>
- unwrapCast(EqualTo(inSet.child, lit)) match {
- case EqualTo(_, unwrapLit: Literal) => canCastSet += unwrapLit.value
- case e @ And(IsNull(_), Literal(null, BooleanType)) => cannotCastSet += e
- case _ => throw new IllegalStateException("Illegal unwrap cast result found.")
- }
- case _ => throw new IllegalStateException("Illegal value found in hset.")
- }
-
- if (canCastSet.isEmpty && cannotCastSet.isEmpty) {
- exp
- } else {
- val unwrapInSet = InSet(fromExp, nullSet ++ canCastSet)
- cannotCastSet.headOption match {
- case None => unwrapInSet
- // since `cannotCastList` are all the same,
- // convert to a single value `And(IsNull(_), Literal(null, BooleanType))`.
- case Some(falseIfNotNull @ And(IsNull(_), Literal(null, BooleanType)))
- if cannotCastSet.map(_.canonicalized).size == 1 => Or(falseIfNotNull, unwrapInSet)
- case _ => exp
- }
- }
+ val buildInSet =
+ (nullList: ArrayBuffer[Literal], canCastList: ArrayBuffer[Literal]) =>
+ InSet(fromExp, (nullList ++ canCastList).map(_.value).toSet)
+ simplifyIn(
+ fromExp,
+ toType,
+ hset.map(v => Literal.create(v, toType)).toSeq,
+ buildInSet).getOrElse(exp)
case _ => exp
}
@@ -346,6 +293,52 @@ object UnwrapCastInBinaryComparison extends Rule[LogicalPlan] {
}
}
+ private def simplifyIn[IN <: Expression](
+ fromExp: Expression,
+ toType: NumericType,
+ list: Seq[Expression],
+ buildExpr: (ArrayBuffer[Literal], ArrayBuffer[Literal]) => IN): Option[Expression] = {
+
+ // There are 3 kinds of literals in the list:
+ // 1. null literals
+ // 2. The literals that can cast to fromExp.dataType
+ // 3. The literals that cannot cast to fromExp.dataType
+ // Note that:
+ // - null literals are special as we can cast null literals to any data type
+ // - for 3, we have three cases
+ // 1). the literal cannot cast to fromExp.dataType, and there is no min/max for the fromType,
+ // for instance:
+ // `cast(input[2, decimal(5,2), true] as decimal(10,4)) = 123456.1234`
+ // 2). the literal value is out of fromType range, for instance:
+ // `cast(input[0, smallint, true] as bigint) = 2147483647`
+ // 3). the literal value is rounded up/down after casting to `fromType`, for instance:
+ // `cast(input[1, float, true] as double) = 3.14`
+ // note that 3.14 will be rounded to 3.14000010... after casting to float
+
+ val (nullList, canCastList) = (ArrayBuffer[Literal](), ArrayBuffer[Literal]())
+ val fromType = fromExp.dataType
+ val ordering = toType.ordering.asInstanceOf[Ordering[Any]]
+
+ list.foreach {
+ case lit @ Literal(null, _) => nullList += lit
+ case NonNullLiteral(value, _) =>
+ val newValue = Cast(Literal(value), fromType, ansiEnabled = false).eval()
+ val valueRoundTrip = Cast(Literal(newValue, fromType), toType).eval()
+ if (newValue != null && ordering.compare(value, valueRoundTrip) == 0) {
+ canCastList += Literal(newValue, fromType)
+ }
+ }
+
+ if (nullList.isEmpty && canCastList.isEmpty) {
+ // only have cannot cast to fromExp.dataType literals
+ Option(falseIfNotNull(fromExp))
+ } else {
+ val unwrapExpr = buildExpr(nullList, canCastList)
+ Option(unwrapExpr)
+ }
+ }
+
+
/**
* Check if the input `fromExp` can be safely cast to `toType` without any loss of precision,
* i.e., the conversion is injective. Note this only handles the case when both sides are of
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala
index 75a0565da96..2e3b2708444 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala
@@ -248,18 +248,6 @@ class UnwrapCastInBinaryComparisonSuite extends PlanTest with ExpressionEvalHelp
val intLit = Literal.create(null, IntegerType)
val shortLit = Literal.create(null, ShortType)
- def checkInAndInSet(in: In, expected: Expression): Unit = {
- assertEquivalent(in, expected)
- val toInSet = (in: In) => InSet(in.value, HashSet() ++ in.list.map(_.eval()))
- val expectedInSet = expected match {
- case expectedIn: In =>
- toInSet(expectedIn)
- case Or(falseIfNotNull: And, expectedIn: In) =>
- Or(falseIfNotNull, toInSet(expectedIn))
- }
- assertEquivalent(toInSet(in), expectedInSet)
- }
-
checkInAndInSet(
In(Cast(f, LongType), Seq(1.toLong, 2.toLong, 3.toLong)),
f.in(1.toShort, 2.toShort, 3.toShort))
@@ -267,12 +255,12 @@ class UnwrapCastInBinaryComparisonSuite extends PlanTest with ExpressionEvalHelp
// in.list contains the value which out of `fromType` range
checkInAndInSet(
In(Cast(f, LongType), Seq(1.toLong, Int.MaxValue.toLong, Long.MaxValue)),
- Or(falseIfNotNull(f), f.in(1.toShort)))
+ f.in(1.toShort))
// in.list only contains the value which out of `fromType` range
checkInAndInSet(
In(Cast(f, LongType), Seq(Int.MaxValue.toLong, Long.MaxValue)),
- Or(falseIfNotNull(f), f.in()))
+ falseIfNotNull(f))
// in.list is empty
checkInAndInSet(
@@ -280,17 +268,51 @@ class UnwrapCastInBinaryComparisonSuite extends PlanTest with ExpressionEvalHelp
// in.list contains null value
checkInAndInSet(
- In(Cast(f, IntegerType), Seq(intLit)), In(Cast(f, IntegerType), Seq(intLit)))
+ In(Cast(f, IntegerType), Seq(intLit)), f.in(shortLit))
checkInAndInSet(
- In(Cast(f, IntegerType), Seq(intLit, intLit)), In(Cast(f, IntegerType), Seq(intLit, intLit)))
+ In(Cast(f, IntegerType), Seq(intLit, intLit)), f.in(shortLit, shortLit))
checkInAndInSet(
In(Cast(f, IntegerType), Seq(intLit, 1)), f.in(shortLit, 1.toShort))
checkInAndInSet(
In(Cast(f, LongType), Seq(longLit, 1.toLong, Long.MaxValue)),
- Or(falseIfNotNull(f), f.in(shortLit, 1.toShort))
+ f.in(shortLit, 1.toShort)
+ )
+ checkInAndInSet(
+ In(Cast(f, LongType), Seq(longLit, Long.MaxValue)),
+ f.in(shortLit)
)
}
+ test("SPARK-39896: unwrap cast when the literal of In/InSet downcast failed") {
+ val decimalValue = decimal2(123456.1234)
+ val decimalValue2 = decimal2(100.20)
+ checkInAndInSet(
+ In(castDecimal2(f3), Seq(decimalValue, decimalValue2)),
+ f3.in(decimal(decimalValue2)))
+ }
+
+ test("SPARK-39896: unwrap cast when the literal of In/Inset has round up or down") {
+
+ val doubleValue = 1.0
+ val doubleValue1 = 100.6
+ checkInAndInSet(
+ In(castDouble(f), Seq(doubleValue1, doubleValue)),
+ f.in(doubleValue.toShort))
+
+ // Cases for rounding up: 3.14 will be rounded to 3.14000010... after casting to float
+ val doubleValue2 = 3.14
+ checkInAndInSet(
+ In(castDouble(f2), Seq(doubleValue2, doubleValue)),
+ f2.in(doubleValue.toFloat))
+
+ // Another case: 400.5678 is rounded up to 400.57
+ val decimalValue1 = decimal2(400.5678)
+ val decimalValue2 = decimal2(1.0)
+ checkInAndInSet(
+ In(castDecimal2(f3), Seq(decimalValue1, decimalValue2)),
+ f3.in(decimal(decimalValue2)))
+ }
+
test("SPARK-36130: unwrap In should skip when in.list contains an expression that " +
"is not literal") {
val add = Cast(f2, DoubleType) + 1.0d
@@ -375,4 +397,16 @@ class UnwrapCastInBinaryComparisonSuite extends PlanTest with ExpressionEvalHelp
})
}
}
+
+ private def checkInAndInSet(in: In, expected: Expression): Unit = {
+ assertEquivalent(in, expected)
+ val toInSet = (in: In) => InSet(in.value, HashSet() ++ in.list.map(_.eval()))
+ val expectedInSet = expected match {
+ case expectedIn: In =>
+ toInSet(expectedIn)
+ case falseIfNotNull: And =>
+ falseIfNotNull
+ }
+ assertEquivalent(toInSet(in), expectedInSet)
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org