You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2021/07/29 05:52:35 UTC

[spark] branch master updated: [SPARK-36323][SQL] Support ANSI interval literals for TimeWindow

This is an automated email from the ASF dual-hosted git repository.

maxgekk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new db18866  [SPARK-36323][SQL] Support ANSI interval literals for TimeWindow
db18866 is described below

commit db18866742a4641e7119f637024bc89a3f048634
Author: Kousuke Saruta <sa...@oss.nttdata.com>
AuthorDate: Thu Jul 29 08:51:51 2021 +0300

    [SPARK-36323][SQL] Support ANSI interval literals for TimeWindow
    
    ### What changes were proposed in this pull request?
    
    This PR proposes to support ANSI interval literals for `TimeWindow`.
    
    ### Why are the changes needed?
    
    Watermark also supports ANSI interval literals so it's great to support for `TimeWindow`.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    New test.
    
    Closes #33551 from sarutak/window-interval.
    
    Authored-by: Kousuke Saruta <sa...@oss.nttdata.com>
    Signed-off-by: Max Gekk <ma...@gmail.com>
---
 .../sql/catalyst/expressions/TimeWindow.scala      |  5 +-
 .../spark/sql/catalyst/util/IntervalUtils.scala    | 23 +++++++-
 .../spark/sql/errors/QueryCompilationErrors.scala  |  4 +-
 .../sql/catalyst/expressions/TimeWindowSuite.scala | 62 +++++++++++++++++++---
 .../main/scala/org/apache/spark/sql/Dataset.scala  | 20 +------
 5 files changed, 82 insertions(+), 32 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala
index 8475c1f..2f08fd7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala
@@ -26,7 +26,6 @@ import org.apache.spark.sql.catalyst.util.DateTimeConstants.MICROS_PER_DAY
 import org.apache.spark.sql.catalyst.util.IntervalUtils
 import org.apache.spark.sql.errors.QueryCompilationErrors
 import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.UTF8String
 
 case class TimeWindow(
     timeColumn: Expression,
@@ -110,12 +109,12 @@ object TimeWindow {
    *         precision.
    */
   def getIntervalInMicroSeconds(interval: String): Long = {
-    val cal = IntervalUtils.stringToInterval(UTF8String.fromString(interval))
+    val cal = IntervalUtils.fromIntervalString(interval)
     if (cal.months != 0) {
       throw new IllegalArgumentException(
         s"Intervals greater than a month is not supported ($interval).")
     }
-    cal.days * MICROS_PER_DAY + cal.microseconds
+    Math.addExact(Math.multiplyExact(cal.days, MICROS_PER_DAY), cal.microseconds)
   }
 
   /**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala
index dc6c02e..62a2657 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala
@@ -25,11 +25,14 @@ import java.util.concurrent.TimeUnit
 import scala.collection.mutable
 import scala.util.control.NonFatal
 
+import org.apache.spark.sql.catalyst.expressions.Literal
+import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
 import org.apache.spark.sql.catalyst.util.DateTimeConstants._
 import org.apache.spark.sql.catalyst.util.DateTimeUtils.millisToMicros
 import org.apache.spark.sql.catalyst.util.IntervalStringStyles.{ANSI_STYLE, HIVE_STYLE, IntervalStyle}
-import org.apache.spark.sql.errors.QueryExecutionErrors
+import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
 import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types._
 import org.apache.spark.sql.types.{DayTimeIntervalType => DT, Decimal, YearMonthIntervalType => YM}
 import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
 
@@ -433,6 +436,24 @@ object IntervalUtils {
     }
   }
 
+  /**
+   * Parse all kinds of interval literals including unit-to-unit form and unit list form
+   */
+  def fromIntervalString(input: String): CalendarInterval = try {
+    if (input.toLowerCase(Locale.ROOT).trim.startsWith("interval")) {
+      CatalystSqlParser.parseExpression(input) match {
+        case Literal(months: Int, _: YearMonthIntervalType) => new CalendarInterval(months, 0, 0)
+        case Literal(micros: Long, _: DayTimeIntervalType) => new CalendarInterval(0, 0, micros)
+        case Literal(cal: CalendarInterval, CalendarIntervalType) => cal
+      }
+    } else {
+      stringToInterval(UTF8String.fromString(input))
+    }
+  } catch {
+    case NonFatal(e) =>
+      throw QueryCompilationErrors.cannotParseIntervalError(input, e)
+  }
+
   private val dayTimePatternLegacy =
     "^([+|-])?((\\d+) )?((\\d+):)?(\\d+):(\\d+)(\\.(\\d+))?$".r
 
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
index 1421fa3..b62729a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
@@ -2261,8 +2261,8 @@ private[spark] object QueryCompilationErrors {
       s"""Cannot resolve column name "$colName" among (${fieldsStr})${extraMsg}""")
   }
 
-  def cannotParseTimeDelayError(delayThreshold: String, e: Throwable): Throwable = {
-    new AnalysisException(s"Unable to parse time delay '$delayThreshold'", cause = Some(e))
+  def cannotParseIntervalError(delayThreshold: String, e: Throwable): Throwable = {
+    new AnalysisException(s"Unable to parse '$delayThreshold'", cause = Some(e))
   }
 
   def invalidJoinTypeInJoinWithError(joinType: JoinType): Throwable = {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeWindowSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeWindowSuite.scala
index a4860fa..faa8e6f 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeWindowSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeWindowSuite.scala
@@ -17,10 +17,14 @@
 
 package org.apache.spark.sql.catalyst.expressions
 
+import scala.reflect.ClassTag
+
 import org.scalatest.PrivateMethodTester
 
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.util.DateTimeConstants._
+import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types.{LongType, StructField, StructType, TimestampNTZType, TimestampType}
 
 class TimeWindowSuite extends SparkFunSuite with ExpressionEvalHelper with PrivateMethodTester {
@@ -31,16 +35,16 @@ class TimeWindowSuite extends SparkFunSuite with ExpressionEvalHelper with Priva
     }
   }
 
-  private def checkErrorMessage(msg: String, value: String): Unit = {
+  private def checkErrorMessage[E <: Exception : ClassTag](msg: String, value: String): Unit = {
     val validDuration = "10 second"
     val validTime = "5 second"
-    val e1 = intercept[IllegalArgumentException] {
+    val e1 = intercept[E] {
       TimeWindow(Literal(10L), value, validDuration, validTime).windowDuration
     }
-    val e2 = intercept[IllegalArgumentException] {
+    val e2 = intercept[E] {
       TimeWindow(Literal(10L), validDuration, value, validTime).slideDuration
     }
-    val e3 = intercept[IllegalArgumentException] {
+    val e3 = intercept[E] {
       TimeWindow(Literal(10L), validDuration, validDuration, value).startTime
     }
     Seq(e1, e2, e3).foreach { e =>
@@ -50,18 +54,18 @@ class TimeWindowSuite extends SparkFunSuite with ExpressionEvalHelper with Priva
 
   test("blank intervals throw exception") {
     for (blank <- Seq(null, " ", "\n", "\t")) {
-      checkErrorMessage(
+      checkErrorMessage[AnalysisException](
         "The window duration, slide duration and start time cannot be null or blank.", blank)
     }
   }
 
   test("invalid intervals throw exception") {
-    checkErrorMessage(
+    checkErrorMessage[AnalysisException](
       "did not correspond to a valid interval string.", "2 apples")
   }
 
   test("intervals greater than a month throws exception") {
-    checkErrorMessage(
+    checkErrorMessage[IllegalArgumentException](
       "Intervals greater than or equal to a month is not supported (1 month).", "1 month")
   }
 
@@ -111,7 +115,7 @@ class TimeWindowSuite extends SparkFunSuite with ExpressionEvalHelper with Priva
   }
 
   test("parse sql expression for duration in microseconds - invalid interval") {
-    intercept[IllegalArgumentException] {
+    intercept[AnalysisException] {
       TimeWindow.invokePrivate(parseExpression(Literal("2 apples")))
     }
   }
@@ -147,4 +151,46 @@ class TimeWindowSuite extends SparkFunSuite with ExpressionEvalHelper with Priva
     assert(timestampNTZWindow.dataType == StructType(
       Seq(StructField("start", TimestampNTZType), StructField("end", TimestampNTZType))))
   }
+
+  Seq("true", "false").foreach { legacyIntervalEnabled =>
+    test("SPARK-36323: Support ANSI interval literals for TimeWindow " +
+      s"(${SQLConf.LEGACY_INTERVAL_ENABLED.key}=$legacyIntervalEnabled)") {
+      withSQLConf(SQLConf.LEGACY_INTERVAL_ENABLED.key -> legacyIntervalEnabled) {
+        Seq(
+          // Conventional form and some variants
+          (Seq("3 days", "Interval 3 day", "inTerval '3' day"), 3 * MICROS_PER_DAY),
+          (Seq(" 5 hours", "INTERVAL 5 hour", "interval '5' hour"), 5 * MICROS_PER_HOUR),
+          (Seq("\t8 minutes", "interval 8 minute", "interval '8' minute"), 8 * MICROS_PER_MINUTE),
+          (Seq(
+            "10 seconds", "interval 10 second", "interval '10' second"), 10 * MICROS_PER_SECOND),
+          (Seq(
+            "1 day 2 hours 3 minutes 4 seconds",
+            " interval 1 day 2 hours 3 minutes 4 seconds",
+            "\tinterval '1' day '2' hours '3' minutes '4' seconds",
+            "interval '1 2:3:4' day to second"),
+            MICROS_PER_DAY + 2 * MICROS_PER_HOUR + 3 * MICROS_PER_MINUTE + 4 * MICROS_PER_SECOND)
+        ).foreach { case (intervalVariants, expectedMs) =>
+          intervalVariants.foreach { case interval =>
+            val timeWindow = TimeWindow(Literal(10L, TimestampType), interval, interval, interval)
+            val expected =
+              TimeWindow(Literal(10L, TimestampType), expectedMs, expectedMs, expectedMs)
+            assert(timeWindow === expected)
+          }
+        }
+
+        // year-month interval literals are not supported for TimeWindow.
+        Seq(
+          "1 years", "interval 1 year", "interval '1' year",
+          "1 months", "interval 1 month", "interval '1' month",
+          " 1 year 2 months",
+          "interval 1 year 2 month",
+          "interval '1' year '2' month",
+          "\tinterval '1-2' year to month").foreach { interval =>
+          intercept[IllegalArgumentException] {
+            TimeWindow(Literal(10L, TimestampType), interval, interval, interval)
+          }
+        }
+      }
+    }
+  }
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 88dbb9d..87b3bdd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -18,7 +18,6 @@
 package org.apache.spark.sql
 
 import java.io.{ByteArrayOutputStream, CharArrayWriter, DataOutputStream}
-import java.util.Locale
 
 import scala.annotation.varargs
 import scala.collection.JavaConverters._
@@ -44,7 +43,7 @@ import org.apache.spark.sql.catalyst.encoders._
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JSONOptions}
 import org.apache.spark.sql.catalyst.optimizer.CombineUnions
-import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException, ParserUtils}
+import org.apache.spark.sql.catalyst.parser.{ParseException, ParserUtils}
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, PartitioningCollection}
@@ -65,7 +64,6 @@ import org.apache.spark.sql.types._
 import org.apache.spark.sql.util.SchemaUtils
 import org.apache.spark.storage.StorageLevel
 import org.apache.spark.unsafe.array.ByteArrayMethods
-import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
 import org.apache.spark.util.Utils
 
 private[sql] object Dataset {
@@ -741,21 +739,7 @@ class Dataset[T] private[sql](
   // We only accept an existing column name, not a derived column here as a watermark that is
   // defined on a derived column cannot referenced elsewhere in the plan.
   def withWatermark(eventTime: String, delayThreshold: String): Dataset[T] = withTypedPlan {
-    val parsedDelay = try {
-      if (delayThreshold.toLowerCase(Locale.ROOT).trim.startsWith("interval")) {
-        CatalystSqlParser.parseExpression(delayThreshold) match {
-          case Literal(months: Int, _: YearMonthIntervalType) =>
-            new CalendarInterval(months, 0, 0)
-          case Literal(micros: Long, _: DayTimeIntervalType) =>
-            new CalendarInterval(0, 0, micros)
-        }
-      } else {
-        IntervalUtils.stringToInterval(UTF8String.fromString(delayThreshold))
-      }
-    } catch {
-      case NonFatal(e) =>
-        throw QueryCompilationErrors.cannotParseTimeDelayError(delayThreshold, e)
-    }
+    val parsedDelay = IntervalUtils.fromIntervalString(delayThreshold)
     require(!IntervalUtils.isNegative(parsedDelay),
       s"delay threshold ($delayThreshold) should not be negative.")
     EliminateEventTimeWatermark(

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org