You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by "HeartSaVioR (via GitHub)" <gi...@apache.org> on 2024/03/08 05:34:20 UTC

Re: [PR] [SPARK-46913][SS] Add support for processing/event time based timers with transformWithState operator [spark]

HeartSaVioR commented on code in PR #45051:
URL: https://github.com/apache/spark/pull/45051#discussion_r1517119556


##########
sql/api/src/main/scala/org/apache/spark/sql/streaming/ExpiredTimerInfo.scala:
##########
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.streaming
+
+import java.io.Serializable
+
+import org.apache.spark.annotation.{Evolving, Experimental}
+
+/**
+ * Class used to provide access to expired timer's expiry time and timeout mode. These values

Review Comment:
   nit: Technically the timeout mode is not visible with trait. If that's intentional, probably remove that part in the interface doc.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ExpiredTimerInfoImpl.scala:
##########
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.streaming
+
+import org.apache.spark.sql.streaming.{ExpiredTimerInfo, TimeoutMode}
+
+/**
+ * Class that provides a concrete implementation that can be used to provide access to expired
+ * timer's expiry time and timeout mode. These values are only relevant if the ExpiredTimerInfo

Review Comment:
   nit: same, timeout mode is not visible to user function AFAIK.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala:
##########
@@ -103,8 +116,12 @@ case class TransformWithStateExec(
     val keyObj = getKeyObj(keyRow)  // convert key to objects
     ImplicitGroupingKeyTracker.setImplicitKey(keyObj)
     val valueObjIter = valueRowIter.map(getValueObj.apply)
-    val mappedIterator = statefulProcessor.handleInputRows(keyObj, valueObjIter,
-      new TimerValuesImpl(batchTimestampMs, eventTimeWatermarkForLateEvents)).map { obj =>
+    val mappedIterator = statefulProcessor.handleInputRows(
+      keyObj,
+      valueObjIter,
+      new TimerValuesImpl(batchTimestampMs, eventTimeWatermarkForLateEvents),
+      new ExpiredTimerInfoImpl(false)

Review Comment:
   super nit / 2 cents: name parameter for boolean would give much better readability in non-IDE environment. It's really more about general suggestion and preference, so you can ignore.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala:
##########
@@ -121,6 +123,46 @@ class StatefulProcessorHandleImpl(
 
   override def getQueryInfo(): QueryInfo = currQueryInfo
 
+  private def getTimerState[T](): TimerStateImpl[T] = {
+    new TimerStateImpl[T](store, timeoutMode, keyEncoder)
+  }
+
+  private val timerState = getTimerState[Boolean]()
+
+  override def registerTimer(expiryTimestampMs: Long): Unit = {
+    if (!(timeoutMode == ProcessingTime || timeoutMode == EventTime)) {
+      throw StateStoreErrors.cannotUseTimersWithInvalidTimeoutMode(timeoutMode.toString)
+    }
+
+    if (!(currState == INITIALIZED || currState == DATA_PROCESSED)) {
+      throw StateStoreErrors.cannotUseTimersWithInvalidHandleState(currState.toString)
+    }
+
+    if (timerState.exists(expiryTimestampMs)) {
+      logWarning(s"Timer already exists for expiryTimestampMs=$expiryTimestampMs")

Review Comment:
   I'm OK with moving the logging to TimerStateImpl if it helps to solve this.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala:
##########
@@ -121,6 +123,46 @@ class StatefulProcessorHandleImpl(
 
   override def getQueryInfo(): QueryInfo = currQueryInfo
 
+  private def getTimerState[T](): TimerStateImpl[T] = {
+    new TimerStateImpl[T](store, timeoutMode, keyEncoder)
+  }
+
+  private val timerState = getTimerState[Boolean]()
+
+  override def registerTimer(expiryTimestampMs: Long): Unit = {
+    if (!(timeoutMode == ProcessingTime || timeoutMode == EventTime)) {
+      throw StateStoreErrors.cannotUseTimersWithInvalidTimeoutMode(timeoutMode.toString)
+    }
+
+    if (!(currState == INITIALIZED || currState == DATA_PROCESSED)) {

Review Comment:
   nit: maybe `currState < TIMER_PROCESSED`? Given we set the order of enum properly to reflect the sequence of phases, this seems to be easier to understand what the state should be to perform this operation.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala:
##########
@@ -121,6 +123,46 @@ class StatefulProcessorHandleImpl(
 
   override def getQueryInfo(): QueryInfo = currQueryInfo
 
+  private def getTimerState[T](): TimerStateImpl[T] = {
+    new TimerStateImpl[T](store, timeoutMode, keyEncoder)
+  }
+
+  private val timerState = getTimerState[Boolean]()
+
+  override def registerTimer(expiryTimestampMs: Long): Unit = {
+    if (!(timeoutMode == ProcessingTime || timeoutMode == EventTime)) {

Review Comment:
   nit: Maybe just compare with NoTimeout? I don't imagine we will be ever adding more timeout mode. Also, if we are ever going to add a new timeout, it would be more chance of another valid timeout mode rather than another representation of no timeout.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TimerStateImpl.scala:
##########
@@ -0,0 +1,224 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.streaming
+
+import java.io.Serializable
+import java.nio.{ByteBuffer, ByteOrder}
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.execution.streaming.state._
+import org.apache.spark.sql.streaming.TimeoutMode
+import org.apache.spark.sql.types._
+import org.apache.spark.util.NextIterator
+
+/**
+ * Singleton utils class used primarily while interacting with TimerState
+ */
+object TimerStateUtils {
+  case class TimestampWithKey(
+      key: Any,
+      expiryTimestampMs: Long) extends Serializable
+
+  val PROC_TIMERS_STATE_NAME = "_procTimers"
+  val EVENT_TIMERS_STATE_NAME = "_eventTimers"
+  val KEY_TO_TIMESTAMP_CF = "_keyToTimestamp"
+  val TIMESTAMP_TO_KEY_CF = "_timestampToKey"
+}
+
+/**
+ * Class that provides the implementation for storing timers
+ * used within the `transformWithState` operator.
+ * @param store - state store to be used for storing timer data
+ * @param timeoutMode - mode of timeout (event time or processing time)
+ * @param keyExprEnc - encoder for key expression
+ * @tparam S - type of timer value
+ */
+class TimerStateImpl[S](
+    store: StateStore,
+    timeoutMode: TimeoutMode,
+    keyExprEnc: ExpressionEncoder[Any]) extends Logging {
+
+  private val EMPTY_ROW =
+    UnsafeProjection.create(Array[DataType](NullType)).apply(InternalRow.apply(null))
+
+  private val schemaForPrefixKey: StructType = new StructType()
+    .add("key", BinaryType)
+
+  private val schemaForKeyRow: StructType = new StructType()
+    .add("key", BinaryType)
+    .add("expiryTimestampMs", LongType, nullable = false)
+
+  private val keySchemaForSecIndex: StructType = new StructType()
+    .add("expiryTimestampMs", BinaryType, nullable = false)
+    .add("key", BinaryType)
+
+  private val schemaForValueRow: StructType =
+    StructType(Array(StructField("__dummy__", NullType)))
+
+  private val keySerializer = keyExprEnc.createSerializer()
+
+  private val prefixKeyEncoder = UnsafeProjection.create(schemaForPrefixKey)
+
+  private val keyEncoder = UnsafeProjection.create(schemaForKeyRow)
+
+  private val secIndexKeyEncoder = UnsafeProjection.create(keySchemaForSecIndex)
+
+  val timerCfName = if (timeoutMode == TimeoutMode.ProcessingTime) {

Review Comment:
   nit: `CF`? for consistency with below



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala:
##########
@@ -121,6 +123,46 @@ class StatefulProcessorHandleImpl(
 
   override def getQueryInfo(): QueryInfo = currQueryInfo
 
+  private def getTimerState[T](): TimerStateImpl[T] = {
+    new TimerStateImpl[T](store, timeoutMode, keyEncoder)
+  }
+
+  private val timerState = getTimerState[Boolean]()
+
+  override def registerTimer(expiryTimestampMs: Long): Unit = {
+    if (!(timeoutMode == ProcessingTime || timeoutMode == EventTime)) {
+      throw StateStoreErrors.cannotUseTimersWithInvalidTimeoutMode(timeoutMode.toString)
+    }
+
+    if (!(currState == INITIALIZED || currState == DATA_PROCESSED)) {
+      throw StateStoreErrors.cannotUseTimersWithInvalidHandleState(currState.toString)
+    }
+
+    if (timerState.exists(expiryTimestampMs)) {
+      logWarning(s"Timer already exists for expiryTimestampMs=$expiryTimestampMs")
+    } else {
+      logInfo(s"Registering timer with expiryTimestampMs=$expiryTimestampMs")

Review Comment:
   ditto



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala:
##########
@@ -121,6 +123,46 @@ class StatefulProcessorHandleImpl(
 
   override def getQueryInfo(): QueryInfo = currQueryInfo
 
+  private def getTimerState[T](): TimerStateImpl[T] = {

Review Comment:
   nit: This is one liner - is this referred from only one place? Then let's inline. 



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala:
##########
@@ -121,6 +123,46 @@ class StatefulProcessorHandleImpl(
 
   override def getQueryInfo(): QueryInfo = currQueryInfo
 
+  private def getTimerState[T](): TimerStateImpl[T] = {
+    new TimerStateImpl[T](store, timeoutMode, keyEncoder)
+  }
+
+  private val timerState = getTimerState[Boolean]()
+
+  override def registerTimer(expiryTimestampMs: Long): Unit = {
+    if (!(timeoutMode == ProcessingTime || timeoutMode == EventTime)) {
+      throw StateStoreErrors.cannotUseTimersWithInvalidTimeoutMode(timeoutMode.toString)
+    }
+
+    if (!(currState == INITIALIZED || currState == DATA_PROCESSED)) {
+      throw StateStoreErrors.cannotUseTimersWithInvalidHandleState(currState.toString)
+    }
+
+    if (timerState.exists(expiryTimestampMs)) {
+      logWarning(s"Timer already exists for expiryTimestampMs=$expiryTimestampMs")
+    } else {
+      logInfo(s"Registering timer with expiryTimestampMs=$expiryTimestampMs")
+      timerState.add(expiryTimestampMs, true)
+    }
+  }
+
+  override def deleteTimer(expiryTimestampMs: Long): Unit = {
+    if (!(timeoutMode == ProcessingTime || timeoutMode == EventTime)) {
+      throw StateStoreErrors.cannotUseTimersWithInvalidTimeoutMode(timeoutMode.toString)
+    }
+
+    if (!(currState == INITIALIZED || currState == DATA_PROCESSED)) {

Review Comment:
   nit: ditto



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala:
##########
@@ -121,6 +123,46 @@ class StatefulProcessorHandleImpl(
 
   override def getQueryInfo(): QueryInfo = currQueryInfo
 
+  private def getTimerState[T](): TimerStateImpl[T] = {
+    new TimerStateImpl[T](store, timeoutMode, keyEncoder)
+  }
+
+  private val timerState = getTimerState[Boolean]()
+
+  override def registerTimer(expiryTimestampMs: Long): Unit = {
+    if (!(timeoutMode == ProcessingTime || timeoutMode == EventTime)) {
+      throw StateStoreErrors.cannotUseTimersWithInvalidTimeoutMode(timeoutMode.toString)
+    }
+
+    if (!(currState == INITIALIZED || currState == DATA_PROCESSED)) {
+      throw StateStoreErrors.cannotUseTimersWithInvalidHandleState(currState.toString)
+    }
+
+    if (timerState.exists(expiryTimestampMs)) {
+      logWarning(s"Timer already exists for expiryTimestampMs=$expiryTimestampMs")
+    } else {
+      logInfo(s"Registering timer with expiryTimestampMs=$expiryTimestampMs")
+      timerState.add(expiryTimestampMs, true)
+    }
+  }
+
+  override def deleteTimer(expiryTimestampMs: Long): Unit = {
+    if (!(timeoutMode == ProcessingTime || timeoutMode == EventTime)) {

Review Comment:
   nit: ditto



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala:
##########
@@ -121,6 +123,46 @@ class StatefulProcessorHandleImpl(
 
   override def getQueryInfo(): QueryInfo = currQueryInfo
 
+  private def getTimerState[T](): TimerStateImpl[T] = {
+    new TimerStateImpl[T](store, timeoutMode, keyEncoder)
+  }
+
+  private val timerState = getTimerState[Boolean]()
+
+  override def registerTimer(expiryTimestampMs: Long): Unit = {
+    if (!(timeoutMode == ProcessingTime || timeoutMode == EventTime)) {
+      throw StateStoreErrors.cannotUseTimersWithInvalidTimeoutMode(timeoutMode.toString)
+    }
+
+    if (!(currState == INITIALIZED || currState == DATA_PROCESSED)) {
+      throw StateStoreErrors.cannotUseTimersWithInvalidHandleState(currState.toString)
+    }
+
+    if (timerState.exists(expiryTimestampMs)) {
+      logWarning(s"Timer already exists for expiryTimestampMs=$expiryTimestampMs")
+    } else {
+      logInfo(s"Registering timer with expiryTimestampMs=$expiryTimestampMs")
+      timerState.add(expiryTimestampMs, true)
+    }
+  }
+
+  override def deleteTimer(expiryTimestampMs: Long): Unit = {
+    if (!(timeoutMode == ProcessingTime || timeoutMode == EventTime)) {
+      throw StateStoreErrors.cannotUseTimersWithInvalidTimeoutMode(timeoutMode.toString)
+    }
+
+    if (!(currState == INITIALIZED || currState == DATA_PROCESSED)) {
+      throw StateStoreErrors.cannotUseTimersWithInvalidHandleState(currState.toString)
+    }
+
+    if (!timerState.exists(expiryTimestampMs)) {
+      logInfo(s"Timer does not exist for expiryTimestampMs=$expiryTimestampMs")

Review Comment:
   ditto



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala:
##########
@@ -121,6 +123,46 @@ class StatefulProcessorHandleImpl(
 
   override def getQueryInfo(): QueryInfo = currQueryInfo
 
+  private def getTimerState[T](): TimerStateImpl[T] = {
+    new TimerStateImpl[T](store, timeoutMode, keyEncoder)
+  }
+
+  private val timerState = getTimerState[Boolean]()
+
+  override def registerTimer(expiryTimestampMs: Long): Unit = {
+    if (!(timeoutMode == ProcessingTime || timeoutMode == EventTime)) {
+      throw StateStoreErrors.cannotUseTimersWithInvalidTimeoutMode(timeoutMode.toString)
+    }
+
+    if (!(currState == INITIALIZED || currState == DATA_PROCESSED)) {
+      throw StateStoreErrors.cannotUseTimersWithInvalidHandleState(currState.toString)
+    }
+
+    if (timerState.exists(expiryTimestampMs)) {
+      logWarning(s"Timer already exists for expiryTimestampMs=$expiryTimestampMs")
+    } else {
+      logInfo(s"Registering timer with expiryTimestampMs=$expiryTimestampMs")
+      timerState.add(expiryTimestampMs, true)
+    }
+  }
+
+  override def deleteTimer(expiryTimestampMs: Long): Unit = {
+    if (!(timeoutMode == ProcessingTime || timeoutMode == EventTime)) {
+      throw StateStoreErrors.cannotUseTimersWithInvalidTimeoutMode(timeoutMode.toString)
+    }
+
+    if (!(currState == INITIALIZED || currState == DATA_PROCESSED)) {
+      throw StateStoreErrors.cannotUseTimersWithInvalidHandleState(currState.toString)
+    }
+
+    if (!timerState.exists(expiryTimestampMs)) {
+      logInfo(s"Timer does not exist for expiryTimestampMs=$expiryTimestampMs")
+    } else {
+      logInfo(s"Removing timer with expiryTimestampMs=$expiryTimestampMs")

Review Comment:
   ditto



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TimerStateImpl.scala:
##########
@@ -0,0 +1,224 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.streaming
+
+import java.io.Serializable
+import java.nio.{ByteBuffer, ByteOrder}
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.execution.streaming.state._
+import org.apache.spark.sql.streaming.TimeoutMode
+import org.apache.spark.sql.types._
+import org.apache.spark.util.NextIterator
+
+/**
+ * Singleton utils class used primarily while interacting with TimerState
+ */
+object TimerStateUtils {
+  case class TimestampWithKey(
+      key: Any,
+      expiryTimestampMs: Long) extends Serializable
+
+  val PROC_TIMERS_STATE_NAME = "_procTimers"
+  val EVENT_TIMERS_STATE_NAME = "_eventTimers"
+  val KEY_TO_TIMESTAMP_CF = "_keyToTimestamp"
+  val TIMESTAMP_TO_KEY_CF = "_timestampToKey"
+}
+
+/**
+ * Class that provides the implementation for storing timers
+ * used within the `transformWithState` operator.
+ * @param store - state store to be used for storing timer data
+ * @param timeoutMode - mode of timeout (event time or processing time)
+ * @param keyExprEnc - encoder for key expression
+ * @tparam S - type of timer value
+ */
+class TimerStateImpl[S](
+    store: StateStore,
+    timeoutMode: TimeoutMode,
+    keyExprEnc: ExpressionEncoder[Any]) extends Logging {
+
+  private val EMPTY_ROW =
+    UnsafeProjection.create(Array[DataType](NullType)).apply(InternalRow.apply(null))
+
+  private val schemaForPrefixKey: StructType = new StructType()
+    .add("key", BinaryType)
+
+  private val schemaForKeyRow: StructType = new StructType()
+    .add("key", BinaryType)
+    .add("expiryTimestampMs", LongType, nullable = false)
+
+  private val keySchemaForSecIndex: StructType = new StructType()
+    .add("expiryTimestampMs", BinaryType, nullable = false)
+    .add("key", BinaryType)
+
+  private val schemaForValueRow: StructType =
+    StructType(Array(StructField("__dummy__", NullType)))
+
+  private val keySerializer = keyExprEnc.createSerializer()
+
+  private val prefixKeyEncoder = UnsafeProjection.create(schemaForPrefixKey)
+
+  private val keyEncoder = UnsafeProjection.create(schemaForKeyRow)
+
+  private val secIndexKeyEncoder = UnsafeProjection.create(keySchemaForSecIndex)
+
+  val timerCfName = if (timeoutMode == TimeoutMode.ProcessingTime) {
+    TimerStateUtils.PROC_TIMERS_STATE_NAME
+  } else {
+    TimerStateUtils.EVENT_TIMERS_STATE_NAME
+  }
+
+  val keyToTsCFName = timerCfName + TimerStateUtils.KEY_TO_TIMESTAMP_CF
+  store.createColFamilyIfAbsent(keyToTsCFName,
+    schemaForKeyRow, numColsPrefixKey = 1,
+    schemaForValueRow, useMultipleValuesPerKey = false,
+    isInternal = true)
+
+  val tsToKeyCFName = timerCfName + TimerStateUtils.TIMESTAMP_TO_KEY_CF
+  store.createColFamilyIfAbsent(tsToKeyCFName,
+    keySchemaForSecIndex, numColsPrefixKey = 0,
+    schemaForValueRow, useMultipleValuesPerKey = false,
+    isInternal = true)
+
+  private def encodeKey(expiryTimestampMs: Long): UnsafeRow = {
+    val keyOption = ImplicitGroupingKeyTracker.getImplicitKeyOption
+    if (!keyOption.isDefined) {
+      throw StateStoreErrors.implicitKeyNotFound(keyToTsCFName)
+    }
+
+    val keyByteArr = keySerializer.apply(keyOption.get).asInstanceOf[UnsafeRow].getBytes()
+    val keyRow = keyEncoder(InternalRow(keyByteArr, expiryTimestampMs))
+    keyRow
+  }
+
+  //  We maintain a secondary index that inverts the ordering of the timestamp
+  //  and grouping key and maintains the list of (expiry) timestamps in sorted order
+  //  (using BIG_ENDIAN encoding) within RocksDB.
+  //  This is because RocksDB uses byte-wise comparison using the default comparator to
+  //  determine sorted order of keys. This is used to read expired timers at any given
+  //  processing time/event time timestamp threshold by performing a range scan.
+  private def encodeSecIndexKey(expiryTimestampMs: Long): UnsafeRow = {
+    val keyOption = ImplicitGroupingKeyTracker.getImplicitKeyOption
+    if (!keyOption.isDefined) {
+      throw StateStoreErrors.implicitKeyNotFound(tsToKeyCFName)
+    }
+
+    val keyByteArr = keySerializer.apply(keyOption.get).asInstanceOf[UnsafeRow].getBytes()
+    val bbuf = ByteBuffer.allocate(8)
+    bbuf.order(ByteOrder.BIG_ENDIAN)
+    bbuf.putLong(expiryTimestampMs)
+    val keyRow = secIndexKeyEncoder(InternalRow(bbuf.array(), keyByteArr))
+    keyRow
+  }
+
+  /**
+   * Function to check if the timer for the given key and timestamp is already registered
+   * @param expiryTimestampMs - expiry timestamp of the timer
+   * @return - true if the timer is already registered, false otherwise
+   */
+  def exists(expiryTimestampMs: Long): Boolean = {
+    getImpl(expiryTimestampMs) != null
+  }
+
+  private def getImpl(expiryTimestampMs: Long): UnsafeRow = {
+    store.get(encodeKey(expiryTimestampMs), keyToTsCFName)
+  }
+
+  /**
+   * Function to add a new timer for the given key and timestamp
+   * @param expiryTimestampMs - expiry timestamp of the timer
+   * @param newState = boolean value to be stored for the state value
+   */
+  def add(expiryTimestampMs: Long, newState: S): Unit = {
+    store.put(encodeKey(expiryTimestampMs), EMPTY_ROW, keyToTsCFName)
+    store.put(encodeSecIndexKey(expiryTimestampMs), EMPTY_ROW, tsToKeyCFName)
+  }
+
+  /**
+   * Function to remove the timer for the given key and timestamp
+   * @param expiryTimestampMs - expiry timestamp of the timer
+   */
+  def remove(expiryTimestampMs: Long): Unit = {
+    store.remove(encodeKey(expiryTimestampMs), keyToTsCFName)
+    store.remove(encodeSecIndexKey(expiryTimestampMs), tsToKeyCFName)
+  }
+
+  def listTimers(): Iterator[Long] = {
+    val keyOption = ImplicitGroupingKeyTracker.getImplicitKeyOption
+    if (!keyOption.isDefined) {
+      throw StateStoreErrors.implicitKeyNotFound(keyToTsCFName)
+    }
+
+    val keyByteArr = keySerializer.apply(keyOption.get).asInstanceOf[UnsafeRow].getBytes()
+    val keyRow = prefixKeyEncoder(InternalRow(keyByteArr))
+    val iter = store.prefixScan(keyRow, keyToTsCFName)
+    iter.map { kv =>
+      val keyRow = kv.key
+      keyRow.getLong(1)
+    }
+  }
+
+  private def getTimerRow(keyRow: UnsafeRow): (Any, Long) = {
+    // Decode the key object from the UnsafeRow
+    val keyBytes = keyRow.getBinary(1)
+    val retUnsafeRow = new UnsafeRow(1)
+    retUnsafeRow.pointTo(keyBytes, keyBytes.length)
+    val keyObj = keyExprEnc.resolveAndBind().
+    createDeserializer().apply(retUnsafeRow).asInstanceOf[Any]

Review Comment:
   nit: shall we move `.` to this line and indent? This is slightly confusing like not continuation of lines.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala:
##########
@@ -121,6 +123,46 @@ class StatefulProcessorHandleImpl(
 
   override def getQueryInfo(): QueryInfo = currQueryInfo
 
+  private def getTimerState[T](): TimerStateImpl[T] = {
+    new TimerStateImpl[T](store, timeoutMode, keyEncoder)
+  }
+
+  private val timerState = getTimerState[Boolean]()
+
+  override def registerTimer(expiryTimestampMs: Long): Unit = {
+    if (!(timeoutMode == ProcessingTime || timeoutMode == EventTime)) {
+      throw StateStoreErrors.cannotUseTimersWithInvalidTimeoutMode(timeoutMode.toString)
+    }
+
+    if (!(currState == INITIALIZED || currState == DATA_PROCESSED)) {
+      throw StateStoreErrors.cannotUseTimersWithInvalidHandleState(currState.toString)
+    }
+
+    if (timerState.exists(expiryTimestampMs)) {
+      logWarning(s"Timer already exists for expiryTimestampMs=$expiryTimestampMs")

Review Comment:
   Do we ever log the information of "grouping key" as a context? I prefer to provide most relevant information in the same line if it is not super long, but I'm fine with it if you have log in surrounding context. If there is no information about grouping key, user wouldn't be able to understand what happened.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala:
##########
@@ -103,8 +116,12 @@ case class TransformWithStateExec(
     val keyObj = getKeyObj(keyRow)  // convert key to objects
     ImplicitGroupingKeyTracker.setImplicitKey(keyObj)
     val valueObjIter = valueRowIter.map(getValueObj.apply)
-    val mappedIterator = statefulProcessor.handleInputRows(keyObj, valueObjIter,
-      new TimerValuesImpl(batchTimestampMs, eventTimeWatermarkForLateEvents)).map { obj =>
+    val mappedIterator = statefulProcessor.handleInputRows(
+      keyObj,
+      valueObjIter,
+      new TimerValuesImpl(batchTimestampMs, eventTimeWatermarkForLateEvents),
+      new ExpiredTimerInfoImpl(false)
+      ).map { obj =>

Review Comment:
   nit: shift left? or one line above



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala:
##########
@@ -69,8 +70,20 @@ case class TransformWithStateExec(
 
   override def shortName: String = "transformWithStateExec"
 
-  // TODO: update this to run no-data batches when timer support is added
-  override def shouldRunAnotherBatch(newInputWatermark: Long): Boolean = false
+  override def shouldRunAnotherBatch(newInputWatermark: Long): Boolean = {
+    timeoutMode match {
+      // TODO: check if we can return true only if actual timers are registered

Review Comment:
   Good to call out. Currently we have indefinite batch run for processing timer in flatMapGroupsWithState and I see some complaint around it.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala:
##########
@@ -103,8 +116,12 @@ case class TransformWithStateExec(
     val keyObj = getKeyObj(keyRow)  // convert key to objects
     ImplicitGroupingKeyTracker.setImplicitKey(keyObj)
     val valueObjIter = valueRowIter.map(getValueObj.apply)
-    val mappedIterator = statefulProcessor.handleInputRows(keyObj, valueObjIter,
-      new TimerValuesImpl(batchTimestampMs, eventTimeWatermarkForLateEvents)).map { obj =>
+    val mappedIterator = statefulProcessor.handleInputRows(
+      keyObj,
+      valueObjIter,
+      new TimerValuesImpl(batchTimestampMs, eventTimeWatermarkForLateEvents),

Review Comment:
   I agree it's super confusing, but conceptually, two watermark values are really a one watermark value with two phases, like following:
   
   1. the operator starts with watermark as eventTimeWatermarkForLateEvents (continuation of watermark in previous microbatch)
   2. the operator processes all inputs
   3. the operator advances the watermark as eventTimeWatermarkForEviction
   4. the operator handles eviction based on watermark advancement
    
   like row1, row2, row3, ..., rowN, WT (watermark marker) in the input queue, if it's easier to understand based on record-to-record based streaming engine.
   
   That said, conceptually, here the watermark value is correct, but in onTimer, the watermark value should be eventTimeWatermarkForEviction. But if, we concern that users would be confusing (despite the fact that this is correct) and we want to provide the single value, arguably, eventTimeWatermarkForEviction seems to be more proper. We drop late events already, and users don't get the chance to process them anyway. So the focus should be more about eviction.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TimerStateImpl.scala:
##########
@@ -0,0 +1,224 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.streaming
+
+import java.io.Serializable
+import java.nio.{ByteBuffer, ByteOrder}
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.execution.streaming.state._
+import org.apache.spark.sql.streaming.TimeoutMode
+import org.apache.spark.sql.types._
+import org.apache.spark.util.NextIterator
+
+/**
+ * Singleton utils class used primarily while interacting with TimerState
+ */
+object TimerStateUtils {
+  case class TimestampWithKey(
+      key: Any,
+      expiryTimestampMs: Long) extends Serializable
+
+  val PROC_TIMERS_STATE_NAME = "_procTimers"
+  val EVENT_TIMERS_STATE_NAME = "_eventTimers"
+  val KEY_TO_TIMESTAMP_CF = "_keyToTimestamp"
+  val TIMESTAMP_TO_KEY_CF = "_timestampToKey"
+}
+
+/**
+ * Class that provides the implementation for storing timers
+ * used within the `transformWithState` operator.
+ * @param store - state store to be used for storing timer data
+ * @param timeoutMode - mode of timeout (event time or processing time)
+ * @param keyExprEnc - encoder for key expression
+ * @tparam S - type of timer value
+ */
+class TimerStateImpl[S](
+    store: StateStore,
+    timeoutMode: TimeoutMode,
+    keyExprEnc: ExpressionEncoder[Any]) extends Logging {
+
+  private val EMPTY_ROW =
+    UnsafeProjection.create(Array[DataType](NullType)).apply(InternalRow.apply(null))
+
+  private val schemaForPrefixKey: StructType = new StructType()
+    .add("key", BinaryType)
+
+  private val schemaForKeyRow: StructType = new StructType()
+    .add("key", BinaryType)
+    .add("expiryTimestampMs", LongType, nullable = false)
+
+  private val keySchemaForSecIndex: StructType = new StructType()
+    .add("expiryTimestampMs", BinaryType, nullable = false)
+    .add("key", BinaryType)
+
+  private val schemaForValueRow: StructType =
+    StructType(Array(StructField("__dummy__", NullType)))
+
+  private val keySerializer = keyExprEnc.createSerializer()
+
+  private val prefixKeyEncoder = UnsafeProjection.create(schemaForPrefixKey)
+
+  private val keyEncoder = UnsafeProjection.create(schemaForKeyRow)
+
+  private val secIndexKeyEncoder = UnsafeProjection.create(keySchemaForSecIndex)
+
+  val timerCfName = if (timeoutMode == TimeoutMode.ProcessingTime) {
+    TimerStateUtils.PROC_TIMERS_STATE_NAME
+  } else {
+    TimerStateUtils.EVENT_TIMERS_STATE_NAME
+  }
+
+  val keyToTsCFName = timerCfName + TimerStateUtils.KEY_TO_TIMESTAMP_CF
+  store.createColFamilyIfAbsent(keyToTsCFName,
+    schemaForKeyRow, numColsPrefixKey = 1,
+    schemaForValueRow, useMultipleValuesPerKey = false,
+    isInternal = true)
+
+  val tsToKeyCFName = timerCfName + TimerStateUtils.TIMESTAMP_TO_KEY_CF
+  store.createColFamilyIfAbsent(tsToKeyCFName,
+    keySchemaForSecIndex, numColsPrefixKey = 0,
+    schemaForValueRow, useMultipleValuesPerKey = false,
+    isInternal = true)
+
+  private def encodeKey(expiryTimestampMs: Long): UnsafeRow = {
+    val keyOption = ImplicitGroupingKeyTracker.getImplicitKeyOption
+    if (!keyOption.isDefined) {
+      throw StateStoreErrors.implicitKeyNotFound(keyToTsCFName)
+    }
+
+    val keyByteArr = keySerializer.apply(keyOption.get).asInstanceOf[UnsafeRow].getBytes()
+    val keyRow = keyEncoder(InternalRow(keyByteArr, expiryTimestampMs))
+    keyRow
+  }
+
+  //  We maintain a secondary index that inverts the ordering of the timestamp
+  //  and grouping key and maintains the list of (expiry) timestamps in sorted order
+  //  (using BIG_ENDIAN encoding) within RocksDB.
+  //  This is because RocksDB uses byte-wise comparison using the default comparator to
+  //  determine sorted order of keys. This is used to read expired timers at any given
+  //  processing time/event time timestamp threshold by performing a range scan.
+  private def encodeSecIndexKey(expiryTimestampMs: Long): UnsafeRow = {
+    val keyOption = ImplicitGroupingKeyTracker.getImplicitKeyOption
+    if (!keyOption.isDefined) {
+      throw StateStoreErrors.implicitKeyNotFound(tsToKeyCFName)
+    }
+
+    val keyByteArr = keySerializer.apply(keyOption.get).asInstanceOf[UnsafeRow].getBytes()
+    val bbuf = ByteBuffer.allocate(8)
+    bbuf.order(ByteOrder.BIG_ENDIAN)
+    bbuf.putLong(expiryTimestampMs)
+    val keyRow = secIndexKeyEncoder(InternalRow(bbuf.array(), keyByteArr))
+    keyRow
+  }
+
+  /**
+   * Function to check if the timer for the given key and timestamp is already registered
+   * @param expiryTimestampMs - expiry timestamp of the timer
+   * @return - true if the timer is already registered, false otherwise
+   */
+  def exists(expiryTimestampMs: Long): Boolean = {
+    getImpl(expiryTimestampMs) != null
+  }
+
+  private def getImpl(expiryTimestampMs: Long): UnsafeRow = {
+    store.get(encodeKey(expiryTimestampMs), keyToTsCFName)
+  }
+
+  /**
+   * Function to add a new timer for the given key and timestamp
+   * @param expiryTimestampMs - expiry timestamp of the timer
+   * @param newState = boolean value to be stored for the state value
+   */
+  def add(expiryTimestampMs: Long, newState: S): Unit = {
+    store.put(encodeKey(expiryTimestampMs), EMPTY_ROW, keyToTsCFName)
+    store.put(encodeSecIndexKey(expiryTimestampMs), EMPTY_ROW, tsToKeyCFName)
+  }
+
+  /**
+   * Function to remove the timer for the given key and timestamp
+   * @param expiryTimestampMs - expiry timestamp of the timer
+   */
+  def remove(expiryTimestampMs: Long): Unit = {
+    store.remove(encodeKey(expiryTimestampMs), keyToTsCFName)
+    store.remove(encodeSecIndexKey(expiryTimestampMs), tsToKeyCFName)
+  }
+
+  def listTimers(): Iterator[Long] = {
+    val keyOption = ImplicitGroupingKeyTracker.getImplicitKeyOption
+    if (!keyOption.isDefined) {
+      throw StateStoreErrors.implicitKeyNotFound(keyToTsCFName)
+    }
+
+    val keyByteArr = keySerializer.apply(keyOption.get).asInstanceOf[UnsafeRow].getBytes()
+    val keyRow = prefixKeyEncoder(InternalRow(keyByteArr))
+    val iter = store.prefixScan(keyRow, keyToTsCFName)
+    iter.map { kv =>
+      val keyRow = kv.key
+      keyRow.getLong(1)
+    }
+  }
+
+  private def getTimerRow(keyRow: UnsafeRow): (Any, Long) = {

Review Comment:
   nit: maybe more specific in the name that it's decode from tsToKey? In later day, someone can pass keyRow from keyToTs and mess up (crash in runtime).



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala:
##########
@@ -163,6 +249,16 @@ case class TransformWithStateExec(
   override protected def doExecute(): RDD[InternalRow] = {
     metrics // force lazy init at driver
 
+    timeoutMode match {
+      case ProcessingTime =>
+        require(batchTimestampMs.nonEmpty)
+
+      case EventTime =>
+        require(eventTimeWatermarkForEviction.nonEmpty)

Review Comment:
   ditto



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TimerStateImpl.scala:
##########
@@ -0,0 +1,224 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.streaming
+
+import java.io.Serializable
+import java.nio.{ByteBuffer, ByteOrder}
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.execution.streaming.state._
+import org.apache.spark.sql.streaming.TimeoutMode
+import org.apache.spark.sql.types._
+import org.apache.spark.util.NextIterator
+
+/**
+ * Singleton utils class used primarily while interacting with TimerState
+ */
+object TimerStateUtils {
+  case class TimestampWithKey(
+      key: Any,
+      expiryTimestampMs: Long) extends Serializable
+
+  val PROC_TIMERS_STATE_NAME = "_procTimers"
+  val EVENT_TIMERS_STATE_NAME = "_eventTimers"
+  val KEY_TO_TIMESTAMP_CF = "_keyToTimestamp"
+  val TIMESTAMP_TO_KEY_CF = "_timestampToKey"
+}
+
+/**
+ * Class that provides the implementation for storing timers
+ * used within the `transformWithState` operator.
+ * @param store - state store to be used for storing timer data
+ * @param timeoutMode - mode of timeout (event time or processing time)
+ * @param keyExprEnc - encoder for key expression
+ * @tparam S - type of timer value
+ */
+class TimerStateImpl[S](
+    store: StateStore,
+    timeoutMode: TimeoutMode,
+    keyExprEnc: ExpressionEncoder[Any]) extends Logging {
+
+  private val EMPTY_ROW =
+    UnsafeProjection.create(Array[DataType](NullType)).apply(InternalRow.apply(null))
+
+  private val schemaForPrefixKey: StructType = new StructType()
+    .add("key", BinaryType)
+
+  private val schemaForKeyRow: StructType = new StructType()
+    .add("key", BinaryType)
+    .add("expiryTimestampMs", LongType, nullable = false)
+
+  private val keySchemaForSecIndex: StructType = new StructType()
+    .add("expiryTimestampMs", BinaryType, nullable = false)
+    .add("key", BinaryType)
+
+  private val schemaForValueRow: StructType =
+    StructType(Array(StructField("__dummy__", NullType)))
+
+  private val keySerializer = keyExprEnc.createSerializer()
+
+  private val prefixKeyEncoder = UnsafeProjection.create(schemaForPrefixKey)
+
+  private val keyEncoder = UnsafeProjection.create(schemaForKeyRow)
+
+  private val secIndexKeyEncoder = UnsafeProjection.create(keySchemaForSecIndex)
+
+  val timerCfName = if (timeoutMode == TimeoutMode.ProcessingTime) {
+    TimerStateUtils.PROC_TIMERS_STATE_NAME
+  } else {
+    TimerStateUtils.EVENT_TIMERS_STATE_NAME
+  }
+
+  val keyToTsCFName = timerCfName + TimerStateUtils.KEY_TO_TIMESTAMP_CF
+  store.createColFamilyIfAbsent(keyToTsCFName,
+    schemaForKeyRow, numColsPrefixKey = 1,
+    schemaForValueRow, useMultipleValuesPerKey = false,
+    isInternal = true)
+
+  val tsToKeyCFName = timerCfName + TimerStateUtils.TIMESTAMP_TO_KEY_CF
+  store.createColFamilyIfAbsent(tsToKeyCFName,
+    keySchemaForSecIndex, numColsPrefixKey = 0,
+    schemaForValueRow, useMultipleValuesPerKey = false,
+    isInternal = true)
+
+  private def encodeKey(expiryTimestampMs: Long): UnsafeRow = {
+    val keyOption = ImplicitGroupingKeyTracker.getImplicitKeyOption

Review Comment:
   nit: shall we extract out "getting implicit key as byte array" to the private method? It appears three times with only diff of informative parameter in error class.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala:
##########
@@ -163,6 +249,16 @@ case class TransformWithStateExec(
   override protected def doExecute(): RDD[InternalRow] = {
     metrics // force lazy init at driver
 
+    timeoutMode match {
+      case ProcessingTime =>
+        require(batchTimestampMs.nonEmpty)

Review Comment:
   nit: internal error? I'm OK to use require and expect that surrounding code will convert the exception, but probably better to provide the message as well.



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala:
##########
@@ -535,6 +535,22 @@ class RocksDBSuite extends AlsoTestWithChangelogCheckpointingEnabled with Shared
     }
   }
 
+  testWithColumnFamilies(s"RocksDB: column family creation with invalid names",

Review Comment:
   Shall we add test about non-internal request starting with `_` as CF name?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org