You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by GitBox <gi...@apache.org> on 2020/10/06 17:06:29 UTC

[GitHub] [spark] sunchao commented on a change in pull request #29792: [SPARK-32858][SQL] UnwrapCastInBinaryComparison: support other numeric types

sunchao commented on a change in pull request #29792:
URL: https://github.com/apache/spark/pull/29792#discussion_r500436125



##########
File path: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala
##########
@@ -116,82 +128,118 @@ object UnwrapCastInBinaryComparison extends Rule[LogicalPlan] {
    * optimizes the expression by moving the cast to the literal side. Otherwise if result is not
    * true, this replaces the input binary comparison `exp` with simpler expressions.
    */
-  private def simplifyIntegralComparison(
+  private def simplifyNumericComparison(
       exp: BinaryComparison,
       fromExp: Expression,
-      toType: IntegralType,
+      toType: NumericType,
       value: Any): Expression = {
 
     val fromType = fromExp.dataType
-    val (min, max) = getRange(fromType)
-    val (minInToType, maxInToType) = {
-      (Cast(Literal(min), toType).eval(), Cast(Literal(max), toType).eval())
-    }
     val ordering = toType.ordering.asInstanceOf[Ordering[Any]]
-    val minCmp = ordering.compare(value, minInToType)
-    val maxCmp = ordering.compare(value, maxInToType)
+    val range = getRange(fromType)
 
-    if (maxCmp > 0) {
-      exp match {
-        case EqualTo(_, _) | GreaterThan(_, _) | GreaterThanOrEqual(_, _) =>
-          falseIfNotNull(fromExp)
-        case LessThan(_, _) | LessThanOrEqual(_, _) =>
-          trueIfNotNull(fromExp)
-        // make sure the expression is evaluated if it is non-deterministic
-        case EqualNullSafe(_, _) if exp.deterministic =>
-          FalseLiteral
-        case _ => exp
+    if (range.isDefined) {
+      val (min, max) = range.get
+      val (minInToType, maxInToType) = {
+        (Cast(Literal(min), toType).eval(), Cast(Literal(max), toType).eval())
       }
-    } else if (maxCmp == 0) {
-      exp match {
-        case GreaterThan(_, _) =>
-          falseIfNotNull(fromExp)
-        case LessThanOrEqual(_, _) =>
-          trueIfNotNull(fromExp)
-        case LessThan(_, _) =>
-          Not(EqualTo(fromExp, Literal(max, fromType)))
-        case GreaterThanOrEqual(_, _) | EqualTo(_, _) =>
-          EqualTo(fromExp, Literal(max, fromType))
-        case EqualNullSafe(_, _) =>
-          EqualNullSafe(fromExp, Literal(max, fromType))
-        case _ => exp
+      val minCmp = ordering.compare(value, minInToType)
+      val maxCmp = ordering.compare(value, maxInToType)
+
+      if (maxCmp >= 0 || minCmp <= 0) {
+        return if (maxCmp > 0) {
+          exp match {
+            case EqualTo(_, _) | GreaterThan(_, _) | GreaterThanOrEqual(_, _) =>
+              falseIfNotNull(fromExp)
+            case LessThan(_, _) | LessThanOrEqual(_, _) =>
+              trueIfNotNull(fromExp)
+            // make sure the expression is evaluated if it is non-deterministic
+            case EqualNullSafe(_, _) if exp.deterministic =>
+              FalseLiteral
+            case _ => exp
+          }
+        } else if (maxCmp == 0) {
+          exp match {
+            case GreaterThan(_, _) =>
+              falseIfNotNull(fromExp)
+            case LessThanOrEqual(_, _) =>
+              trueIfNotNull(fromExp)
+            case LessThan(_, _) =>
+              Not(EqualTo(fromExp, Literal(max, fromType)))
+            case GreaterThanOrEqual(_, _) | EqualTo(_, _) =>
+              EqualTo(fromExp, Literal(max, fromType))
+            case EqualNullSafe(_, _) =>
+              EqualNullSafe(fromExp, Literal(max, fromType))
+            case _ => exp
+          }
+        } else if (minCmp < 0) {
+          exp match {
+            case GreaterThan(_, _) | GreaterThanOrEqual(_, _) =>
+              trueIfNotNull(fromExp)
+            case LessThan(_, _) | LessThanOrEqual(_, _) | EqualTo(_, _) =>
+              falseIfNotNull(fromExp)
+            // make sure the expression is evaluated if it is non-deterministic
+            case EqualNullSafe(_, _) if exp.deterministic =>
+              FalseLiteral
+            case _ => exp
+          }
+        } else { // minCmp == 0
+          exp match {
+            case LessThan(_, _) =>
+              falseIfNotNull(fromExp)
+            case GreaterThanOrEqual(_, _) =>
+              trueIfNotNull(fromExp)
+            case GreaterThan(_, _) =>
+              Not(EqualTo(fromExp, Literal(min, fromType)))
+            case LessThanOrEqual(_, _) | EqualTo(_, _) =>
+              EqualTo(fromExp, Literal(min, fromType))
+            case EqualNullSafe(_, _) =>
+              EqualNullSafe(fromExp, Literal(min, fromType))
+            case _ => exp
+          }
+        }
       }
-    } else if (minCmp < 0) {
+    }
+
+    // When we reach to this point, it means either there is no min/max for the `fromType` (e.g.,
+    // decimal type), or that the literal `value` is within range `(min, max)`. For these, we
+    // optimize by moving the cast to the literal side.
+
+    val newValue = Cast(Literal(value), fromType).eval()
+    if (newValue == null) {
+      // This means the cast failed, for instance, due to the value is not representable in the
+      // narrower type. In this case we simply return the original expression.

Review comment:
       yup will do - there is also a test case covering this.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org