You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2022/06/03 06:12:45 UTC

[spark] branch branch-3.3 updated: [SPARK-39259][SQL][3.3] Evaluate timestamps consistently in subqueries

This is an automated email from the ASF dual-hosted git repository.

maxgekk pushed a commit to branch branch-3.3
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.3 by this push:
     new 4a0f0ff6c22 [SPARK-39259][SQL][3.3] Evaluate timestamps consistently in subqueries
4a0f0ff6c22 is described below

commit 4a0f0ff6c22b85cb0fc1eef842da8dbe4c90543a
Author: Ole Sasse <ol...@databricks.com>
AuthorDate: Fri Jun 3 09:12:26 2022 +0300

    [SPARK-39259][SQL][3.3] Evaluate timestamps consistently in subqueries
    
    ### What changes were proposed in this pull request?
    
    Apply the optimizer rule ComputeCurrentTime consistently across subqueries.
    
    This is a backport of https://github.com/apache/spark/pull/36654.
    
    ### Why are the changes needed?
    
    At the moment timestamp functions like now() can return different values within a query if subqueries are involved
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    A new unit test was added
    
    Closes #36752 from olaky/SPARK-39259-spark_3_3.
    
    Authored-by: Ole Sasse <ol...@databricks.com>
    Signed-off-by: Max Gekk <ma...@gmail.com>
---
 .../sql/catalyst/optimizer/finishAnalysis.scala    | 41 +++++-----
 .../spark/sql/catalyst/plans/QueryPlan.scala       | 11 ++-
 .../optimizer/ComputeCurrentTimeSuite.scala        | 89 ++++++++++++++++------
 3 files changed, 95 insertions(+), 46 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala
index ef9c4b9af40..242c799dd22 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala
@@ -17,14 +17,16 @@
 
 package org.apache.spark.sql.catalyst.optimizer
 
-import scala.collection.mutable
+import java.time.{Instant, LocalDateTime}
 
 import org.apache.spark.sql.catalyst.CurrentUserContext.CURRENT_USER
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.rules._
 import org.apache.spark.sql.catalyst.trees.TreePattern._
-import org.apache.spark.sql.catalyst.util.DateTimeUtils.{convertSpecialDate, convertSpecialTimestamp, convertSpecialTimestampNTZ}
+import org.apache.spark.sql.catalyst.trees.TreePatternBits
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
+import org.apache.spark.sql.catalyst.util.DateTimeUtils.{convertSpecialDate, convertSpecialTimestamp, convertSpecialTimestampNTZ, instantToMicros, localDateTimeToMicros}
 import org.apache.spark.sql.connector.catalog.CatalogManager
 import org.apache.spark.sql.types._
 import org.apache.spark.util.Utils
@@ -73,29 +75,30 @@ object RewriteNonCorrelatedExists extends Rule[LogicalPlan] {
  */
 object ComputeCurrentTime extends Rule[LogicalPlan] {
   def apply(plan: LogicalPlan): LogicalPlan = {
-    val currentDates = mutable.Map.empty[String, Literal]
-    val timeExpr = CurrentTimestamp()
-    val timestamp = timeExpr.eval(EmptyRow).asInstanceOf[Long]
-    val currentTime = Literal.create(timestamp, timeExpr.dataType)
+    val instant = Instant.now()
+    val currentTimestampMicros = instantToMicros(instant)
+    val currentTime = Literal.create(currentTimestampMicros, TimestampType)
     val timezone = Literal.create(conf.sessionLocalTimeZone, StringType)
-    val localTimestamps = mutable.Map.empty[String, Literal]
 
-    plan.transformAllExpressionsWithPruning(_.containsPattern(CURRENT_LIKE)) {
-      case currentDate @ CurrentDate(Some(timeZoneId)) =>
-        currentDates.getOrElseUpdate(timeZoneId, {
-          Literal.create(currentDate.eval().asInstanceOf[Int], DateType)
-        })
-      case CurrentTimestamp() | Now() => currentTime
-      case CurrentTimeZone() => timezone
-      case localTimestamp @ LocalTimestamp(Some(timeZoneId)) =>
-        localTimestamps.getOrElseUpdate(timeZoneId, {
-          Literal.create(localTimestamp.eval().asInstanceOf[Long], TimestampNTZType)
-        })
+    def transformCondition(treePatternbits: TreePatternBits): Boolean = {
+      treePatternbits.containsPattern(CURRENT_LIKE)
+    }
+
+    plan.transformDownWithSubqueries(transformCondition) {
+      case subQuery =>
+        subQuery.transformAllExpressionsWithPruning(transformCondition) {
+          case cd: CurrentDate =>
+            Literal.create(DateTimeUtils.microsToDays(currentTimestampMicros, cd.zoneId), DateType)
+          case CurrentTimestamp() | Now() => currentTime
+          case CurrentTimeZone() => timezone
+          case localTimestamp: LocalTimestamp =>
+            val asDateTime = LocalDateTime.ofInstant(instant, localTimestamp.zoneId)
+            Literal.create(localDateTimeToMicros(asDateTime), TimestampNTZType)
+        }
     }
   }
 }
 
-
 /**
  * Replaces the expression of CurrentDatabase with the current database name.
  * Replaces the expression of CurrentCatalog with the current catalog name.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
index 0f8df5df376..d0283f4d367 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
@@ -454,7 +454,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]]
    * to rewrite the whole plan, include its subqueries, in one go.
    */
   def transformWithSubqueries(f: PartialFunction[PlanType, PlanType]): PlanType =
-    transformDownWithSubqueries(f)
+    transformDownWithSubqueries(AlwaysProcess.fn, UnknownRuleId)(f)
 
   /**
    * Returns a copy of this node where the given partial function has been recursively applied
@@ -479,7 +479,10 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]]
    * first to this node, then this node's subqueries and finally this node's children.
    * When the partial function does not apply to a given node, it is left unchanged.
    */
-  def transformDownWithSubqueries(f: PartialFunction[PlanType, PlanType]): PlanType = {
+  def transformDownWithSubqueries(
+    cond: TreePatternBits => Boolean = AlwaysProcess.fn, ruleId: RuleId = UnknownRuleId)
+    (f: PartialFunction[PlanType, PlanType])
+: PlanType = {
     val g: PartialFunction[PlanType, PlanType] = new PartialFunction[PlanType, PlanType] {
       override def isDefinedAt(x: PlanType): Boolean = true
 
@@ -487,13 +490,13 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]]
         val transformed = f.applyOrElse[PlanType, PlanType](plan, identity)
         transformed transformExpressionsDown {
           case planExpression: PlanExpression[PlanType] =>
-            val newPlan = planExpression.plan.transformDownWithSubqueries(f)
+            val newPlan = planExpression.plan.transformDownWithSubqueries(cond, ruleId)(f)
             planExpression.withNewPlan(newPlan)
         }
       }
     }
 
-    transformDown(g)
+    transformDownWithPruning(cond, ruleId)(g)
   }
 
   /**
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ComputeCurrentTimeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ComputeCurrentTimeSuite.scala
index 9b04dcddfb2..c034906c09b 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ComputeCurrentTimeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ComputeCurrentTimeSuite.scala
@@ -19,10 +19,13 @@ package org.apache.spark.sql.catalyst.optimizer
 
 import java.time.{LocalDateTime, ZoneId}
 
+import scala.collection.JavaConverters.mapAsScalaMap
+import scala.concurrent.duration._
+
 import org.apache.spark.sql.catalyst.dsl.plans._
-import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentDate, CurrentTimestamp, CurrentTimeZone, Literal, LocalTimestamp}
+import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentDate, CurrentTimestamp, CurrentTimeZone, InSubquery, ListQuery, Literal, LocalTimestamp, Now}
 import org.apache.spark.sql.catalyst.plans.PlanTest
-import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project}
+import org.apache.spark.sql.catalyst.plans.logical.{Filter, LocalRelation, LogicalPlan, Project}
 import org.apache.spark.sql.catalyst.rules.RuleExecutor
 import org.apache.spark.sql.catalyst.util.DateTimeUtils
 import org.apache.spark.sql.internal.SQLConf
@@ -41,11 +44,7 @@ class ComputeCurrentTimeSuite extends PlanTest {
     val plan = Optimize.execute(in.analyze).asInstanceOf[Project]
     val max = (System.currentTimeMillis() + 1) * 1000
 
-    val lits = new scala.collection.mutable.ArrayBuffer[Long]
-    plan.transformAllExpressions { case e: Literal =>
-      lits += e.value.asInstanceOf[Long]
-      e
-    }
+    val lits = literals[Long](plan)
     assert(lits.size == 2)
     assert(lits(0) >= min && lits(0) <= max)
     assert(lits(1) >= min && lits(1) <= max)
@@ -59,11 +58,7 @@ class ComputeCurrentTimeSuite extends PlanTest {
     val plan = Optimize.execute(in.analyze).asInstanceOf[Project]
     val max = DateTimeUtils.currentDate(ZoneId.systemDefault())
 
-    val lits = new scala.collection.mutable.ArrayBuffer[Int]
-    plan.transformAllExpressions { case e: Literal =>
-      lits += e.value.asInstanceOf[Int]
-      e
-    }
+    val lits = literals[Int](plan)
     assert(lits.size == 2)
     assert(lits(0) >= min && lits(0) <= max)
     assert(lits(1) >= min && lits(1) <= max)
@@ -73,13 +68,9 @@ class ComputeCurrentTimeSuite extends PlanTest {
   test("SPARK-33469: Add current_timezone function") {
     val in = Project(Seq(Alias(CurrentTimeZone(), "c")()), LocalRelation())
     val plan = Optimize.execute(in.analyze).asInstanceOf[Project]
-    val lits = new scala.collection.mutable.ArrayBuffer[String]
-    plan.transformAllExpressions { case e: Literal =>
-      lits += e.value.asInstanceOf[UTF8String].toString
-      e
-    }
+    val lits = literals[UTF8String](plan)
     assert(lits.size == 1)
-    assert(lits.head == SQLConf.get.sessionLocalTimeZone)
+    assert(lits.head == UTF8String.fromString(SQLConf.get.sessionLocalTimeZone))
   }
 
   test("analyzer should replace localtimestamp with literals") {
@@ -92,14 +83,66 @@ class ComputeCurrentTimeSuite extends PlanTest {
     val plan = Optimize.execute(in.analyze).asInstanceOf[Project]
     val max = DateTimeUtils.localDateTimeToMicros(LocalDateTime.now(zoneId))
 
-    val lits = new scala.collection.mutable.ArrayBuffer[Long]
-    plan.transformAllExpressions { case e: Literal =>
-      lits += e.value.asInstanceOf[Long]
-      e
-    }
+    val lits = literals[Long](plan)
     assert(lits.size == 2)
     assert(lits(0) >= min && lits(0) <= max)
     assert(lits(1) >= min && lits(1) <= max)
     assert(lits(0) == lits(1))
   }
+
+  test("analyzer should use equal timestamps across subqueries") {
+    val timestampInSubQuery = Project(Seq(Alias(LocalTimestamp(), "timestamp1")()), LocalRelation())
+    val listSubQuery = ListQuery(timestampInSubQuery)
+    val valueSearchedInSubQuery = Seq(Alias(LocalTimestamp(), "timestamp2")())
+    val inFilterWithSubQuery = InSubquery(valueSearchedInSubQuery, listSubQuery)
+    val input = Project(Nil, Filter(inFilterWithSubQuery, LocalRelation()))
+
+    val plan = Optimize.execute(input.analyze).asInstanceOf[Project]
+
+    val lits = literals[Long](plan)
+    assert(lits.size == 3) // transformDownWithSubqueries covers the inner timestamp twice
+    assert(lits.toSet.size == 1)
+  }
+
+  test("analyzer should use consistent timestamps for different timezones") {
+    val localTimestamps = mapAsScalaMap(ZoneId.SHORT_IDS)
+      .map { case (zoneId, _) => Alias(LocalTimestamp(Some(zoneId)), zoneId)() }.toSeq
+    val input = Project(localTimestamps, LocalRelation())
+
+    val plan = Optimize.execute(input).asInstanceOf[Project]
+
+    val lits = literals[Long](plan)
+    assert(lits.size === localTimestamps.size)
+    // there are timezones with a 30 or 45 minute offset
+    val offsetsFromQuarterHour = lits.map( _ % Duration(15, MINUTES).toMicros).toSet
+    assert(offsetsFromQuarterHour.size == 1)
+  }
+
+  test("analyzer should use consistent timestamps for different timestamp functions") {
+    val differentTimestamps = Seq(
+      Alias(CurrentTimestamp(), "currentTimestamp")(),
+      Alias(Now(), "now")(),
+      Alias(LocalTimestamp(Some("PLT")), "localTimestampWithTimezone")()
+    )
+    val input = Project(differentTimestamps, LocalRelation())
+
+    val plan = Optimize.execute(input).asInstanceOf[Project]
+
+    val lits = literals[Long](plan)
+    assert(lits.size === differentTimestamps.size)
+    // there are timezones with a 30 or 45 minute offset
+    val offsetsFromQuarterHour = lits.map( _ % Duration(15, MINUTES).toMicros).toSet
+    assert(offsetsFromQuarterHour.size == 1)
+  }
+
+  private def literals[T](plan: LogicalPlan): Seq[T] = {
+    val literals = new scala.collection.mutable.ArrayBuffer[T]
+    plan.transformWithSubqueries { case subQuery =>
+      subQuery.transformAllExpressions { case expression: Literal =>
+        literals += expression.value.asInstanceOf[T]
+        expression
+      }
+    }
+    literals.asInstanceOf[Seq[T]]
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org