You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ge...@apache.org on 2021/07/19 11:25:17 UTC

[spark] branch branch-3.2 updated: [SPARK-36091][SQL] Support TimestampNTZ type in expression TimeWindow

This is an automated email from the ASF dual-hosted git repository.

gengliang pushed a commit to branch branch-3.2
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.2 by this push:
     new ab4c160  [SPARK-36091][SQL] Support TimestampNTZ type in expression TimeWindow
ab4c160 is described below

commit ab4c160880faf6bc5c85049b41cea58500a654f8
Author: gengjiaan <ge...@360.cn>
AuthorDate: Mon Jul 19 19:23:39 2021 +0800

    [SPARK-36091][SQL] Support TimestampNTZ type in expression TimeWindow
    
    ### What changes were proposed in this pull request?
    The current implement of `TimeWindow` only supports `TimestampType`. Spark added a new type `TimestampNTZType`, so we should support `TimestampNTZType` in expression `TimeWindow`.
    
    ### Why are the changes needed?
     `TimestampNTZType` similar to `TimestampType`, we should support `TimestampNTZType` in expression `TimeWindow`.
    
    ### Does this PR introduce _any_ user-facing change?
    'Yes'.
    `TimeWindow` will accepts `TimestampNTZType`.
    
    ### How was this patch tested?
    New tests.
    
    Closes #33341 from beliefer/SPARK-36091.
    
    Lead-authored-by: gengjiaan <ge...@360.cn>
    Co-authored-by: Jiaan Geng <be...@163.com>
    Signed-off-by: Gengliang Wang <ge...@apache.org>
    (cherry picked from commit 7aa01798c5772426147d3f03dd121b57a550a050)
    Signed-off-by: Gengliang Wang <ge...@apache.org>
---
 .../spark/sql/catalyst/analysis/Analyzer.scala     |  13 +-
 .../sql/catalyst/expressions/TimeWindow.scala      |   6 +-
 .../sql/catalyst/expressions/TimeWindowSuite.scala |  16 +-
 .../spark/sql/DataFrameTimeWindowingSuite.scala    | 454 ++++++++++++++-------
 4 files changed, 321 insertions(+), 168 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index e8ab874..ed7ad7f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -3874,9 +3874,9 @@ object TimeWindowing extends Rule[LogicalPlan] {
           case _ => Metadata.empty
         }
 
-        def getWindow(i: Int, overlappingWindows: Int): Expression = {
+        def getWindow(i: Int, overlappingWindows: Int, dataType: DataType): Expression = {
           val division = (PreciseTimestampConversion(
-            window.timeColumn, TimestampType, LongType) - window.startTime) / window.slideDuration
+            window.timeColumn, dataType, LongType) - window.startTime) / window.slideDuration
           val ceil = Ceil(division)
           // if the division is equal to the ceiling, our record is the start of a window
           val windowId = CaseWhen(Seq((ceil === division, ceil + 1)), Some(ceil))
@@ -3886,9 +3886,9 @@ object TimeWindowing extends Rule[LogicalPlan] {
 
           CreateNamedStruct(
             Literal(WINDOW_START) ::
-              PreciseTimestampConversion(windowStart, LongType, TimestampType) ::
+              PreciseTimestampConversion(windowStart, LongType, dataType) ::
               Literal(WINDOW_END) ::
-              PreciseTimestampConversion(windowEnd, LongType, TimestampType) ::
+              PreciseTimestampConversion(windowEnd, LongType, dataType) ::
               Nil)
         }
 
@@ -3896,7 +3896,7 @@ object TimeWindowing extends Rule[LogicalPlan] {
           WINDOW_COL_NAME, window.dataType, metadata = metadata)()
 
         if (window.windowDuration == window.slideDuration) {
-          val windowStruct = Alias(getWindow(0, 1), WINDOW_COL_NAME)(
+          val windowStruct = Alias(getWindow(0, 1, window.timeColumn.dataType), WINDOW_COL_NAME)(
             exprId = windowAttr.exprId, explicitMetadata = Some(metadata))
 
           val replacedPlan = p transformExpressions {
@@ -3913,7 +3913,8 @@ object TimeWindowing extends Rule[LogicalPlan] {
           val overlappingWindows =
             math.ceil(window.windowDuration * 1.0 / window.slideDuration).toInt
           val windows =
-            Seq.tabulate(overlappingWindows)(i => getWindow(i, overlappingWindows))
+            Seq.tabulate(overlappingWindows)(i =>
+              getWindow(i, overlappingWindows, window.timeColumn.dataType))
 
           val projections = windows.map(_ +: child.output)
 
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala
index e79e8d7..8475c1f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala
@@ -60,10 +60,10 @@ case class TimeWindow(
   }
 
   override def child: Expression = timeColumn
-  override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType)
+  override def inputTypes: Seq[AbstractDataType] = Seq(AnyTimestampType)
   override def dataType: DataType = new StructType()
-    .add(StructField("start", TimestampType))
-    .add(StructField("end", TimestampType))
+    .add(StructField("start", child.dataType))
+    .add(StructField("end", child.dataType))
   override def prettyName: String = "window"
   final override val nodePatterns: Seq[TreePattern] = Seq(TIME_WINDOW)
 
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeWindowSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeWindowSuite.scala
index e9d2178..a4860fa 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeWindowSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeWindowSuite.scala
@@ -21,7 +21,7 @@ import org.scalatest.PrivateMethodTester
 
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.types.LongType
+import org.apache.spark.sql.types.{LongType, StructField, StructType, TimestampNTZType, TimestampType}
 
 class TimeWindowSuite extends SparkFunSuite with ExpressionEvalHelper with PrivateMethodTester {
 
@@ -133,4 +133,18 @@ class TimeWindowSuite extends SparkFunSuite with ExpressionEvalHelper with Priva
       assert(applyValue == constructed)
     }
   }
+
+  test("SPARK-36091: Support TimestampNTZ type in expression TimeWindow") {
+    val timestampWindow =
+      TimeWindow(Literal(10L, TimestampType), "10 seconds", "10 seconds", "0 seconds")
+    assert(timestampWindow.child.dataType == TimestampType)
+    assert(timestampWindow.dataType == StructType(
+      Seq(StructField("start", TimestampType), StructField("end", TimestampType))))
+
+    val timestampNTZWindow =
+      TimeWindow(Literal(10L, TimestampNTZType), "10 seconds", "10 seconds", "0 seconds")
+    assert(timestampNTZWindow.child.dataType == TimestampNTZType)
+    assert(timestampNTZWindow.dataType == StructType(
+      Seq(StructField("start", TimestampNTZType), StructField("end", TimestampNTZType))))
+  }
 }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala
index 2ef43dc..c385d9f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala
@@ -17,184 +17,249 @@
 
 package org.apache.spark.sql
 
-import org.apache.spark.sql.catalyst.plans.logical.Expand
+import java.time.LocalDateTime
+
+import org.apache.spark.sql.catalyst.expressions.AttributeReference
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand}
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.test.SharedSparkSession
-import org.apache.spark.sql.types.StringType
+import org.apache.spark.sql.types._
 
 class DataFrameTimeWindowingSuite extends QueryTest with SharedSparkSession {
 
   import testImplicits._
 
   test("simple tumbling window with record at window start") {
-    val df = Seq(
-      ("2016-03-27 19:39:30", 1, "a")).toDF("time", "value", "id")
-
-    checkAnswer(
-      df.groupBy(window($"time", "10 seconds"))
-        .agg(count("*").as("counts"))
-        .orderBy($"window.start".asc)
-        .select($"window.start".cast("string"), $"window.end".cast("string"), $"counts"),
-      Seq(
-        Row("2016-03-27 19:39:30", "2016-03-27 19:39:40", 1)
+    val df1 = Seq(("2016-03-27 19:39:30", 1, "a")).toDF("time", "value", "id")
+    val df2 = Seq((LocalDateTime.parse("2016-03-27T19:39:30"), 1, "a")).toDF("time", "value", "id")
+
+    Seq(df1, df2).foreach { df =>
+      checkAnswer(
+        df.groupBy(window($"time", "10 seconds"))
+          .agg(count("*").as("counts"))
+          .orderBy($"window.start".asc)
+          .select($"window.start".cast("string"), $"window.end".cast("string"), $"counts"),
+        Seq(
+          Row("2016-03-27 19:39:30", "2016-03-27 19:39:40", 1)
+        )
       )
-    )
+    }
   }
 
   test("SPARK-21590: tumbling window using negative start time") {
-    val df = Seq(
+    val df1 = Seq(
       ("2016-03-27 19:39:30", 1, "a"),
       ("2016-03-27 19:39:25", 2, "a")).toDF("time", "value", "id")
+    val df2 = Seq((LocalDateTime.parse("2016-03-27T19:39:30"), 1, "a"),
+      (LocalDateTime.parse("2016-03-27T19:39:25"), 2, "a")).toDF("time", "value", "id")
 
-    checkAnswer(
-      df.groupBy(window($"time", "10 seconds", "10 seconds", "-5 seconds"))
-        .agg(count("*").as("counts"))
-        .orderBy($"window.start".asc)
-        .select($"window.start".cast("string"), $"window.end".cast("string"), $"counts"),
-      Seq(
-        Row("2016-03-27 19:39:25", "2016-03-27 19:39:35", 2)
+    Seq(df1, df2).foreach { df =>
+      checkAnswer(
+        df.groupBy(window($"time", "10 seconds", "10 seconds", "-5 seconds"))
+          .agg(count("*").as("counts"))
+          .orderBy($"window.start".asc)
+          .select($"window.start".cast("string"), $"window.end".cast("string"), $"counts"),
+        Seq(
+          Row("2016-03-27 19:39:25", "2016-03-27 19:39:35", 2)
+        )
       )
-    )
+    }
   }
 
   test("tumbling window groupBy statement") {
-    val df = Seq(
+    val df1 = Seq(
       ("2016-03-27 19:39:34", 1, "a"),
       ("2016-03-27 19:39:56", 2, "a"),
       ("2016-03-27 19:39:27", 4, "b")).toDF("time", "value", "id")
+    val df2 = Seq(
+      (LocalDateTime.parse("2016-03-27T19:39:34"), 1, "a"),
+      (LocalDateTime.parse("2016-03-27T19:39:56"), 2, "a"),
+      (LocalDateTime.parse("2016-03-27T19:39:27"), 4, "b")).toDF("time", "value", "id")
 
-    checkAnswer(
-      df.groupBy(window($"time", "10 seconds"))
-        .agg(count("*").as("counts"))
-        .orderBy($"window.start".asc)
-        .select("counts"),
-      Seq(Row(1), Row(1), Row(1))
-    )
+    Seq(df1, df2).foreach { df =>
+      checkAnswer(
+        df.groupBy(window($"time", "10 seconds"))
+          .agg(count("*").as("counts"))
+          .orderBy($"window.start".asc)
+          .select("counts"),
+        Seq(Row(1), Row(1), Row(1))
+      )
+    }
   }
 
   test("tumbling window groupBy statement with startTime") {
-    val df = Seq(
+    val df1 = Seq(
       ("2016-03-27 19:39:34", 1, "a"),
       ("2016-03-27 19:39:56", 2, "a"),
       ("2016-03-27 19:39:27", 4, "b")).toDF("time", "value", "id")
+    val df2 = Seq(
+      (LocalDateTime.parse("2016-03-27T19:39:34"), 1, "a"),
+      (LocalDateTime.parse("2016-03-27T19:39:56"), 2, "a"),
+      (LocalDateTime.parse("2016-03-27T19:39:27"), 4, "b")).toDF("time", "value", "id")
 
-    checkAnswer(
-      df.groupBy(window($"time", "10 seconds", "10 seconds", "5 seconds"), $"id")
-        .agg(count("*").as("counts"))
-        .orderBy($"window.start".asc)
-        .select("counts"),
-      Seq(Row(1), Row(1), Row(1)))
+    Seq(df1, df2).foreach { df =>
+      checkAnswer(
+        df.groupBy(window($"time", "10 seconds", "10 seconds", "5 seconds"), $"id")
+          .agg(count("*").as("counts"))
+          .orderBy($"window.start".asc)
+          .select("counts"),
+        Seq(Row(1), Row(1), Row(1))
+      )
+    }
   }
 
   test("SPARK-21590: tumbling window groupBy statement with negative startTime") {
-    val df = Seq(
+    val df1 = Seq(
       ("2016-03-27 19:39:34", 1, "a"),
       ("2016-03-27 19:39:56", 2, "a"),
       ("2016-03-27 19:39:27", 4, "b")).toDF("time", "value", "id")
+    val df2 = Seq(
+      (LocalDateTime.parse("2016-03-27T19:39:34"), 1, "a"),
+      (LocalDateTime.parse("2016-03-27T19:39:56"), 2, "a"),
+      (LocalDateTime.parse("2016-03-27T19:39:27"), 4, "b")).toDF("time", "value", "id")
 
-    checkAnswer(
-      df.groupBy(window($"time", "10 seconds", "10 seconds", "-5 seconds"), $"id")
-        .agg(count("*").as("counts"))
-        .orderBy($"window.start".asc)
-        .select("counts"),
-      Seq(Row(1), Row(1), Row(1)))
+    Seq(df1, df2).foreach { df =>
+      checkAnswer(
+        df.groupBy(window($"time", "10 seconds", "10 seconds", "-5 seconds"), $"id")
+          .agg(count("*").as("counts"))
+          .orderBy($"window.start".asc)
+          .select("counts"),
+        Seq(Row(1), Row(1), Row(1))
+      )
+    }
   }
 
   test("tumbling window with multi-column projection") {
-    val df = Seq(
-        ("2016-03-27 19:39:34", 1, "a"),
-        ("2016-03-27 19:39:56", 2, "a"),
-        ("2016-03-27 19:39:27", 4, "b")).toDF("time", "value", "id")
+    val df1 = Seq(
+      ("2016-03-27 19:39:34", 1, "a"),
+      ("2016-03-27 19:39:56", 2, "a"),
+      ("2016-03-27 19:39:27", 4, "b")).toDF("time", "value", "id")
+      .select(window($"time", "10 seconds"), $"value")
+      .orderBy($"window.start".asc)
+      .select($"window.start".cast("string"), $"window.end".cast("string"), $"value")
+    val df2 = Seq(
+      (LocalDateTime.parse("2016-03-27T19:39:34"), 1, "a"),
+      (LocalDateTime.parse("2016-03-27T19:39:56"), 2, "a"),
+      (LocalDateTime.parse("2016-03-27T19:39:27"), 4, "b")).toDF("time", "value", "id")
       .select(window($"time", "10 seconds"), $"value")
       .orderBy($"window.start".asc)
       .select($"window.start".cast("string"), $"window.end".cast("string"), $"value")
 
-    val expands = df.queryExecution.optimizedPlan.find(_.isInstanceOf[Expand])
-    assert(expands.isEmpty, "Tumbling windows shouldn't require expand")
+    Seq(df1, df2).foreach { df =>
+      val expands = df.queryExecution.optimizedPlan.find(_.isInstanceOf[Expand])
+      assert(expands.isEmpty, "Tumbling windows shouldn't require expand")
 
-    checkAnswer(
-      df,
-      Seq(
-        Row("2016-03-27 19:39:20", "2016-03-27 19:39:30", 4),
-        Row("2016-03-27 19:39:30", "2016-03-27 19:39:40", 1),
-        Row("2016-03-27 19:39:50", "2016-03-27 19:40:00", 2)
+      checkAnswer(
+        df,
+        Seq(
+          Row("2016-03-27 19:39:20", "2016-03-27 19:39:30", 4),
+          Row("2016-03-27 19:39:30", "2016-03-27 19:39:40", 1),
+          Row("2016-03-27 19:39:50", "2016-03-27 19:40:00", 2)
+        )
       )
-    )
+    }
   }
 
   test("sliding window grouping") {
-    val df = Seq(
+    val df1 = Seq(
       ("2016-03-27 19:39:34", 1, "a"),
       ("2016-03-27 19:39:56", 2, "a"),
       ("2016-03-27 19:39:27", 4, "b")).toDF("time", "value", "id")
+    val df2 = Seq(
+      (LocalDateTime.parse("2016-03-27T19:39:34"), 1, "a"),
+      (LocalDateTime.parse("2016-03-27T19:39:56"), 2, "a"),
+      (LocalDateTime.parse("2016-03-27T19:39:27"), 4, "b")).toDF("time", "value", "id")
 
-    checkAnswer(
-      df.groupBy(window($"time", "10 seconds", "3 seconds", "0 second"))
-        .agg(count("*").as("counts"))
-        .orderBy($"window.start".asc)
-        .select($"window.start".cast("string"), $"window.end".cast("string"), $"counts"),
-      // 2016-03-27 19:39:27 UTC -> 4 bins
-      // 2016-03-27 19:39:34 UTC -> 3 bins
-      // 2016-03-27 19:39:56 UTC -> 3 bins
-      Seq(
-        Row("2016-03-27 19:39:18", "2016-03-27 19:39:28", 1),
-        Row("2016-03-27 19:39:21", "2016-03-27 19:39:31", 1),
-        Row("2016-03-27 19:39:24", "2016-03-27 19:39:34", 1),
-        Row("2016-03-27 19:39:27", "2016-03-27 19:39:37", 2),
-        Row("2016-03-27 19:39:30", "2016-03-27 19:39:40", 1),
-        Row("2016-03-27 19:39:33", "2016-03-27 19:39:43", 1),
-        Row("2016-03-27 19:39:48", "2016-03-27 19:39:58", 1),
-        Row("2016-03-27 19:39:51", "2016-03-27 19:40:01", 1),
-        Row("2016-03-27 19:39:54", "2016-03-27 19:40:04", 1))
-    )
+    Seq(df1, df2).foreach { df =>
+      checkAnswer(
+        df.groupBy(window($"time", "10 seconds", "3 seconds", "0 second"))
+          .agg(count("*").as("counts"))
+          .orderBy($"window.start".asc)
+          .select($"window.start".cast("string"), $"window.end".cast("string"), $"counts"),
+        // 2016-03-27 19:39:27 UTC -> 4 bins
+        // 2016-03-27 19:39:34 UTC -> 3 bins
+        // 2016-03-27 19:39:56 UTC -> 3 bins
+        Seq(
+          Row("2016-03-27 19:39:18", "2016-03-27 19:39:28", 1),
+          Row("2016-03-27 19:39:21", "2016-03-27 19:39:31", 1),
+          Row("2016-03-27 19:39:24", "2016-03-27 19:39:34", 1),
+          Row("2016-03-27 19:39:27", "2016-03-27 19:39:37", 2),
+          Row("2016-03-27 19:39:30", "2016-03-27 19:39:40", 1),
+          Row("2016-03-27 19:39:33", "2016-03-27 19:39:43", 1),
+          Row("2016-03-27 19:39:48", "2016-03-27 19:39:58", 1),
+          Row("2016-03-27 19:39:51", "2016-03-27 19:40:01", 1),
+          Row("2016-03-27 19:39:54", "2016-03-27 19:40:04", 1))
+      )
+    }
   }
 
   test("sliding window projection") {
-    val df = Seq(
-        ("2016-03-27 19:39:34", 1, "a"),
-        ("2016-03-27 19:39:56", 2, "a"),
-        ("2016-03-27 19:39:27", 4, "b")).toDF("time", "value", "id")
+    val df1 = Seq(
+      ("2016-03-27 19:39:34", 1, "a"),
+      ("2016-03-27 19:39:56", 2, "a"),
+      ("2016-03-27 19:39:27", 4, "b")).toDF("time", "value", "id")
+      .select(window($"time", "10 seconds", "3 seconds", "0 second"), $"value")
+      .orderBy($"window.start".asc, $"value".desc).select("value")
+    val df2 = Seq(
+      (LocalDateTime.parse("2016-03-27T19:39:34"), 1, "a"),
+      (LocalDateTime.parse("2016-03-27T19:39:56"), 2, "a"),
+      (LocalDateTime.parse("2016-03-27T19:39:27"), 4, "b")).toDF("time", "value", "id")
       .select(window($"time", "10 seconds", "3 seconds", "0 second"), $"value")
       .orderBy($"window.start".asc, $"value".desc).select("value")
 
-    val expands = df.queryExecution.optimizedPlan.find(_.isInstanceOf[Expand])
-    assert(expands.nonEmpty, "Sliding windows require expand")
+    Seq(df1, df2).foreach { df =>
+      val expands = df.queryExecution.optimizedPlan.find(_.isInstanceOf[Expand])
+      assert(expands.nonEmpty, "Sliding windows require expand")
 
-    checkAnswer(
-      df,
-      // 2016-03-27 19:39:27 UTC -> 4 bins
-      // 2016-03-27 19:39:34 UTC -> 3 bins
-      // 2016-03-27 19:39:56 UTC -> 3 bins
-      Seq(Row(4), Row(4), Row(4), Row(4), Row(1), Row(1), Row(1), Row(2), Row(2), Row(2))
-    )
+      checkAnswer(
+        df,
+        // 2016-03-27 19:39:27 UTC -> 4 bins
+        // 2016-03-27 19:39:34 UTC -> 3 bins
+        // 2016-03-27 19:39:56 UTC -> 3 bins
+        Seq(Row(4), Row(4), Row(4), Row(4), Row(1), Row(1), Row(1), Row(2), Row(2), Row(2))
+      )
+    }
   }
 
   test("windowing combined with explode expression") {
-    val df = Seq(
+    val df1 = Seq(
       ("2016-03-27 19:39:34", 1, Seq("a", "b")),
       ("2016-03-27 19:39:56", 2, Seq("a", "c", "d"))).toDF("time", "value", "ids")
+    val df2 = Seq(
+      (LocalDateTime.parse("2016-03-27T19:39:34"), 1, Seq("a", "b")),
+      (LocalDateTime.parse("2016-03-27T19:39:56"), 2, Seq("a", "c", "d"))).toDF(
+"time", "value", "ids")
 
-    checkAnswer(
-      df.select(window($"time", "10 seconds"), $"value", explode($"ids"))
-        .orderBy($"window.start".asc).select("value"),
-      // first window exploded to two rows for "a", and "b", second window exploded to 3 rows
-      Seq(Row(1), Row(1), Row(2), Row(2), Row(2))
-    )
+    Seq(df1, df2).foreach { df =>
+      checkAnswer(
+        df.select(window($"time", "10 seconds"), $"value", explode($"ids"))
+          .orderBy($"window.start".asc).select("value"),
+        // first window exploded to two rows for "a", and "b", second window exploded to 3 rows
+        Seq(Row(1), Row(1), Row(2), Row(2), Row(2))
+      )
+    }
   }
 
   test("null timestamps") {
-    val df = Seq(
+    val df1 = Seq(
       ("2016-03-27 09:00:05", 1),
       ("2016-03-27 09:00:32", 2),
       (null, 3),
       (null, 4)).toDF("time", "value")
+    val df2 = Seq(
+      (LocalDateTime.parse("2016-03-27T09:00:05"), 1),
+      (LocalDateTime.parse("2016-03-27T09:00:32"), 2),
+      (null, 3),
+      (null, 4)).toDF("time", "value")
 
-    checkDataset(
-      df.select(window($"time", "10 seconds"), $"value")
-        .orderBy($"window.start".asc)
-        .select("value")
-        .as[Int],
-      1, 2) // null columns are dropped
+    Seq(df1, df2).foreach { df =>
+      checkDataset(
+        df.select(window($"time", "10 seconds"), $"value")
+          .orderBy($"window.start".asc)
+          .select("value")
+          .as[Int],
+        1, 2) // null columns are dropped
+    }
   }
 
   test("time window joins") {
@@ -208,89 +273,135 @@ class DataFrameTimeWindowingSuite extends QueryTest with SharedSparkSession {
       ("2016-03-27 09:00:02", 3),
       ("2016-03-27 09:00:35", 6)).toDF("time", "othervalue")
 
-    checkAnswer(
-      df.select(window($"time", "10 seconds"), $"value").join(
-        df2.select(window($"time", "10 seconds"), $"othervalue"), Seq("window"))
-        .groupBy("window")
-        .agg((sum("value") + sum("othervalue")).as("total"))
-        .orderBy($"window.start".asc).select("total"),
-      Seq(Row(4), Row(8)))
+    val df3 = Seq(
+      (LocalDateTime.parse("2016-03-27T09:00:05"), 1),
+      (LocalDateTime.parse("2016-03-27T09:00:32"), 2),
+      (null, 3),
+      (null, 4)).toDF("time", "value")
+
+    val df4 = Seq(
+      (LocalDateTime.parse("2016-03-27T09:00:02"), 3),
+      (LocalDateTime.parse("2016-03-27T09:00:35"), 6)).toDF("time", "othervalue")
+
+    Seq((df, df2), (df3, df4)).foreach { tuple =>
+      checkAnswer(
+        tuple._1.select(window($"time", "10 seconds"), $"value").join(
+          tuple._2.select(window($"time", "10 seconds"), $"othervalue"), Seq("window"))
+          .groupBy("window")
+          .agg((sum("value") + sum("othervalue")).as("total"))
+          .orderBy($"window.start".asc).select("total"),
+        Seq(Row(4), Row(8))
+      )
+    }
   }
 
   test("negative timestamps") {
-    val df4 = Seq(
+    val df1 = Seq(
       ("1970-01-01 00:00:02", 1),
       ("1970-01-01 00:00:12", 2)).toDF("time", "value")
-    checkAnswer(
-      df4.select(window($"time", "10 seconds", "10 seconds", "5 seconds"), $"value")
-        .orderBy($"window.start".asc)
-        .select($"window.start".cast(StringType), $"window.end".cast(StringType), $"value"),
-      Seq(
-        Row("1969-12-31 23:59:55", "1970-01-01 00:00:05", 1),
-        Row("1970-01-01 00:00:05", "1970-01-01 00:00:15", 2))
-    )
+    val df2 = Seq(
+      (LocalDateTime.parse("1970-01-01T00:00:02"), 1),
+      (LocalDateTime.parse("1970-01-01T00:00:12"), 2)).toDF("time", "value")
+
+    Seq(df1, df2).foreach { df =>
+      checkAnswer(
+        df.select(window($"time", "10 seconds", "10 seconds", "5 seconds"), $"value")
+          .orderBy($"window.start".asc)
+          .select($"window.start".cast(StringType), $"window.end".cast(StringType), $"value"),
+        Seq(
+          Row("1969-12-31 23:59:55", "1970-01-01 00:00:05", 1),
+          Row("1970-01-01 00:00:05", "1970-01-01 00:00:15", 2))
+      )
+    }
   }
 
   test("multiple time windows in a single operator throws nice exception") {
-    val df = Seq(
+    val df1 = Seq(
       ("2016-03-27 09:00:02", 3),
       ("2016-03-27 09:00:35", 6)).toDF("time", "value")
-    val e = intercept[AnalysisException] {
-      df.select(window($"time", "10 second"), window($"time", "15 second")).collect()
+    val df2 = Seq(
+      (LocalDateTime.parse("2016-03-27T09:00:02"), 3),
+      (LocalDateTime.parse("2016-03-27T09:00:35"), 6)).toDF("time", "value")
+
+    Seq(df1, df2).foreach { df =>
+      val e = intercept[AnalysisException] {
+        df.select(window($"time", "10 second"), window($"time", "15 second")).collect()
+      }
+      assert(e.getMessage.contains(
+        "Multiple time/session window expressions would result in a cartesian product"))
     }
-    assert(e.getMessage.contains(
-      "Multiple time/session window expressions would result in a cartesian product"))
   }
 
   test("aliased windows") {
-    val df = Seq(
+    val df1 = Seq(
       ("2016-03-27 19:39:34", 1, Seq("a", "b")),
       ("2016-03-27 19:39:56", 2, Seq("a", "c", "d"))).toDF("time", "value", "ids")
+    val df2 = Seq(
+      (LocalDateTime.parse("2016-03-27T19:39:34"), 1, Seq("a", "b")),
+      (LocalDateTime.parse("2016-03-27T19:39:56"), 2, Seq("a", "c", "d"))).toDF(
+      "time", "value", "ids")
 
-    checkAnswer(
-      df.select(window($"time", "10 seconds").as("time_window"), $"value")
-        .orderBy($"time_window.start".asc)
-        .select("value"),
-      Seq(Row(1), Row(2))
-    )
+    Seq(df1, df2).foreach { df =>
+      checkAnswer(
+        df.select(window($"time", "10 seconds").as("time_window"), $"value")
+          .orderBy($"time_window.start".asc)
+          .select("value"),
+        Seq(Row(1), Row(2))
+      )
+    }
   }
 
   test("millisecond precision sliding windows") {
-    val df = Seq(
+    val df1 = Seq(
       ("2016-03-27 09:00:00.41", 3),
       ("2016-03-27 09:00:00.62", 6),
       ("2016-03-27 09:00:00.715", 8)).toDF("time", "value")
-    checkAnswer(
-      df.groupBy(window($"time", "200 milliseconds", "40 milliseconds", "0 milliseconds"))
-        .agg(count("*").as("counts"))
-        .orderBy($"window.start".asc)
-        .select($"window.start".cast(StringType), $"window.end".cast(StringType), $"counts"),
-      Seq(
-        Row("2016-03-27 09:00:00.24", "2016-03-27 09:00:00.44", 1),
-        Row("2016-03-27 09:00:00.28", "2016-03-27 09:00:00.48", 1),
-        Row("2016-03-27 09:00:00.32", "2016-03-27 09:00:00.52", 1),
-        Row("2016-03-27 09:00:00.36", "2016-03-27 09:00:00.56", 1),
-        Row("2016-03-27 09:00:00.4", "2016-03-27 09:00:00.6", 1),
-        Row("2016-03-27 09:00:00.44", "2016-03-27 09:00:00.64", 1),
-        Row("2016-03-27 09:00:00.48", "2016-03-27 09:00:00.68", 1),
-        Row("2016-03-27 09:00:00.52", "2016-03-27 09:00:00.72", 2),
-        Row("2016-03-27 09:00:00.56", "2016-03-27 09:00:00.76", 2),
-        Row("2016-03-27 09:00:00.6", "2016-03-27 09:00:00.8", 2),
-        Row("2016-03-27 09:00:00.64", "2016-03-27 09:00:00.84", 1),
-        Row("2016-03-27 09:00:00.68", "2016-03-27 09:00:00.88", 1))
-    )
+    val df2 = Seq(
+      (LocalDateTime.parse("2016-03-27T09:00:00.41"), 3),
+      (LocalDateTime.parse("2016-03-27T09:00:00.62"), 6),
+      (LocalDateTime.parse("2016-03-27T09:00:00.715"), 8)).toDF("time", "value")
+
+    Seq(df1, df2).foreach { df =>
+      checkAnswer(
+        df.groupBy(window($"time", "200 milliseconds", "40 milliseconds", "0 milliseconds"))
+          .agg(count("*").as("counts"))
+          .orderBy($"window.start".asc)
+          .select($"window.start".cast(StringType), $"window.end".cast(StringType), $"counts"),
+        Seq(
+          Row("2016-03-27 09:00:00.24", "2016-03-27 09:00:00.44", 1),
+          Row("2016-03-27 09:00:00.28", "2016-03-27 09:00:00.48", 1),
+          Row("2016-03-27 09:00:00.32", "2016-03-27 09:00:00.52", 1),
+          Row("2016-03-27 09:00:00.36", "2016-03-27 09:00:00.56", 1),
+          Row("2016-03-27 09:00:00.4", "2016-03-27 09:00:00.6", 1),
+          Row("2016-03-27 09:00:00.44", "2016-03-27 09:00:00.64", 1),
+          Row("2016-03-27 09:00:00.48", "2016-03-27 09:00:00.68", 1),
+          Row("2016-03-27 09:00:00.52", "2016-03-27 09:00:00.72", 2),
+          Row("2016-03-27 09:00:00.56", "2016-03-27 09:00:00.76", 2),
+          Row("2016-03-27 09:00:00.6", "2016-03-27 09:00:00.8", 2),
+          Row("2016-03-27 09:00:00.64", "2016-03-27 09:00:00.84", 1),
+          Row("2016-03-27 09:00:00.68", "2016-03-27 09:00:00.88", 1))
+      )
+    }
   }
 
   private def withTempTable(f: String => Unit): Unit = {
     val tableName = "temp"
-    Seq(
+    val df1 = Seq(
       ("2016-03-27 19:39:34", 1),
       ("2016-03-27 19:39:56", 2),
-      ("2016-03-27 19:39:27", 4)).toDF("time", "value").createOrReplaceTempView(tableName)
-    try {
-      f(tableName)
-    } finally {
-      spark.catalog.dropTempView(tableName)
+      ("2016-03-27 19:39:27", 4)).toDF("time", "value")
+    val df2 = Seq(
+      (LocalDateTime.parse("2016-03-27T19:39:34"), 1),
+      (LocalDateTime.parse("2016-03-27T19:39:56"), 2),
+      (LocalDateTime.parse("2016-03-27T19:39:27"), 4)).toDF("time", "value")
+
+    Seq(df1, df2).foreach { df =>
+      df.createOrReplaceTempView(tableName)
+      try {
+        f(tableName)
+      } finally {
+        spark.catalog.dropTempView(tableName)
+      }
     }
   }
 
@@ -352,4 +463,31 @@ class DataFrameTimeWindowingSuite extends QueryTest with SharedSparkSession {
       )
     }
   }
+
+  test("SPARK-36091: Support TimestampNTZ type in expression TimeWindow") {
+    val df1 = Seq(
+      ("2016-03-27 19:39:30", 1, "a"),
+      ("2016-03-27 19:39:25", 2, "a")).toDF("time", "value", "id")
+    val df2 = Seq((LocalDateTime.parse("2016-03-27T19:39:30"), 1, "a"),
+      (LocalDateTime.parse("2016-03-27T19:39:25"), 2, "a")).toDF("time", "value", "id")
+    val type1 = StructType(
+      Seq(StructField("start", TimestampType), StructField("end", TimestampType)))
+    val type2 = StructType(
+      Seq(StructField("start", TimestampNTZType), StructField("end", TimestampNTZType)))
+
+    Seq((df1, type1), (df2, type2)).foreach { tuple =>
+      val logicalPlan =
+        tuple._1.groupBy(window($"time", "10 seconds", "10 seconds", "-5 seconds"))
+          .agg(count("*").as("counts"))
+          .orderBy($"window.start".asc)
+          .select($"window.start".cast("string"), $"window.end".cast("string"), $"counts")
+      val aggregate = logicalPlan.queryExecution.analyzed.children(0).children(0)
+      assert(aggregate.isInstanceOf[Aggregate])
+      val timeWindow = aggregate.asInstanceOf[Aggregate].groupingExpressions(0)
+      assert(timeWindow.isInstanceOf[AttributeReference])
+      val attributeReference = timeWindow.asInstanceOf[AttributeReference]
+      assert(attributeReference.name == "window")
+      assert(attributeReference.dataType == tuple._2)
+    }
+  }
 }

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org