You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by yu...@apache.org on 2023/03/14 00:00:58 UTC

[spark] branch master updated: [SPARK-42597][SQL] Support unwrap date type to timestamp type

This is an automated email from the ASF dual-hosted git repository.

yumwang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new ab2ae034dc7 [SPARK-42597][SQL] Support unwrap date type to timestamp type
ab2ae034dc7 is described below

commit ab2ae034dc70adc597d3baf0d4b7347daa55caa8
Author: Yuming Wang <yu...@ebay.com>
AuthorDate: Tue Mar 14 08:00:40 2023 +0800

    [SPARK-42597][SQL] Support unwrap date type to timestamp type
    
    ### What changes were proposed in this pull request?
    
    This PR enhance `UnwrapCastInBinaryComparison` to support unwrap date type to timestamp type.
    
    The way to unwrap date type to timestamp type are:
    ```
    GreaterThan(Cast(ts, DateType), date) -> GreaterThanOrEqual(ts, Cast(date + 1, TimestampType))
    GreaterThanOrEqual(Cast(ts, DateType), date) -> GreaterThanOrEqual(ts, Cast(date, TimestampType))
    Equality(Cast(ts, DateType), date) -> And(GreaterThanOrEqual(ts, Cast(date, TimestampType)), LessThan(ts, Cast(date + 1, TimestampType)))
    LessThan(Cast(ts, DateType), date) -> LessThan(ts, Cast(date, TimestampType))
    LessThanOrEqual(Cast(ts, DateType), date) -> LessThan(ts, Cast(date + 1, TimestampType))
    ```
    
    ### Why are the changes needed?
    
    Improve query performance.
    
    A common use case. We store cold data in HDFS by partition, store hot data in MySQL, and then union all the results. The filter in the MySQL branch cannot be pushed down, which affects performance:
    ```sql
    CREATE TABLE t_cold(id bigint, start timestamp, dt date) using parquet PARTITIONED BY (dt);
    CREATE TABLE t_hot(id bigint, start timestamp) using org.apache.spark.sql.jdbc OPTIONS (`url` '***', `dbtable` 'db.t2', `user` 'spark', `password` '***');
    CREATE VIEW all_data AS SELECT * FROM t_cold UNION ALL SELECT *, to_date(start) FROM t_hot;
    SELECT * FROM all_data WHERE start BETWEEN '2023-02-06' AND '2023-02-07';
    ```
    
    Before this PR | After this PR
    -- | --
    <img src="https://user-images.githubusercontent.com/5399861/221576723-7fc45356-65db-48e2-8d40-88420c21c9f5.png" width="400" height="730"> | <img src="https://user-images.githubusercontent.com/5399861/221575848-5b975ed0-70ab-4527-acfe-796cc20e169b.png" width="400" height="730">
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Unit test.
    
    Closes #40190 from wangyum/SPARK-42597.
    
    Authored-by: Yuming Wang <yu...@ebay.com>
    Signed-off-by: Yuming Wang <yu...@ebay.com>
---
 .../optimizer/UnwrapCastInBinaryComparison.scala   | 52 +++++++++++++----
 .../UnwrapCastInBinaryComparisonSuite.scala        | 68 ++++++++++++++++++----
 .../sql/UnwrapCastInComparisonEndToEndSuite.scala  | 24 ++++++++
 3 files changed, 124 insertions(+), 20 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala
index f4a92760d22..d95bc694814 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala
@@ -40,7 +40,7 @@ import org.apache.spark.sql.types._
  *
  * Currently this only handles cases where:
  *   1). `fromType` (of `fromExp`) and `toType` are of numeric types (i.e., short, int, float,
- *     decimal, etc) or boolean type
+ *     decimal, etc), boolean type or datetime type
  *   2). `fromType` can be safely coerced to `toType` without precision loss (e.g., short to int,
  *     int to long, but not long to int, nor int to boolean)
  *
@@ -104,16 +104,15 @@ object UnwrapCastInBinaryComparison extends Rule[LogicalPlan] {
     case l: LogicalPlan =>
       l.transformExpressionsUpWithPruning(
         _.containsAnyPattern(BINARY_COMPARISON, IN, INSET), ruleId) {
-        case e @ (BinaryComparison(_, _) | In(_, _) | InSet(_, _)) => unwrapCast(e)
+        case e @ (BinaryComparison(_, _) | In(_, _) | InSet(_, _)) => unwrapCast(e).getOrElse(e)
       }
   }
 
-  private def unwrapCast(exp: Expression): Expression = exp match {
+  private def unwrapCast(exp: Expression): Option[Expression] = exp match {
     // Not a canonical form. In this case we first canonicalize the expression by swapping the
     // literal and cast side, then process the result and swap the literal and cast again to
     // restore the original order.
-    case BinaryComparison(Literal(_, literalType), Cast(fromExp, toType, _, _))
-        if canImplicitlyCast(fromExp, toType, literalType) =>
+    case BinaryComparison(_: Literal, _: Cast) =>
       def swap(e: Expression): Expression = e match {
         case GreaterThan(left, right) => LessThan(right, left)
         case GreaterThanOrEqual(left, right) => LessThanOrEqual(right, left)
@@ -124,14 +123,19 @@ object UnwrapCastInBinaryComparison extends Rule[LogicalPlan] {
         case _ => e
       }
 
-      swap(unwrapCast(swap(exp)))
+      unwrapCast(swap(exp)).map(swap)
 
     // In case both sides have numeric type, optimize the comparison by removing casts or
     // moving cast to the literal side.
     case be @ BinaryComparison(
       Cast(fromExp, toType: NumericType, _, _), Literal(value, literalType))
         if canImplicitlyCast(fromExp, toType, literalType) =>
-      simplifyNumericComparison(be, fromExp, toType, value)
+      Some(simplifyNumericComparison(be, fromExp, toType, value))
+
+    case be @ BinaryComparison(
+      Cast(fromExp, _, timeZoneId, evalMode), date @ Literal(value, DateType))
+        if AnyTimestampType.acceptsType(fromExp.dataType) && value != null =>
+      Some(unwrapDateToTimestamp(be, fromExp, date, timeZoneId, evalMode))
 
     // As the analyzer makes sure that the list of In is already of the same data type, then the
     // rule can simply check the first literal in `in.list` can implicitly cast to `toType` or not,
@@ -151,7 +155,7 @@ object UnwrapCastInBinaryComparison extends Rule[LogicalPlan] {
           val newList = nullList.map(lit => Cast(lit, fromExp.dataType)) ++ canCastList
           In(fromExp, newList.toSeq)
       }
-      simplifyIn(fromExp, toType, list, buildIn).getOrElse(exp)
+      simplifyIn(fromExp, toType, list, buildIn)
 
     // The same with `In` expression, the analyzer makes sure that the hset of InSet is already of
     // the same data type, so simply check `fromExp.dataType` can implicitly cast to `toType` and
@@ -165,9 +169,9 @@ object UnwrapCastInBinaryComparison extends Rule[LogicalPlan] {
         fromExp,
         toType,
         hset.map(v => Literal.create(v, toType)).toSeq,
-        buildInSet).getOrElse(exp)
+        buildInSet)
 
-    case _ => exp
+    case _ => None
   }
 
   /**
@@ -293,6 +297,34 @@ object UnwrapCastInBinaryComparison extends Rule[LogicalPlan] {
     }
   }
 
+  /**
+   * Move the cast to the literal side, because we can only get the minimum value of timestamp,
+   * so some BinaryComparison needs to be changed,
+   * such as CAST(ts AS date) > DATE '2023-01-01' ===> ts >= TIMESTAMP '2023-01-02 00:00:00'
+   */
+  private def unwrapDateToTimestamp(
+      exp: BinaryComparison,
+      fromExp: Expression,
+      date: Literal,
+      tz: Option[String],
+      evalMode: EvalMode.Value): Expression = {
+    val dateAddOne = DateAdd(date, Literal(1, IntegerType))
+    exp match {
+      case _: GreaterThan =>
+        GreaterThanOrEqual(fromExp, Cast(dateAddOne, fromExp.dataType, tz, evalMode))
+      case _: GreaterThanOrEqual =>
+        GreaterThanOrEqual(fromExp, Cast(date, fromExp.dataType, tz, evalMode))
+      case Equality(_, _) =>
+        And(GreaterThanOrEqual(fromExp, Cast(date, fromExp.dataType, tz, evalMode)),
+          LessThan(fromExp, Cast(dateAddOne, fromExp.dataType, tz, evalMode)))
+      case _: LessThan =>
+        LessThan(fromExp, Cast(date, fromExp.dataType, tz, evalMode))
+      case _: LessThanOrEqual =>
+        LessThan(fromExp, Cast(dateAddOne, fromExp.dataType, tz, evalMode))
+      case _ => exp
+    }
+  }
+
   private def simplifyIn[IN <: Expression](
       fromExp: Expression,
       toType: NumericType,
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala
index 2e3b2708444..400f2f2c97b 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.sql.catalyst.optimizer
 
+import java.time.{LocalDate, LocalDateTime}
+
 import scala.collection.immutable.HashSet
 
 import org.apache.spark.sql.catalyst.dsl.expressions._
@@ -39,11 +41,13 @@ class UnwrapCastInBinaryComparisonSuite extends PlanTest with ExpressionEvalHelp
   }
 
   val testRelation: LocalRelation = LocalRelation($"a".short, $"b".float,
-    $"c".decimal(5, 2), $"d".boolean)
+    $"c".decimal(5, 2), $"d".boolean, $"e".timestamp, $"f".timestampNTZ)
   val f: BoundReference = $"a".short.canBeNull.at(0)
   val f2: BoundReference = $"b".float.canBeNull.at(1)
   val f3: BoundReference = $"c".decimal(5, 2).canBeNull.at(2)
   val f4: BoundReference = $"d".boolean.canBeNull.at(3)
+  val f5: BoundReference = $"e".timestamp.canBeNull.at(4)
+  val f6: BoundReference = $"f".timestampNTZ.canBeNull.at(5)
 
   test("unwrap casts when literal == max") {
     val v = Short.MaxValue
@@ -368,9 +372,53 @@ class UnwrapCastInBinaryComparisonSuite extends PlanTest with ExpressionEvalHelp
     assertEquivalent(castInt(f4) < t, trueIfNotNull(f4))
   }
 
+  test("SPARK-42597: Support unwrap date to timestamp type") {
+    val dateLit = Literal.create(LocalDate.of(2023, 1, 1), DateType)
+    val dateAddOne = DateAdd(dateLit, Literal(1))
+    val nullLit = Literal.create(null, DateType)
+
+    assertEquivalent(
+      castDate(f5) > dateLit || castDate(f6) > dateLit,
+      f5 >= castTimestamp(dateAddOne) || f6 >= castTimestampNTZ(dateAddOne))
+    assertEquivalent(
+      castDate(f5) >= dateLit || castDate(f6) >= dateLit,
+      f5 >= castTimestamp(dateLit) || f6 >= castTimestampNTZ(dateLit))
+    assertEquivalent(
+      castDate(f5) < dateLit || castDate(f6) < dateLit,
+      f5 < castTimestamp(dateLit) || f6 < castTimestampNTZ(dateLit))
+    assertEquivalent(
+      castDate(f5) <= dateLit || castDate(f6) <= dateLit,
+      f5 < castTimestamp(dateAddOne) || f6 < castTimestampNTZ(dateAddOne))
+    assertEquivalent(
+      castDate(f5) === dateLit || castDate(f6) === dateLit,
+      (f5 >= castTimestamp(dateLit) && f5 < castTimestamp(dateAddOne)) ||
+        (f6 >= castTimestampNTZ(dateLit) && f6 < castTimestampNTZ(dateAddOne)))
+    assertEquivalent(
+      castDate(f5) <=> dateLit || castDate(f6) === dateLit,
+      (f5 >= castTimestamp(dateLit) && f5 < castTimestamp(dateAddOne)) ||
+        (f6 >= castTimestampNTZ(dateLit) && f6 < castTimestampNTZ(dateAddOne)))
+    assertEquivalent(
+      dateLit < castDate(f5) || dateLit < castDate(f6),
+      castTimestamp(dateAddOne) <= f5 || castTimestampNTZ(dateAddOne) <= f6)
+
+    // Null date literal should be handled by NullPropagation
+    assertEquivalent(castDate(f5) > nullLit || castDate(f6) > nullLit,
+      Literal.create(null, BooleanType) || Literal.create(null, BooleanType))
+  }
+
+  private val ts1 = LocalDateTime.of(2023, 1, 1, 23, 59, 59, 99999000)
+  private val ts2 = LocalDateTime.of(2023, 1, 1, 23, 59, 59, 999998000)
+  private val ts3 = LocalDateTime.of(9999, 12, 31, 23, 59, 59, 999999999)
+  private val ts4 = LocalDateTime.of(1, 1, 1, 0, 0, 0, 0)
+
   private def castInt(e: Expression): Expression = Cast(e, IntegerType)
   private def castDouble(e: Expression): Expression = Cast(e, DoubleType)
   private def castDecimal2(e: Expression): Expression = Cast(e, DecimalType(10, 4))
+  private def castDate(e: Expression): Expression = Cast(e, DateType)
+  private def castTimestamp(e: Expression): Expression =
+    Cast(e, TimestampType, Some(conf.sessionLocalTimeZone))
+  private def castTimestampNTZ(e: Expression): Expression =
+    Cast(e, TimestampNTZType, Some(conf.sessionLocalTimeZone))
 
   private def decimal(v: Decimal): Decimal = Decimal(v.toJavaBigDecimal, 5, 2)
   private def decimal2(v: BigDecimal): Decimal = Decimal(v, 10, 4)
@@ -383,16 +431,16 @@ class UnwrapCastInBinaryComparisonSuite extends PlanTest with ExpressionEvalHelp
 
     if (evaluate) {
       Seq(
-        (100.toShort, 3.14.toFloat, decimal2(100), true),
-        (-300.toShort, 3.1415927.toFloat, decimal2(-3000.50), false),
-        (null, Float.NaN, decimal2(12345.6789), null),
-        (null, null, null, null),
-        (Short.MaxValue, Float.PositiveInfinity, decimal2(Short.MaxValue), true),
-        (Short.MinValue, Float.NegativeInfinity, decimal2(Short.MinValue), false),
-        (0.toShort, Float.MaxValue, decimal2(0), null),
-        (0.toShort, Float.MinValue, decimal2(0.01), null)
+        (100.toShort, 3.14.toFloat, decimal2(100), true, ts1, ts1),
+        (-300.toShort, 3.1415927.toFloat, decimal2(-3000.50), false, ts2, ts2),
+        (null, Float.NaN, decimal2(12345.6789), null, null, null),
+        (null, null, null, null, null, null),
+        (Short.MaxValue, Float.PositiveInfinity, decimal2(Short.MaxValue), true, ts3, ts3),
+        (Short.MinValue, Float.NegativeInfinity, decimal2(Short.MinValue), false, ts4, ts4),
+        (0.toShort, Float.MaxValue, decimal2(0), null, null, null),
+        (0.toShort, Float.MinValue, decimal2(0.01), null, null, null)
       ).foreach(v => {
-        val row = create_row(v._1, v._2, v._3, v._4)
+        val row = create_row(v._1, v._2, v._3, v._4, v._5, v._6)
         checkEvaluation(e1, e2.eval(row), row)
       })
     }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UnwrapCastInComparisonEndToEndSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UnwrapCastInComparisonEndToEndSuite.scala
index 1d7af84ef60..468915aa493 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UnwrapCastInComparisonEndToEndSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UnwrapCastInComparisonEndToEndSuite.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.sql
 
+import java.time.LocalDateTime
+
 import org.apache.spark.sql.catalyst.expressions.IntegralLiteralTestUtils.{negativeInt, positiveInt}
 import org.apache.spark.sql.test.SharedSparkSession
 import org.apache.spark.sql.types.Decimal
@@ -240,5 +242,27 @@ class UnwrapCastInComparisonEndToEndSuite extends QueryTest with SharedSparkSess
     }
   }
 
+  test("SPARK-42597: Support unwrap date type to timestamp type") {
+    val ts1 = LocalDateTime.of(2023, 1, 1, 23, 59, 59, 99999000)
+    val ts2 = LocalDateTime.of(2023, 1, 1, 23, 59, 59, 999998000)
+    val ts3 = LocalDateTime.of(2023, 1, 2, 23, 59, 59, 8000)
+
+    withTable(t) {
+      Seq(ts1, ts2, ts3).toDF("ts").write.saveAsTable(t)
+      val df = spark.table(t)
+
+      checkAnswer(
+        df.where("cast(ts as date) > date'2023-01-01'"), Seq(ts3).map(Row(_)))
+      checkAnswer(
+        df.where("cast(ts as date) >= date'2023-01-01'"), Seq(ts1, ts2, ts3).map(Row(_)))
+      checkAnswer(
+        df.where("cast(ts as date) < date'2023-01-02'"), Seq(ts1, ts2).map(Row(_)))
+      checkAnswer(
+        df.where("cast(ts as date) <= date'2023-01-02'"), Seq(ts1, ts2, ts3).map(Row(_)))
+      checkAnswer(
+        df.where("cast(ts as date) = date'2023-01-01'"), Seq(ts1, ts2).map(Row(_)))
+    }
+  }
+
   private def decimal(v: BigDecimal): Decimal = Decimal(v, 5, 2)
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org