You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2022/06/03 06:12:45 UTC
[spark] branch branch-3.3 updated: [SPARK-39259][SQL][3.3] Evaluate timestamps consistently in subqueries
This is an automated email from the ASF dual-hosted git repository.
maxgekk pushed a commit to branch branch-3.3
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.3 by this push:
new 4a0f0ff6c22 [SPARK-39259][SQL][3.3] Evaluate timestamps consistently in subqueries
4a0f0ff6c22 is described below
commit 4a0f0ff6c22b85cb0fc1eef842da8dbe4c90543a
Author: Ole Sasse <ol...@databricks.com>
AuthorDate: Fri Jun 3 09:12:26 2022 +0300
[SPARK-39259][SQL][3.3] Evaluate timestamps consistently in subqueries
### What changes were proposed in this pull request?
Apply the optimizer rule ComputeCurrentTime consistently across subqueries.
This is a backport of https://github.com/apache/spark/pull/36654.
### Why are the changes needed?
At the moment timestamp functions like now() can return different values within a query if subqueries are involved
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
A new unit test was added
Closes #36752 from olaky/SPARK-39259-spark_3_3.
Authored-by: Ole Sasse <ol...@databricks.com>
Signed-off-by: Max Gekk <ma...@gmail.com>
---
.../sql/catalyst/optimizer/finishAnalysis.scala | 41 +++++-----
.../spark/sql/catalyst/plans/QueryPlan.scala | 11 ++-
.../optimizer/ComputeCurrentTimeSuite.scala | 89 ++++++++++++++++------
3 files changed, 95 insertions(+), 46 deletions(-)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala
index ef9c4b9af40..242c799dd22 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala
@@ -17,14 +17,16 @@
package org.apache.spark.sql.catalyst.optimizer
-import scala.collection.mutable
+import java.time.{Instant, LocalDateTime}
import org.apache.spark.sql.catalyst.CurrentUserContext.CURRENT_USER
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.catalyst.trees.TreePattern._
-import org.apache.spark.sql.catalyst.util.DateTimeUtils.{convertSpecialDate, convertSpecialTimestamp, convertSpecialTimestampNTZ}
+import org.apache.spark.sql.catalyst.trees.TreePatternBits
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
+import org.apache.spark.sql.catalyst.util.DateTimeUtils.{convertSpecialDate, convertSpecialTimestamp, convertSpecialTimestampNTZ, instantToMicros, localDateTimeToMicros}
import org.apache.spark.sql.connector.catalog.CatalogManager
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
@@ -73,29 +75,30 @@ object RewriteNonCorrelatedExists extends Rule[LogicalPlan] {
*/
object ComputeCurrentTime extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = {
- val currentDates = mutable.Map.empty[String, Literal]
- val timeExpr = CurrentTimestamp()
- val timestamp = timeExpr.eval(EmptyRow).asInstanceOf[Long]
- val currentTime = Literal.create(timestamp, timeExpr.dataType)
+ val instant = Instant.now()
+ val currentTimestampMicros = instantToMicros(instant)
+ val currentTime = Literal.create(currentTimestampMicros, TimestampType)
val timezone = Literal.create(conf.sessionLocalTimeZone, StringType)
- val localTimestamps = mutable.Map.empty[String, Literal]
- plan.transformAllExpressionsWithPruning(_.containsPattern(CURRENT_LIKE)) {
- case currentDate @ CurrentDate(Some(timeZoneId)) =>
- currentDates.getOrElseUpdate(timeZoneId, {
- Literal.create(currentDate.eval().asInstanceOf[Int], DateType)
- })
- case CurrentTimestamp() | Now() => currentTime
- case CurrentTimeZone() => timezone
- case localTimestamp @ LocalTimestamp(Some(timeZoneId)) =>
- localTimestamps.getOrElseUpdate(timeZoneId, {
- Literal.create(localTimestamp.eval().asInstanceOf[Long], TimestampNTZType)
- })
+ def transformCondition(treePatternbits: TreePatternBits): Boolean = {
+ treePatternbits.containsPattern(CURRENT_LIKE)
+ }
+
+ plan.transformDownWithSubqueries(transformCondition) {
+ case subQuery =>
+ subQuery.transformAllExpressionsWithPruning(transformCondition) {
+ case cd: CurrentDate =>
+ Literal.create(DateTimeUtils.microsToDays(currentTimestampMicros, cd.zoneId), DateType)
+ case CurrentTimestamp() | Now() => currentTime
+ case CurrentTimeZone() => timezone
+ case localTimestamp: LocalTimestamp =>
+ val asDateTime = LocalDateTime.ofInstant(instant, localTimestamp.zoneId)
+ Literal.create(localDateTimeToMicros(asDateTime), TimestampNTZType)
+ }
}
}
}
-
/**
* Replaces the expression of CurrentDatabase with the current database name.
* Replaces the expression of CurrentCatalog with the current catalog name.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
index 0f8df5df376..d0283f4d367 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
@@ -454,7 +454,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]]
* to rewrite the whole plan, include its subqueries, in one go.
*/
def transformWithSubqueries(f: PartialFunction[PlanType, PlanType]): PlanType =
- transformDownWithSubqueries(f)
+ transformDownWithSubqueries(AlwaysProcess.fn, UnknownRuleId)(f)
/**
* Returns a copy of this node where the given partial function has been recursively applied
@@ -479,7 +479,10 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]]
* first to this node, then this node's subqueries and finally this node's children.
* When the partial function does not apply to a given node, it is left unchanged.
*/
- def transformDownWithSubqueries(f: PartialFunction[PlanType, PlanType]): PlanType = {
+ def transformDownWithSubqueries(
+ cond: TreePatternBits => Boolean = AlwaysProcess.fn, ruleId: RuleId = UnknownRuleId)
+ (f: PartialFunction[PlanType, PlanType])
+: PlanType = {
val g: PartialFunction[PlanType, PlanType] = new PartialFunction[PlanType, PlanType] {
override def isDefinedAt(x: PlanType): Boolean = true
@@ -487,13 +490,13 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]]
val transformed = f.applyOrElse[PlanType, PlanType](plan, identity)
transformed transformExpressionsDown {
case planExpression: PlanExpression[PlanType] =>
- val newPlan = planExpression.plan.transformDownWithSubqueries(f)
+ val newPlan = planExpression.plan.transformDownWithSubqueries(cond, ruleId)(f)
planExpression.withNewPlan(newPlan)
}
}
}
- transformDown(g)
+ transformDownWithPruning(cond, ruleId)(g)
}
/**
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ComputeCurrentTimeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ComputeCurrentTimeSuite.scala
index 9b04dcddfb2..c034906c09b 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ComputeCurrentTimeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ComputeCurrentTimeSuite.scala
@@ -19,10 +19,13 @@ package org.apache.spark.sql.catalyst.optimizer
import java.time.{LocalDateTime, ZoneId}
+import scala.collection.JavaConverters.mapAsScalaMap
+import scala.concurrent.duration._
+
import org.apache.spark.sql.catalyst.dsl.plans._
-import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentDate, CurrentTimestamp, CurrentTimeZone, Literal, LocalTimestamp}
+import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentDate, CurrentTimestamp, CurrentTimeZone, InSubquery, ListQuery, Literal, LocalTimestamp, Now}
import org.apache.spark.sql.catalyst.plans.PlanTest
-import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project}
+import org.apache.spark.sql.catalyst.plans.logical.{Filter, LocalRelation, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.internal.SQLConf
@@ -41,11 +44,7 @@ class ComputeCurrentTimeSuite extends PlanTest {
val plan = Optimize.execute(in.analyze).asInstanceOf[Project]
val max = (System.currentTimeMillis() + 1) * 1000
- val lits = new scala.collection.mutable.ArrayBuffer[Long]
- plan.transformAllExpressions { case e: Literal =>
- lits += e.value.asInstanceOf[Long]
- e
- }
+ val lits = literals[Long](plan)
assert(lits.size == 2)
assert(lits(0) >= min && lits(0) <= max)
assert(lits(1) >= min && lits(1) <= max)
@@ -59,11 +58,7 @@ class ComputeCurrentTimeSuite extends PlanTest {
val plan = Optimize.execute(in.analyze).asInstanceOf[Project]
val max = DateTimeUtils.currentDate(ZoneId.systemDefault())
- val lits = new scala.collection.mutable.ArrayBuffer[Int]
- plan.transformAllExpressions { case e: Literal =>
- lits += e.value.asInstanceOf[Int]
- e
- }
+ val lits = literals[Int](plan)
assert(lits.size == 2)
assert(lits(0) >= min && lits(0) <= max)
assert(lits(1) >= min && lits(1) <= max)
@@ -73,13 +68,9 @@ class ComputeCurrentTimeSuite extends PlanTest {
test("SPARK-33469: Add current_timezone function") {
val in = Project(Seq(Alias(CurrentTimeZone(), "c")()), LocalRelation())
val plan = Optimize.execute(in.analyze).asInstanceOf[Project]
- val lits = new scala.collection.mutable.ArrayBuffer[String]
- plan.transformAllExpressions { case e: Literal =>
- lits += e.value.asInstanceOf[UTF8String].toString
- e
- }
+ val lits = literals[UTF8String](plan)
assert(lits.size == 1)
- assert(lits.head == SQLConf.get.sessionLocalTimeZone)
+ assert(lits.head == UTF8String.fromString(SQLConf.get.sessionLocalTimeZone))
}
test("analyzer should replace localtimestamp with literals") {
@@ -92,14 +83,66 @@ class ComputeCurrentTimeSuite extends PlanTest {
val plan = Optimize.execute(in.analyze).asInstanceOf[Project]
val max = DateTimeUtils.localDateTimeToMicros(LocalDateTime.now(zoneId))
- val lits = new scala.collection.mutable.ArrayBuffer[Long]
- plan.transformAllExpressions { case e: Literal =>
- lits += e.value.asInstanceOf[Long]
- e
- }
+ val lits = literals[Long](plan)
assert(lits.size == 2)
assert(lits(0) >= min && lits(0) <= max)
assert(lits(1) >= min && lits(1) <= max)
assert(lits(0) == lits(1))
}
+
+ test("analyzer should use equal timestamps across subqueries") {
+ val timestampInSubQuery = Project(Seq(Alias(LocalTimestamp(), "timestamp1")()), LocalRelation())
+ val listSubQuery = ListQuery(timestampInSubQuery)
+ val valueSearchedInSubQuery = Seq(Alias(LocalTimestamp(), "timestamp2")())
+ val inFilterWithSubQuery = InSubquery(valueSearchedInSubQuery, listSubQuery)
+ val input = Project(Nil, Filter(inFilterWithSubQuery, LocalRelation()))
+
+ val plan = Optimize.execute(input.analyze).asInstanceOf[Project]
+
+ val lits = literals[Long](plan)
+ assert(lits.size == 3) // transformDownWithSubqueries covers the inner timestamp twice
+ assert(lits.toSet.size == 1)
+ }
+
+ test("analyzer should use consistent timestamps for different timezones") {
+ val localTimestamps = mapAsScalaMap(ZoneId.SHORT_IDS)
+ .map { case (zoneId, _) => Alias(LocalTimestamp(Some(zoneId)), zoneId)() }.toSeq
+ val input = Project(localTimestamps, LocalRelation())
+
+ val plan = Optimize.execute(input).asInstanceOf[Project]
+
+ val lits = literals[Long](plan)
+ assert(lits.size === localTimestamps.size)
+ // there are timezones with a 30 or 45 minute offset
+ val offsetsFromQuarterHour = lits.map( _ % Duration(15, MINUTES).toMicros).toSet
+ assert(offsetsFromQuarterHour.size == 1)
+ }
+
+ test("analyzer should use consistent timestamps for different timestamp functions") {
+ val differentTimestamps = Seq(
+ Alias(CurrentTimestamp(), "currentTimestamp")(),
+ Alias(Now(), "now")(),
+ Alias(LocalTimestamp(Some("PLT")), "localTimestampWithTimezone")()
+ )
+ val input = Project(differentTimestamps, LocalRelation())
+
+ val plan = Optimize.execute(input).asInstanceOf[Project]
+
+ val lits = literals[Long](plan)
+ assert(lits.size === differentTimestamps.size)
+ // there are timezones with a 30 or 45 minute offset
+ val offsetsFromQuarterHour = lits.map( _ % Duration(15, MINUTES).toMicros).toSet
+ assert(offsetsFromQuarterHour.size == 1)
+ }
+
+ private def literals[T](plan: LogicalPlan): Seq[T] = {
+ val literals = new scala.collection.mutable.ArrayBuffer[T]
+ plan.transformWithSubqueries { case subQuery =>
+ subQuery.transformAllExpressions { case expression: Literal =>
+ literals += expression.value.asInstanceOf[T]
+ expression
+ }
+ }
+ literals.asInstanceOf[Seq[T]]
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org