You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2022/08/31 05:32:48 UTC

[spark] branch master updated: [SPARK-39896][SQL] UnwrapCastInBinaryComparison should work when the literal of In/InSet downcast failed

This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 6e62b93f3d1 [SPARK-39896][SQL] UnwrapCastInBinaryComparison should work when the literal of In/InSet downcast failed
6e62b93f3d1 is described below

commit 6e62b93f3d1ef7e2d6be0a3bb729ab9b2d55a36d
Author: Fu Chen <cf...@gmail.com>
AuthorDate: Wed Aug 31 13:32:17 2022 +0800

    [SPARK-39896][SQL] UnwrapCastInBinaryComparison should work when the literal of In/InSet downcast failed
    
    ### Why are the changes needed?
    
    This PR aims to fix the case
    
    ```scala
    sql("create table t1(a decimal(3, 0)) using parquet")
    sql("insert into t1 values(100), (10), (1)")
    sql("select * from t1 where a in(100000, 1.00)").show
    ```
    
    ```
    java.lang.RuntimeException: After applying rule org.apache.spark.sql.catalyst.optimizer.UnwrapCastInBinaryComparison in batch Operator Optimization before Inferring Filters, the structural integrity of the plan is broken.
            at org.apache.spark.sql.errors.QueryExecutionErrors$.structuralIntegrityIsBrokenAfterApplyingRuleError(QueryExecutionErrors.scala:1325)
    ```
    
    1. the rule `UnwrapCastInBinaryComparison` transforms the expression `In` to Equals
    
    ```
    CAST(a as decimal(12,2)) IN (100000.00,1.00)
    
    OR(
       CAST(a as decimal(12,2)) = 100000.00,
       CAST(a as decimal(12,2)) = 1.00
    )
    ```
    
    2. using `UnwrapCastInBinaryComparison.unwrapCast()` to optimize each `EqualTo`
    
    ```
    // Expression1
    CAST(a as decimal(12,2)) = 100000.00 => CAST(a as decimal(12,2)) = 100000.00
    
    // Expression2
    CAST(a as decimal(12,2)) = 1.00      => a = 1
    ```
    
    3. return the new unwrapped cast expression `In`
    
    ```
    a IN (100000.00, 1.00)
    ```
    
    Before this PR:
    
    the method `UnwrapCastInBinaryComparison.unwrapCast()` returns the original expression when downcasting to a decimal type fails (the `Expression1`),returns the original expression if the downcast to the decimal type succeeds (the `Expression2`), the two expressions have different data type which would break the structural integrity
    
    ```
    a IN (100000.00, 1.00)
           |           |
        decimal(12, 2) |
                   decimal(3, 0)
    ```
    
    After this PR:
    
    the PR transform the downcasting failed expression to `falseIfNotNull(fromExp)`
    ```
    
    ((isnull(a) AND null) OR a IN (1.00)
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    
    No, only bug fix.
    
    ### How was this patch tested?
    
    Unit test.
    
    Closes #37439 from cfmcgrady/SPARK-39896.
    
    Authored-by: Fu Chen <cf...@gmail.com>
    Signed-off-by: Wenchen Fan <we...@databricks.com>
---
 .../optimizer/UnwrapCastInBinaryComparison.scala   | 131 ++++++++++-----------
 .../UnwrapCastInBinaryComparisonSuite.scala        |  68 ++++++++---
 2 files changed, 113 insertions(+), 86 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala
index 94e27379b74..f4a92760d22 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala
@@ -17,7 +17,6 @@
 
 package org.apache.spark.sql.catalyst.optimizer
 
-import scala.collection.immutable.HashSet
 import scala.collection.mutable.ArrayBuffer
 
 import org.apache.spark.sql.catalyst.expressions._
@@ -145,80 +144,28 @@ object UnwrapCastInBinaryComparison extends Rule[LogicalPlan] {
     case in @ In(Cast(fromExp, toType: NumericType, _, _), list @ Seq(firstLit, _*))
       if canImplicitlyCast(fromExp, toType, firstLit.dataType) && in.inSetConvertible =>
 
-      // There are 3 kinds of literals in the list:
-      // 1. null literals
-      // 2. The literals that can cast to fromExp.dataType
-      // 3. The literals that cannot cast to fromExp.dataType
-      // null literals is special as we can cast null literals to any data type.
-      val (nullList, canCastList, cannotCastList) =
-        (ArrayBuffer[Literal](), ArrayBuffer[Literal](), ArrayBuffer[Expression]())
-      list.foreach {
-        case lit @ Literal(null, _) => nullList += lit
-        case lit @ NonNullLiteral(_, _) =>
-          unwrapCast(EqualTo(in.value, lit)) match {
-            case EqualTo(_, unwrapLit: Literal) => canCastList += unwrapLit
-            case e @ And(IsNull(_), Literal(null, BooleanType)) => cannotCastList += e
-            case _ => throw new IllegalStateException("Illegal unwrap cast result found.")
-          }
-        case _ => throw new IllegalStateException("Illegal value found in in.list.")
-      }
-
-      // return original expression when in.list contains only null values.
-      if (canCastList.isEmpty && cannotCastList.isEmpty) {
-        exp
-      } else {
-        // cast null value to fromExp.dataType, to make sure the new return list is in the same data
-        // type.
-        val newList = nullList.map(lit => Cast(lit, fromExp.dataType)) ++ canCastList
-        val unwrapIn = In(fromExp, newList.toSeq)
-        cannotCastList.headOption match {
-          case None => unwrapIn
-          // since `cannotCastList` are all the same,
-          // convert to a single value `And(IsNull(_), Literal(null, BooleanType))`.
-          case Some(falseIfNotNull @ And(IsNull(_), Literal(null, BooleanType)))
-              if cannotCastList.map(_.canonicalized).distinct.length == 1 =>
-            Or(falseIfNotNull, unwrapIn)
-          case _ => exp
-        }
+      val buildIn = {
+        (nullList: ArrayBuffer[Literal], canCastList: ArrayBuffer[Literal]) =>
+          // cast null value to fromExp.dataType, to make sure the new return list is in the same
+          // data type.
+          val newList = nullList.map(lit => Cast(lit, fromExp.dataType)) ++ canCastList
+          In(fromExp, newList.toSeq)
       }
+      simplifyIn(fromExp, toType, list, buildIn).getOrElse(exp)
 
     // The same with `In` expression, the analyzer makes sure that the hset of InSet is already of
     // the same data type, so simply check `fromExp.dataType` can implicitly cast to `toType` and
     // both `fromExp.dataType` and `toType` is numeric type or not.
-    case inSet @ InSet(Cast(fromExp, toType: NumericType, _, _), hset)
+    case InSet(Cast(fromExp, toType: NumericType, _, _), hset)
       if hset.nonEmpty && canImplicitlyCast(fromExp, toType, toType) =>
-
-      // The same with `In`, there are 3 kinds of literals in the hset:
-      // 1. null literals
-      // 2. The literals that can cast to fromExp.dataType
-      // 3. The literals that cannot cast to fromExp.dataType
-      var (nullSet, canCastSet, cannotCastSet) =
-        (HashSet[Any](), HashSet[Any](), HashSet[Expression]())
-      hset.map(value => Literal.create(value, toType))
-        .foreach {
-          case lit @ Literal(null, _) => nullSet += lit.value
-          case lit @ NonNullLiteral(_, _) =>
-            unwrapCast(EqualTo(inSet.child, lit)) match {
-              case EqualTo(_, unwrapLit: Literal) => canCastSet += unwrapLit.value
-              case e @ And(IsNull(_), Literal(null, BooleanType)) => cannotCastSet += e
-              case _ => throw new IllegalStateException("Illegal unwrap cast result found.")
-            }
-          case _ => throw new IllegalStateException("Illegal value found in hset.")
-        }
-
-      if (canCastSet.isEmpty && cannotCastSet.isEmpty) {
-        exp
-      } else {
-        val unwrapInSet = InSet(fromExp, nullSet ++ canCastSet)
-        cannotCastSet.headOption match {
-          case None => unwrapInSet
-          // since `cannotCastList` are all the same,
-          // convert to a single value `And(IsNull(_), Literal(null, BooleanType))`.
-          case Some(falseIfNotNull @ And(IsNull(_), Literal(null, BooleanType)))
-            if cannotCastSet.map(_.canonicalized).size == 1 => Or(falseIfNotNull, unwrapInSet)
-          case _ => exp
-        }
-      }
+      val buildInSet =
+        (nullList: ArrayBuffer[Literal], canCastList: ArrayBuffer[Literal]) =>
+          InSet(fromExp, (nullList ++ canCastList).map(_.value).toSet)
+      simplifyIn(
+        fromExp,
+        toType,
+        hset.map(v => Literal.create(v, toType)).toSeq,
+        buildInSet).getOrElse(exp)
 
     case _ => exp
   }
@@ -346,6 +293,52 @@ object UnwrapCastInBinaryComparison extends Rule[LogicalPlan] {
     }
   }
 
+  private def simplifyIn[IN <: Expression](
+      fromExp: Expression,
+      toType: NumericType,
+      list: Seq[Expression],
+      buildExpr: (ArrayBuffer[Literal], ArrayBuffer[Literal]) => IN): Option[Expression] = {
+
+    // There are 3 kinds of literals in the list:
+    // 1. null literals
+    // 2. The literals that can cast to fromExp.dataType
+    // 3. The literals that cannot cast to fromExp.dataType
+    // Note that:
+    // - null literals are special as we can cast null literals to any data type
+    // - for 3, we have three cases
+    //   1). the literal cannot cast to fromExp.dataType, and there is no min/max for the fromType,
+    //     for instance:
+    //         `cast(input[2, decimal(5,2), true] as decimal(10,4)) = 123456.1234`
+    //   2). the literal value is out of fromType range, for instance:
+    //         `cast(input[0, smallint, true] as bigint) = 2147483647`
+    //   3). the literal value is rounded up/down after casting to `fromType`, for instance:
+    //         `cast(input[1, float, true] as double) = 3.14`
+    //     note that 3.14 will be rounded to 3.14000010... after casting to float
+
+    val (nullList, canCastList) = (ArrayBuffer[Literal](), ArrayBuffer[Literal]())
+    val fromType = fromExp.dataType
+    val ordering = toType.ordering.asInstanceOf[Ordering[Any]]
+
+    list.foreach {
+      case lit @ Literal(null, _) => nullList += lit
+      case NonNullLiteral(value, _) =>
+        val newValue = Cast(Literal(value), fromType, ansiEnabled = false).eval()
+        val valueRoundTrip = Cast(Literal(newValue, fromType), toType).eval()
+        if (newValue != null && ordering.compare(value, valueRoundTrip) == 0) {
+          canCastList += Literal(newValue, fromType)
+        }
+    }
+
+    if (nullList.isEmpty && canCastList.isEmpty) {
+      // only have cannot cast to fromExp.dataType literals
+      Option(falseIfNotNull(fromExp))
+    } else {
+      val unwrapExpr = buildExpr(nullList, canCastList)
+      Option(unwrapExpr)
+    }
+  }
+
+
   /**
    * Check if the input `fromExp` can be safely cast to `toType` without any loss of precision,
    * i.e., the conversion is injective. Note this only handles the case when both sides are of
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala
index 75a0565da96..2e3b2708444 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala
@@ -248,18 +248,6 @@ class UnwrapCastInBinaryComparisonSuite extends PlanTest with ExpressionEvalHelp
     val intLit = Literal.create(null, IntegerType)
     val shortLit = Literal.create(null, ShortType)
 
-    def checkInAndInSet(in: In, expected: Expression): Unit = {
-      assertEquivalent(in, expected)
-      val toInSet = (in: In) => InSet(in.value, HashSet() ++ in.list.map(_.eval()))
-      val expectedInSet = expected match {
-        case expectedIn: In =>
-          toInSet(expectedIn)
-        case Or(falseIfNotNull: And, expectedIn: In) =>
-          Or(falseIfNotNull, toInSet(expectedIn))
-      }
-      assertEquivalent(toInSet(in), expectedInSet)
-    }
-
     checkInAndInSet(
       In(Cast(f, LongType), Seq(1.toLong, 2.toLong, 3.toLong)),
       f.in(1.toShort, 2.toShort, 3.toShort))
@@ -267,12 +255,12 @@ class UnwrapCastInBinaryComparisonSuite extends PlanTest with ExpressionEvalHelp
     // in.list contains the value which out of `fromType` range
     checkInAndInSet(
       In(Cast(f, LongType), Seq(1.toLong, Int.MaxValue.toLong, Long.MaxValue)),
-      Or(falseIfNotNull(f), f.in(1.toShort)))
+      f.in(1.toShort))
 
     // in.list only contains the value which out of `fromType` range
     checkInAndInSet(
       In(Cast(f, LongType), Seq(Int.MaxValue.toLong, Long.MaxValue)),
-      Or(falseIfNotNull(f), f.in()))
+      falseIfNotNull(f))
 
     // in.list is empty
     checkInAndInSet(
@@ -280,17 +268,51 @@ class UnwrapCastInBinaryComparisonSuite extends PlanTest with ExpressionEvalHelp
 
     // in.list contains null value
     checkInAndInSet(
-      In(Cast(f, IntegerType), Seq(intLit)), In(Cast(f, IntegerType), Seq(intLit)))
+      In(Cast(f, IntegerType), Seq(intLit)), f.in(shortLit))
     checkInAndInSet(
-      In(Cast(f, IntegerType), Seq(intLit, intLit)), In(Cast(f, IntegerType), Seq(intLit, intLit)))
+      In(Cast(f, IntegerType), Seq(intLit, intLit)), f.in(shortLit, shortLit))
     checkInAndInSet(
       In(Cast(f, IntegerType), Seq(intLit, 1)), f.in(shortLit, 1.toShort))
     checkInAndInSet(
       In(Cast(f, LongType), Seq(longLit, 1.toLong, Long.MaxValue)),
-      Or(falseIfNotNull(f), f.in(shortLit, 1.toShort))
+      f.in(shortLit, 1.toShort)
+    )
+    checkInAndInSet(
+      In(Cast(f, LongType), Seq(longLit, Long.MaxValue)),
+      f.in(shortLit)
     )
   }
 
+  test("SPARK-39896: unwrap cast when the literal of In/InSet downcast failed") {
+    val decimalValue = decimal2(123456.1234)
+    val decimalValue2 = decimal2(100.20)
+    checkInAndInSet(
+      In(castDecimal2(f3), Seq(decimalValue, decimalValue2)),
+      f3.in(decimal(decimalValue2)))
+  }
+
+  test("SPARK-39896: unwrap cast when the literal of In/Inset has round up or down") {
+
+    val doubleValue = 1.0
+    val doubleValue1 = 100.6
+    checkInAndInSet(
+      In(castDouble(f), Seq(doubleValue1, doubleValue)),
+      f.in(doubleValue.toShort))
+
+    // Cases for rounding up: 3.14 will be rounded to 3.14000010... after casting to float
+    val doubleValue2 = 3.14
+    checkInAndInSet(
+      In(castDouble(f2), Seq(doubleValue2, doubleValue)),
+      f2.in(doubleValue.toFloat))
+
+    // Another case: 400.5678 is rounded up to 400.57
+    val decimalValue1 = decimal2(400.5678)
+    val decimalValue2 = decimal2(1.0)
+    checkInAndInSet(
+      In(castDecimal2(f3), Seq(decimalValue1, decimalValue2)),
+      f3.in(decimal(decimalValue2)))
+  }
+
   test("SPARK-36130: unwrap In should skip when in.list contains an expression that " +
     "is not literal") {
     val add = Cast(f2, DoubleType) + 1.0d
@@ -375,4 +397,16 @@ class UnwrapCastInBinaryComparisonSuite extends PlanTest with ExpressionEvalHelp
       })
     }
   }
+
+  private def checkInAndInSet(in: In, expected: Expression): Unit = {
+    assertEquivalent(in, expected)
+    val toInSet = (in: In) => InSet(in.value, HashSet() ++ in.list.map(_.eval()))
+    val expectedInSet = expected match {
+      case expectedIn: In =>
+        toInSet(expectedIn)
+      case falseIfNotNull: And =>
+        falseIfNotNull
+    }
+    assertEquivalent(toInSet(in), expectedInSet)
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org