You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ka...@apache.org on 2021/08/16 02:07:43 UTC

[spark] branch branch-3.2 updated: [SPARK-36465][SS] Dynamic gap duration in session window

This is an automated email from the ASF dual-hosted git repository.

kabhwan pushed a commit to branch branch-3.2
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.2 by this push:
     new 3aa933b  [SPARK-36465][SS] Dynamic gap duration in session window
3aa933b is described below

commit 3aa933b16245f550ad08359bb986e8f624ea4a6a
Author: Liang-Chi Hsieh <vi...@gmail.com>
AuthorDate: Mon Aug 16 11:06:00 2021 +0900

    [SPARK-36465][SS] Dynamic gap duration in session window
    
    ### What changes were proposed in this pull request?
    
    This patch supports dynamic gap duration in session window.
    
    ### Why are the changes needed?
    
    The gap duration used in session window for now is a static value. To support more complex usage, it is better to support dynamic gap duration which determines the gap duration by looking at the current data. For example, in our usecase, we may have different gap by looking at the certain column in the input rows.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, users can specify dynamic gap duration.
    
    ### How was this patch tested?
    
    Modified existing tests and new test.
    
    Closes #33691 from viirya/dynamic-session-window-gap.
    
    Authored-by: Liang-Chi Hsieh <vi...@gmail.com>
    Signed-off-by: Jungtaek Lim <ka...@gmail.com>
    (cherry picked from commit 8b8d91cf64aeb4ccc51dfe914f307e28c57081f8)
    Signed-off-by: Jungtaek Lim <ka...@gmail.com>
---
 docs/structured-streaming-programming-guide.md     |  75 ++++++++++++++-
 python/pyspark/sql/functions.py                    |  21 ++++-
 python/pyspark/sql/functions.pyi                   |   2 +-
 .../spark/sql/catalyst/analysis/Analyzer.scala     |  17 +++-
 .../sql/catalyst/expressions/SessionWindow.scala   |  47 ++++------
 .../spark/sql/errors/QueryCompilationErrors.scala  |   5 +
 .../scala/org/apache/spark/sql/functions.scala     |  37 ++++++++
 .../spark/sql/DataFrameSessionWindowingSuite.scala |  96 ++++++++++++++++++-
 .../streaming/StreamingSessionWindowSuite.scala    | 104 ++++++++++++++++++++-
 9 files changed, 358 insertions(+), 46 deletions(-)

diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md
index 69e5ba9..b56d8c8 100644
--- a/docs/structured-streaming-programming-guide.md
+++ b/docs/structured-streaming-programming-guide.md
@@ -1080,8 +1080,8 @@ Tumbling and sliding window use `window` function, which has been described on a
 
 Session windows have different characteristic compared to the previous two types. Session window has a dynamic size
 of the window length, depending on the inputs. A session window starts with an input, and expands itself
-if following input has been received within gap duration. A session window closes when there's no input
-received within gap duration after receiving the latest input.
+if following input has been received within gap duration. For static gap duration, a session window closes when
+there's no input received within gap duration after receiving the latest input.
 
 Session window uses `session_window` function. The usage of the function is similar to the `window` function.
 
@@ -1134,6 +1134,77 @@ sessionizedCounts = events \
 </div>
 </div>
 
+Instead of static value, we can also provide an expression to specify gap duration dynamically
+based on the input row. Note that the rows with negative or zero gap duration will be filtered
+out from the aggregation.
+
+With dynamic gap duration, the closing of a session window does not depend on the latest input
+anymore. A session window's range is the union of all events' ranges which are determined by
+event start time and evaluated gap duration during the query execution.
+
+<div class="codetabs">
+<div data-lang="scala"  markdown="1">
+
+{% highlight scala %}
+import spark.implicits._
+
+val events = ... // streaming DataFrame of schema { timestamp: Timestamp, userId: String }
+
+val sessionWindow = session_window($"timestamp", when($"userId" === "user1", "5 seconds")
+  .when($"userId" === "user2", "20 seconds")
+  .otherwise("5 minutes"))
+
+// Group the data by session window and userId, and compute the count of each group
+val sessionizedCounts = events
+    .withWatermark("timestamp", "10 minutes")
+    .groupBy(
+        Column(sessionWindow),
+        $"userId")
+    .count()
+{% endhighlight %}
+
+</div>
+<div data-lang="java"  markdown="1">
+
+{% highlight java %}
+Dataset<Row> events = ... // streaming DataFrame of schema { timestamp: Timestamp, userId: String }
+
+SessionWindow sessionWindow = session_window(col("timestamp"), when(col("userId").equalTo("user1"), "5 seconds")
+  .when(col("userId").equalTo("user2"), "20 seconds")
+  .otherwise("5 minutes"))
+
+// Group the data by session window and userId, and compute the count of each group
+Dataset<Row> sessionizedCounts = events
+    .withWatermark("timestamp", "10 minutes")
+    .groupBy(
+        new Column(sessionWindow),
+        col("userId"))
+    .count();
+{% endhighlight %}
+
+</div>
+<div data-lang="python"  markdown="1">
+{% highlight python %}
+from pyspark.sql import functions as F
+
+events = ...  # streaming DataFrame of schema { timestamp: Timestamp, userId: String }
+
+session_window = session_window(events.timestamp, \
+    F.when(events.userId == "user1", "5 seconds") \
+    .when(events.userId == "user2", "20 seconds").otherwise("5 minutes"))
+
+# Group the data by session window and userId, and compute the count of each group
+sessionizedCounts = events \
+    .withWatermark("timestamp", "10 minutes") \
+    .groupBy(
+        session_window,
+        events.userId) \
+    .count()
+{% endhighlight %}
+
+</div>
+</div>
+
 Note that there are some restrictions when you use session window in streaming query, like below:
 
 - "Update mode" as output mode is not supported.
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 06d58b8..fa96ea6 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -2343,6 +2343,8 @@ def session_window(timeColumn, gapDuration):
     processing time.
     gapDuration is provided as strings, e.g. '1 second', '1 day 12 hours', '2 minutes'. Valid
     interval strings are 'week', 'day', 'hour', 'minute', 'second', 'millisecond', 'microsecond'.
+    It could also be a Column which can be evaluated to gap duration dynamically based on the
+    input row.
     The output column will be a struct called 'session_window' by default with the nested columns
     'start' and 'end', where 'start' and 'end' will be of :class:`pyspark.sql.types.TimestampType`.
     .. versionadded:: 3.2.0
@@ -2353,15 +2355,24 @@ def session_window(timeColumn, gapDuration):
     >>> w.select(w.session_window.start.cast("string").alias("start"),
     ...          w.session_window.end.cast("string").alias("end"), "sum").collect()
     [Row(start='2016-03-11 09:00:07', end='2016-03-11 09:00:12', sum=1)]
+    >>> w = df.groupBy(session_window("date", lit("5 seconds"))).agg(sum("val").alias("sum"))
+    >>> w.select(w.session_window.start.cast("string").alias("start"),
+    ...          w.session_window.end.cast("string").alias("end"), "sum").collect()
+    [Row(start='2016-03-11 09:00:07', end='2016-03-11 09:00:12', sum=1)]
     """
-    def check_string_field(field, fieldName):
-        if not field or type(field) is not str:
-            raise TypeError("%s should be provided as a string" % fieldName)
+    def check_field(field, fieldName):
+        if field is None or not isinstance(field, (str, Column)):
+            raise TypeError("%s should be provided as a string or Column" % fieldName)
 
     sc = SparkContext._active_spark_context
     time_col = _to_java_column(timeColumn)
-    check_string_field(gapDuration, "gapDuration")
-    res = sc._jvm.functions.session_window(time_col, gapDuration)
+    check_field(gapDuration, "gapDuration")
+    gap_duration = (
+        gapDuration
+        if isinstance(gapDuration, str)
+        else _to_java_column(gapDuration)
+    )
+    res = sc._jvm.functions.session_window(time_col, gap_duration)
     return Column(res)
 
 
diff --git a/python/pyspark/sql/functions.pyi b/python/pyspark/sql/functions.pyi
index 8342e65..5c39706 100644
--- a/python/pyspark/sql/functions.pyi
+++ b/python/pyspark/sql/functions.pyi
@@ -136,7 +136,7 @@ def window(
     slideDuration: Optional[str] = ...,
     startTime: Optional[str] = ...,
 ) -> Column: ...
-def session_window(timeColumn: ColumnOrName, gapDuration: str) -> Column: ...
+def session_window(timeColumn: ColumnOrName, gapDuration: Union[Column, str]) -> Column: ...
 def crc32(col: ColumnOrName) -> Column: ...
 def md5(col: ColumnOrName) -> Column: ...
 def sha1(col: ColumnOrName) -> Column: ...
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 349f9b9..468986d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -3987,7 +3987,14 @@ object SessionWindowing extends Rule[LogicalPlan] {
           SESSION_COL_NAME, session.dataType, metadata = newMetadata)()
 
         val sessionStart = PreciseTimestampConversion(session.timeColumn, TimestampType, LongType)
-        val sessionEnd = sessionStart + session.gapDuration
+        val gapDuration = session.gapDuration match {
+          case expr if Cast.canCast(expr.dataType, CalendarIntervalType) =>
+            Cast(expr, CalendarIntervalType)
+          case other =>
+            throw QueryCompilationErrors.sessionWindowGapDurationDataTypeError(other.dataType)
+        }
+        val sessionEnd = PreciseTimestampConversion(session.timeColumn + gapDuration,
+          TimestampType, LongType)
 
         val literalSessionStruct = CreateNamedStruct(
           Literal(SESSION_START) ::
@@ -4004,11 +4011,13 @@ object SessionWindowing extends Rule[LogicalPlan] {
         }
 
         // As same as tumbling window, we add a filter to filter out nulls.
-        val filterExpr = IsNotNull(session.timeColumn)
+        // And we also filter out events with negative or zero gap duration.
+        val filterExpr = IsNotNull(session.timeColumn) &&
+          (sessionAttr.getField(SESSION_END) > sessionAttr.getField(SESSION_START))
 
         replacedPlan.withNewChildren(
-          Project(sessionStruct +: child.output,
-            Filter(filterExpr, child)) :: Nil)
+          Filter(filterExpr,
+            Project(sessionStruct +: child.output, child)) :: Nil)
       } else if (numWindowExpr > 1) {
         throw QueryCompilationErrors.multiTimeWindowExpressionsNotSupportedError(p)
       } else {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SessionWindow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SessionWindow.scala
index 60b0744..eb46c0f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SessionWindow.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SessionWindow.scala
@@ -17,32 +17,31 @@
 
 package org.apache.spark.sql.catalyst.expressions
 
-import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
-import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure
+import org.apache.spark.sql.catalyst.util.IntervalUtils
 import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
 
 /**
  * Represent the session window.
  *
  * @param timeColumn the start time of session window
- * @param gapDuration the duration of session gap, meaning the session will close if there is
- *                    no new element appeared within "the last element in session + gap".
+ * @param gapDuration the duration of session gap. For static gap duration, meaning the session
+ *                    will close if there is no new element appeared within "the last element in
+ *                    session + gap". Besides a static gap duration value, users can also provide
+ *                    an expression to specify gap duration dynamically based on the input row.
+ *                    With dynamic gap duration, the closing of a session window does not depend
+ *                    on the latest input anymore. A session window's range is the union of all
+ *                    events' ranges which are determined by event start time and evaluated gap
+ *                    duration during the query execution. Note that the rows with negative or
+ *                    zero gap duration will be filtered out from the aggregation.
  */
-case class SessionWindow(timeColumn: Expression, gapDuration: Long) extends UnaryExpression
+case class SessionWindow(timeColumn: Expression, gapDuration: Expression) extends Expression
   with ImplicitCastInputTypes
   with Unevaluable
   with NonSQLExpression {
 
-  //////////////////////////
-  // SQL Constructors
-  //////////////////////////
-
-  def this(timeColumn: Expression, gapDuration: Expression) = {
-    this(timeColumn, TimeWindow.parseExpression(gapDuration))
-  }
-
-  override def child: Expression = timeColumn
-  override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType)
+  override def children: Seq[Expression] = Seq(timeColumn, gapDuration)
+  override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, AnyDataType)
   override def dataType: DataType = new StructType()
     .add(StructField("start", TimestampType))
     .add(StructField("end", TimestampType))
@@ -50,19 +49,10 @@ case class SessionWindow(timeColumn: Expression, gapDuration: Long) extends Unar
   // This expression is replaced in the analyzer.
   override lazy val resolved = false
 
-  /** Validate the inputs for the gap duration in addition to the input data type. */
-  override def checkInputDataTypes(): TypeCheckResult = {
-    val dataTypeCheck = super.checkInputDataTypes()
-    if (dataTypeCheck.isSuccess) {
-      if (gapDuration <= 0) {
-        return TypeCheckFailure(s"The window duration ($gapDuration) must be greater than 0.")
-      }
-    }
-    dataTypeCheck
-  }
+  override def nullable: Boolean = false
 
-  override protected def withNewChildInternal(newChild: Expression): Expression =
-    copy(timeColumn = newChild)
+  override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression =
+    copy(timeColumn = newChildren(0), gapDuration = newChildren(1))
 }
 
 object SessionWindow {
@@ -72,6 +62,7 @@ object SessionWindow {
       timeColumn: Expression,
       gapDuration: String): SessionWindow = {
     SessionWindow(timeColumn,
-      TimeWindow.getIntervalInMicroSeconds(gapDuration))
+      Literal(IntervalUtils.safeStringToInterval(UTF8String.fromString(gapDuration)),
+        CalendarIntervalType))
   }
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
index b62729a..eedf038 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
@@ -371,6 +371,11 @@ private[spark] object QueryCompilationErrors {
       t.origin.startPosition)
   }
 
+  def sessionWindowGapDurationDataTypeError(dt: DataType): Throwable = {
+    new AnalysisException("Gap duration expression used in session window must be " +
+      s"CalendarIntervalType, but got ${dt}")
+  }
+
   def viewOutputNumberMismatchQueryColumnNamesError(
       output: Seq[Attribute], queryColumnNames: Seq[String]): Throwable = {
     new AnalysisException(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 1ecd3e0..da82ac5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -3661,6 +3661,43 @@ object functions {
   }
 
   /**
+   * Generates session window given a timestamp specifying column.
+   *
+   * Session window is one of dynamic windows, which means the length of window is varying
+   * according to the given inputs. For static gap duration, the length of session window
+   * is defined as "the timestamp of latest input of the session + gap duration", so when
+   * the new inputs are bound to the current session window, the end time of session window
+   * can be expanded according to the new inputs.
+   *
+   * Besides a static gap duration value, users can also provide an expression to specify
+   * gap duration dynamically based on the input row. With dynamic gap duration, the closing
+   * of a session window does not depend on the latest input anymore. A session window's range
+   * is the union of all events' ranges which are determined by event start time and evaluated
+   * gap duration during the query execution. Note that the rows with negative or zero gap
+   * duration will be filtered out from the aggregation.
+   *
+   * Windows can support microsecond precision. gapDuration in the order of months are not
+   * supported.
+   *
+   * For a streaming query, you may use the function `current_timestamp` to generate windows on
+   * processing time.
+   *
+   * @param timeColumn The column or the expression to use as the timestamp for windowing by time.
+   *                   The time column must be of TimestampType.
+   * @param gapDuration A column specifying the timeout of the session. It could be static value,
+   *                    e.g. `10 minutes`, `1 second`, or an expression/UDF that specifies gap
+   *                    duration dynamically based on the input row.
+   *
+   * @group datetime_funcs
+   * @since 3.2.0
+   */
+  def session_window(timeColumn: Column, gapDuration: Column): Column = {
+    withExpr {
+      SessionWindow(timeColumn.expr, gapDuration.expr)
+    }.as("session_window")
+  }
+
+  /**
    * Creates timestamp from the number of seconds since UTC epoch.
    * @group datetime_funcs
    * @since 3.1.0
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala
index b70b2c6..7a0cd42 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala
@@ -263,9 +263,10 @@ class DataFrameSessionWindowingSuite extends QueryTest with SharedSparkSession
   private def withTempTable(f: String => Unit): Unit = {
     val tableName = "temp"
     Seq(
-      ("2016-03-27 19:39:34", 1),
-      ("2016-03-27 19:39:56", 2),
-      ("2016-03-27 19:39:27", 4)).toDF("time", "value").createOrReplaceTempView(tableName)
+      ("2016-03-27 19:39:34", 1, "10 seconds"),
+      ("2016-03-27 19:39:56", 2, "20 seconds"),
+      ("2016-03-27 19:39:27", 4, "30 seconds")).toDF("time", "value", "duration")
+      .createOrReplaceTempView(tableName)
     try {
       f(tableName)
     } finally {
@@ -287,4 +288,93 @@ class DataFrameSessionWindowingSuite extends QueryTest with SharedSparkSession
       )
     }
   }
+
+  test("SPARK-36465: time window in SQL with dynamic string expression") {
+    withTempTable { table =>
+      checkAnswer(
+        spark.sql(s"""select session_window(time, duration), value from $table""")
+          .select($"session_window.start".cast(StringType), $"session_window.end".cast(StringType),
+            $"value"),
+        Seq(
+          Row("2016-03-27 19:39:27", "2016-03-27 19:39:57", 4),
+          Row("2016-03-27 19:39:34", "2016-03-27 19:39:44", 1),
+          Row("2016-03-27 19:39:56", "2016-03-27 19:40:16", 2)
+        )
+      )
+    }
+  }
+
+  test("SPARK-36465: Unsupported dynamic gap datatype") {
+    withTempTable { table =>
+      val err = intercept[AnalysisException] {
+        spark.sql(s"""select session_window(time, 1.0), value from $table""")
+          .select($"session_window.start".cast(StringType), $"session_window.end".cast(StringType),
+            $"value")
+      }
+      assert(err.message.contains("Gap duration expression used in session window must be " +
+        "CalendarIntervalType, but got DecimalType(2,1)"))
+    }
+  }
+
+  test("SPARK-36465: time window in SQL with UDF as gap duration") {
+    withTempTable { table =>
+
+      spark.udf.register("gapDuration",
+        (i: java.lang.Integer) => s"${i * 10} seconds")
+
+      checkAnswer(
+        spark.sql(s"""select session_window(time, gapDuration(value)), value from $table""")
+          .select($"session_window.start".cast(StringType), $"session_window.end".cast(StringType),
+            $"value"),
+        Seq(
+          Row("2016-03-27 19:39:27", "2016-03-27 19:40:07", 4),
+          Row("2016-03-27 19:39:34", "2016-03-27 19:39:44", 1),
+          Row("2016-03-27 19:39:56", "2016-03-27 19:40:16", 2)
+        )
+      )
+    }
+  }
+
+  test("SPARK-36465: time window in SQL with conditional expression as gap duration") {
+    withTempTable { table =>
+
+      checkAnswer(
+        spark.sql("select session_window(time, " +
+          """case when value = 1 then "2 seconds" when value = 2 then "10 seconds" """ +
+          s"""else "20 seconds" end), value from $table""")
+          .select($"session_window.start".cast(StringType), $"session_window.end".cast(StringType),
+            $"value"),
+        Seq(
+          Row("2016-03-27 19:39:27", "2016-03-27 19:39:47", 4),
+          Row("2016-03-27 19:39:34", "2016-03-27 19:39:36", 1),
+          Row("2016-03-27 19:39:56", "2016-03-27 19:40:06", 2)
+        )
+      )
+    }
+  }
+
+  test("SPARK-36465: filter out events with negative/zero gap duration") {
+    withTempTable { table =>
+
+      spark.udf.register("gapDuration",
+        (i: java.lang.Integer) => {
+          if (i == 1) {
+            "0 seconds"
+          } else if (i == 2) {
+            "-10 seconds"
+          } else {
+            "5 seconds"
+          }
+        })
+
+      checkAnswer(
+        spark.sql(s"""select session_window(time, gapDuration(value)), value from $table""")
+          .groupBy($"session_window")
+          .agg(count("*").as("counts"))
+          .select($"session_window.start".cast("string"), $"session_window.end".cast("string"),
+            $"counts"),
+        Seq(Row("2016-03-27 19:39:27", "2016-03-27 19:39:32", 1))
+      )
+    }
+  }
 }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSessionWindowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSessionWindowSuite.scala
index c9a5e1a..f41f708 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSessionWindowSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSessionWindowSuite.scala
@@ -23,7 +23,7 @@ import org.scalatest.BeforeAndAfter
 import org.scalatest.matchers.must.Matchers
 
 import org.apache.spark.internal.Logging
-import org.apache.spark.sql.{AnalysisException, DataFrame}
+import org.apache.spark.sql.{AnalysisException, Column, DataFrame}
 import org.apache.spark.sql.execution.streaming.MemoryStream
 import org.apache.spark.sql.execution.streaming.state.{HDFSBackedStateStoreProvider, RocksDBStateStoreProvider}
 import org.apache.spark.sql.functions.{count, session_window, sum}
@@ -256,6 +256,102 @@ class StreamingSessionWindowSuite extends StreamTest
     )
   }
 
+  testWithAllOptions("SPARK-36465: dynamic gap duration") {
+    val inputData = MemoryStream[(String, Long)]
+
+    val udf = spark.udf.register("gapDuration", (s: String) => {
+      if (s == "hello") {
+        "1 second"
+      } else if (s == "structured") {
+        // zero gap duration will be filtered out from aggregation
+        "0 second"
+      } else if (s == "world") {
+        // negative gap duration will be filtered out from aggregation
+        "-10 seconds"
+      } else {
+        "10 seconds"
+      }
+    })
+
+    val sessionUpdates = sessionWindowQuery(inputData,
+      session_window($"eventTime", udf($"sessionId")))
+
+    testStream(sessionUpdates, OutputMode.Append())(
+      AddData(inputData,
+        ("hello world spark streaming", 40L),
+        ("world hello structured streaming", 41L)
+      ),
+
+      // watermark: 11
+      // current sessions
+      // ("hello", 40, 42, 2, 2),
+      // ("streaming", 40, 51, 11, 2),
+      // ("spark", 40, 50, 10, 1),
+      CheckNewAnswer(
+      ),
+
+      // placing new sessions "before" previous sessions
+      AddData(inputData, ("spark streaming", 25L)),
+      // watermark: 11
+      // current sessions
+      // ("spark", 25, 35, 10, 1),
+      // ("streaming", 25, 35, 10, 1),
+      // ("hello", 40, 42, 2, 2),
+      // ("streaming", 40, 51, 11, 2),
+      // ("spark", 40, 50, 10, 1),
+      CheckNewAnswer(
+      ),
+
+      // late event which session's end 10 would be later than watermark 11: should be dropped
+      AddData(inputData, ("spark streaming", 0L)),
+      // watermark: 11
+      // current sessions
+      // ("spark", 25, 35, 10, 1),
+      // ("streaming", 25, 35, 10, 1),
+      // ("hello", 40, 42, 2, 2),
+      // ("streaming", 40, 51, 11, 2),
+      // ("spark", 40, 50, 10, 1),
+      CheckNewAnswer(
+      ),
+
+      // concatenating multiple previous sessions into one
+      AddData(inputData, ("spark streaming", 30L)),
+      // watermark: 11
+      // current sessions
+      // ("spark", 25, 50, 25, 3),
+      // ("streaming", 25, 51, 26, 4),
+      // ("hello", 40, 42, 2, 2),
+      CheckNewAnswer(
+      ),
+
+      // placing new sessions after previous sessions
+      AddData(inputData, ("hello apache spark", 60L)),
+      // watermark: 30
+      // current sessions
+      // ("spark", 25, 50, 25, 3),
+      // ("streaming", 25, 51, 26, 4),
+      // ("hello", 40, 42, 2, 2),
+      // ("hello", 60, 61, 1, 1),
+      // ("apache", 60, 70, 10, 1),
+      // ("spark", 60, 70, 10, 1)
+      CheckNewAnswer(
+      ),
+
+      AddData(inputData, ("structured streaming", 90L)),
+      // watermark: 60
+      // current sessions
+      // ("hello", 60, 61, 1, 1),
+      // ("apache", 60, 70, 10, 1),
+      // ("spark", 60, 70, 10, 1),
+      // ("streaming", 90, 100, 10, 1)
+      CheckNewAnswer(
+        ("spark", 25, 50, 25, 3),
+        ("streaming", 25, 51, 26, 4),
+        ("hello", 40, 42, 2, 2)
+      )
+    )
+  }
+
   testWithAllOptions("append mode - session window - no key") {
     val inputData = MemoryStream[Int]
     val windowedAggregation = sessionWindowQueryOnGlobalKey(inputData)
@@ -304,7 +400,9 @@ class StreamingSessionWindowSuite extends StreamTest
     }
   }
 
-  private def sessionWindowQuery(input: MemoryStream[(String, Long)]): DataFrame = {
+  private def sessionWindowQuery(
+      input: MemoryStream[(String, Long)],
+      sessionWindow: Column = session_window($"eventTime", "10 seconds")): DataFrame = {
     // Split the lines into words, treat words as sessionId of events
     val events = input.toDF()
       .select($"_1".as("value"), $"_2".as("timestamp"))
@@ -313,7 +411,7 @@ class StreamingSessionWindowSuite extends StreamTest
       .selectExpr("explode(split(value, ' ')) AS sessionId", "eventTime")
 
     events
-      .groupBy(session_window($"eventTime", "10 seconds") as 'session, 'sessionId)
+      .groupBy(sessionWindow as 'session, 'sessionId)
       .agg(count("*").as("numEvents"))
       .selectExpr("sessionId", "CAST(session.start AS LONG)", "CAST(session.end AS LONG)",
         "CAST(session.end AS LONG) - CAST(session.start AS LONG) AS durationMs",

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org