You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by td...@apache.org on 2017/10/05 02:25:25 UTC

spark git commit: [SPARK-22187][SS] Update unsaferow format for saved state such that we can set timeouts when state is null

Repository: spark
Updated Branches:
  refs/heads/master bb035f1ee -> 969ffd631


[SPARK-22187][SS] Update unsaferow format for saved state such that we can set timeouts when state is null

## What changes were proposed in this pull request?

Currently, the group state of user-defined-type is encoded as top-level columns in the UnsafeRows stores in the state store. The timeout timestamp is also saved as (when needed) as the last top-level column. Since the group state is serialized to top-level columns, you cannot save "null" as a value of state (setting null in all the top-level columns is not equivalent). So we don't let the user set the timeout without initializing the state for a key. Based on user experience, this leads to confusion.

This PR is to change the row format such that the state is saved as nested columns. This would allow the state to be set to null, and avoid these confusing corner cases.

## How was this patch tested?
Refactored tests.

Author: Tathagata Das <ta...@gmail.com>

Closes #19416 from tdas/SPARK-22187.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/969ffd63
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/969ffd63
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/969ffd63

Branch: refs/heads/master
Commit: 969ffd631746125eb2b83722baf6f6e7ddd2092c
Parents: bb035f1
Author: Tathagata Das <ta...@gmail.com>
Authored: Wed Oct 4 19:25:22 2017 -0700
Committer: Tathagata Das <ta...@gmail.com>
Committed: Wed Oct 4 19:25:22 2017 -0700

----------------------------------------------------------------------
 .../streaming/FlatMapGroupsWithStateExec.scala  | 133 +++-------------
 .../FlatMapGroupsWithState_StateManager.scala   | 153 +++++++++++++++++++
 .../streaming/FlatMapGroupsWithStateSuite.scala | 130 ++++++++--------
 3 files changed, 246 insertions(+), 170 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/969ffd63/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
index ab690fd..aab06d6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
@@ -23,10 +23,8 @@ import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Attribut
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution}
 import org.apache.spark.sql.execution._
-import org.apache.spark.sql.execution.streaming.GroupStateImpl.NO_TIMESTAMP
 import org.apache.spark.sql.execution.streaming.state._
 import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode}
-import org.apache.spark.sql.types.IntegerType
 import org.apache.spark.util.CompletionIterator
 
 /**
@@ -62,26 +60,7 @@ case class FlatMapGroupsWithStateExec(
   import GroupStateImpl._
 
   private val isTimeoutEnabled = timeoutConf != NoTimeout
-  private val timestampTimeoutAttribute =
-    AttributeReference("timeoutTimestamp", dataType = IntegerType, nullable = false)()
-  private val stateAttributes: Seq[Attribute] = {
-    val encSchemaAttribs = stateEncoder.schema.toAttributes
-    if (isTimeoutEnabled) encSchemaAttribs :+ timestampTimeoutAttribute else encSchemaAttribs
-  }
-  // Get the serializer for the state, taking into account whether we need to save timestamps
-  private val stateSerializer = {
-    val encoderSerializer = stateEncoder.namedExpressions
-    if (isTimeoutEnabled) {
-      encoderSerializer :+ Literal(GroupStateImpl.NO_TIMESTAMP)
-    } else {
-      encoderSerializer
-    }
-  }
-  // Get the deserializer for the state. Note that this must be done in the driver, as
-  // resolving and binding of deserializer expressions to the encoded type can be safely done
-  // only in the driver.
-  private val stateDeserializer = stateEncoder.resolveAndBind().deserializer
-
+  val stateManager = new FlatMapGroupsWithState_StateManager(stateEncoder, isTimeoutEnabled)
 
   /** Distribute by grouping attributes */
   override def requiredChildDistribution: Seq[Distribution] =
@@ -109,11 +88,11 @@ case class FlatMapGroupsWithStateExec(
     child.execute().mapPartitionsWithStateStore[InternalRow](
       getStateInfo,
       groupingAttributes.toStructType,
-      stateAttributes.toStructType,
+      stateManager.stateSchema,
       indexOrdinal = None,
       sqlContext.sessionState,
       Some(sqlContext.streams.stateStoreCoordinator)) { case (store, iter) =>
-        val updater = new StateStoreUpdater(store)
+        val processor = new InputProcessor(store)
 
         // If timeout is based on event time, then filter late data based on watermark
         val filteredIter = watermarkPredicateForData match {
@@ -128,7 +107,7 @@ case class FlatMapGroupsWithStateExec(
         // all the data has been processed. This is to ensure that the timeout information of all
         // the keys with data is updated before they are processed for timeouts.
         val outputIterator =
-          updater.updateStateForKeysWithData(filteredIter) ++ updater.updateStateForTimedOutKeys()
+          processor.processNewData(filteredIter) ++ processor.processTimedOutState()
 
         // Return an iterator of all the rows generated by all the keys, such that when fully
         // consumed, all the state updates will be committed by the state store
@@ -143,7 +122,7 @@ case class FlatMapGroupsWithStateExec(
   }
 
   /** Helper class to update the state store */
-  class StateStoreUpdater(store: StateStore) {
+  class InputProcessor(store: StateStore) {
 
     // Converters for translating input keys, values, output data between rows and Java objects
     private val getKeyObj =
@@ -152,14 +131,6 @@ case class FlatMapGroupsWithStateExec(
       ObjectOperator.deserializeRowToObject(valueDeserializer, dataAttributes)
     private val getOutputRow = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType)
 
-    // Converters for translating state between rows and Java objects
-    private val getStateObjFromRow = ObjectOperator.deserializeRowToObject(
-      stateDeserializer, stateAttributes)
-    private val getStateRowFromObj = ObjectOperator.serializeObjectToRow(stateSerializer)
-
-    // Index of the additional metadata fields in the state row
-    private val timeoutTimestampIndex = stateAttributes.indexOf(timestampTimeoutAttribute)
-
     // Metrics
     private val numUpdatedStateRows = longMetric("numUpdatedStateRows")
     private val numOutputRows = longMetric("numOutputRows")
@@ -168,20 +139,19 @@ case class FlatMapGroupsWithStateExec(
      * For every group, get the key, values and corresponding state and call the function,
      * and return an iterator of rows
      */
-    def updateStateForKeysWithData(dataIter: Iterator[InternalRow]): Iterator[InternalRow] = {
+    def processNewData(dataIter: Iterator[InternalRow]): Iterator[InternalRow] = {
       val groupedIter = GroupedIterator(dataIter, groupingAttributes, child.output)
       groupedIter.flatMap { case (keyRow, valueRowIter) =>
         val keyUnsafeRow = keyRow.asInstanceOf[UnsafeRow]
         callFunctionAndUpdateState(
-          keyUnsafeRow,
+          stateManager.getState(store, keyUnsafeRow),
           valueRowIter,
-          store.get(keyUnsafeRow),
           hasTimedOut = false)
       }
     }
 
     /** Find the groups that have timeout set and are timing out right now, and call the function */
-    def updateStateForTimedOutKeys(): Iterator[InternalRow] = {
+    def processTimedOutState(): Iterator[InternalRow] = {
       if (isTimeoutEnabled) {
         val timeoutThreshold = timeoutConf match {
           case ProcessingTimeTimeout => batchTimestampMs.get
@@ -190,12 +160,11 @@ case class FlatMapGroupsWithStateExec(
             throw new IllegalStateException(
               s"Cannot filter timed out keys for $timeoutConf")
         }
-        val timingOutKeys = store.getRange(None, None).filter { rowPair =>
-          val timeoutTimestamp = getTimeoutTimestamp(rowPair.value)
-          timeoutTimestamp != NO_TIMESTAMP && timeoutTimestamp < timeoutThreshold
+        val timingOutKeys = stateManager.getAllState(store).filter { state =>
+          state.timeoutTimestamp != NO_TIMESTAMP && state.timeoutTimestamp < timeoutThreshold
         }
-        timingOutKeys.flatMap { rowPair =>
-          callFunctionAndUpdateState(rowPair.key, Iterator.empty, rowPair.value, hasTimedOut = true)
+        timingOutKeys.flatMap { stateData =>
+          callFunctionAndUpdateState(stateData, Iterator.empty, hasTimedOut = true)
         }
       } else Iterator.empty
     }
@@ -205,72 +174,43 @@ case class FlatMapGroupsWithStateExec(
      * iterator. Note that the store updating is lazy, that is, the store will be updated only
      * after the returned iterator is fully consumed.
      *
-     * @param keyRow Row representing the key, cannot be null
+     * @param stateData All the data related to the state to be updated
      * @param valueRowIter Iterator of values as rows, cannot be null, but can be empty
-     * @param prevStateRow Row representing the previous state, can be null
      * @param hasTimedOut Whether this function is being called for a key timeout
      */
     private def callFunctionAndUpdateState(
-        keyRow: UnsafeRow,
+        stateData: FlatMapGroupsWithState_StateData,
         valueRowIter: Iterator[InternalRow],
-        prevStateRow: UnsafeRow,
         hasTimedOut: Boolean): Iterator[InternalRow] = {
 
-      val keyObj = getKeyObj(keyRow)  // convert key to objects
+      val keyObj = getKeyObj(stateData.keyRow)  // convert key to objects
       val valueObjIter = valueRowIter.map(getValueObj.apply) // convert value rows to objects
-      val stateObj = getStateObj(prevStateRow)
-      val keyedState = GroupStateImpl.createForStreaming(
-        Option(stateObj),
+      val groupState = GroupStateImpl.createForStreaming(
+        Option(stateData.stateObj),
         batchTimestampMs.getOrElse(NO_TIMESTAMP),
         eventTimeWatermark.getOrElse(NO_TIMESTAMP),
         timeoutConf,
         hasTimedOut)
 
       // Call function, get the returned objects and convert them to rows
-      val mappedIterator = func(keyObj, valueObjIter, keyedState).map { obj =>
+      val mappedIterator = func(keyObj, valueObjIter, groupState).map { obj =>
         numOutputRows += 1
         getOutputRow(obj)
       }
 
       // When the iterator is consumed, then write changes to state
       def onIteratorCompletion: Unit = {
-
-        val currentTimeoutTimestamp = keyedState.getTimeoutTimestamp
-        // If the state has not yet been set but timeout has been set, then
-        // we have to generate a row to save the timeout. However, attempting serialize
-        // null using case class encoder throws -
-        //    java.lang.NullPointerException: Null value appeared in non-nullable field:
-        //    If the schema is inferred from a Scala tuple / case class, or a Java bean, please
-        //    try to use scala.Option[_] or other nullable types.
-        if (!keyedState.exists && currentTimeoutTimestamp != NO_TIMESTAMP) {
-          throw new IllegalStateException(
-            "Cannot set timeout when state is not defined, that is, state has not been" +
-              "initialized or has been removed")
-        }
-
-        if (keyedState.hasRemoved) {
-          store.remove(keyRow)
+        if (groupState.hasRemoved && groupState.getTimeoutTimestamp == NO_TIMESTAMP) {
+          stateManager.removeState(store, stateData.keyRow)
           numUpdatedStateRows += 1
-
         } else {
-          val previousTimeoutTimestamp = getTimeoutTimestamp(prevStateRow)
-          val stateRowToWrite = if (keyedState.hasUpdated) {
-            getStateRow(keyedState.get)
-          } else {
-            prevStateRow
-          }
-
-          val hasTimeoutChanged = currentTimeoutTimestamp != previousTimeoutTimestamp
-          val shouldWriteState = keyedState.hasUpdated || hasTimeoutChanged
+          val currentTimeoutTimestamp = groupState.getTimeoutTimestamp
+          val hasTimeoutChanged = currentTimeoutTimestamp != stateData.timeoutTimestamp
+          val shouldWriteState = groupState.hasUpdated || groupState.hasRemoved || hasTimeoutChanged
 
           if (shouldWriteState) {
-            if (stateRowToWrite == null) {
-              // This should never happen because checks in GroupStateImpl should avoid cases
-              // where empty state would need to be written
-              throw new IllegalStateException("Attempting to write empty state")
-            }
-            setTimeoutTimestamp(stateRowToWrite, currentTimeoutTimestamp)
-            store.put(keyRow, stateRowToWrite)
+            val updatedStateObj = if (groupState.exists) groupState.get else null
+            stateManager.putState(store, stateData.keyRow, updatedStateObj, currentTimeoutTimestamp)
             numUpdatedStateRows += 1
           }
         }
@@ -279,28 +219,5 @@ case class FlatMapGroupsWithStateExec(
       // Return an iterator of rows such that fully consumed, the updated state value will be saved
       CompletionIterator[InternalRow, Iterator[InternalRow]](mappedIterator, onIteratorCompletion)
     }
-
-    /** Returns the state as Java object if defined */
-    def getStateObj(stateRow: UnsafeRow): Any = {
-      if (stateRow != null) getStateObjFromRow(stateRow) else null
-    }
-
-    /** Returns the row for an updated state */
-    def getStateRow(obj: Any): UnsafeRow = {
-      assert(obj != null)
-      getStateRowFromObj(obj)
-    }
-
-    /** Returns the timeout timestamp of a state row is set */
-    def getTimeoutTimestamp(stateRow: UnsafeRow): Long = {
-      if (isTimeoutEnabled && stateRow != null) {
-        stateRow.getLong(timeoutTimestampIndex)
-      } else NO_TIMESTAMP
-    }
-
-    /** Set the timestamp in a state row */
-    def setTimeoutTimestamp(stateRow: UnsafeRow, timeoutTimestamps: Long): Unit = {
-      if (isTimeoutEnabled) stateRow.setLong(timeoutTimestampIndex, timeoutTimestamps)
-    }
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/969ffd63/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithState_StateManager.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithState_StateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithState_StateManager.scala
new file mode 100644
index 0000000..d077836
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithState_StateManager.scala
@@ -0,0 +1,153 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.state
+
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, BoundReference, CaseWhen, CreateNamedStruct, GetStructField, IsNull, Literal, UnsafeRow}
+import org.apache.spark.sql.execution.ObjectOperator
+import org.apache.spark.sql.execution.streaming.GroupStateImpl
+import org.apache.spark.sql.execution.streaming.GroupStateImpl.NO_TIMESTAMP
+import org.apache.spark.sql.types.{IntegerType, LongType, StructType}
+
+
+/**
+ * Class to serialize/write/read/deserialize state for
+ * [[org.apache.spark.sql.execution.streaming.FlatMapGroupsWithStateExec]].
+ */
+class FlatMapGroupsWithState_StateManager(
+    stateEncoder: ExpressionEncoder[Any],
+    shouldStoreTimestamp: Boolean) extends Serializable {
+
+  /** Schema of the state rows saved in the state store */
+  val stateSchema = {
+    val schema = new StructType().add("groupState", stateEncoder.schema, nullable = true)
+    if (shouldStoreTimestamp) schema.add("timeoutTimestamp", LongType) else schema
+  }
+
+  /** Get deserialized state and corresponding timeout timestamp for a key */
+  def getState(store: StateStore, keyRow: UnsafeRow): FlatMapGroupsWithState_StateData = {
+    val stateRow = store.get(keyRow)
+    stateDataForGets.withNew(
+      keyRow, stateRow, getStateObj(stateRow), getTimestamp(stateRow))
+  }
+
+  /** Put state and timeout timestamp for a key */
+  def putState(store: StateStore, keyRow: UnsafeRow, state: Any, timestamp: Long): Unit = {
+    val stateRow = getStateRow(state)
+    setTimestamp(stateRow, timestamp)
+    store.put(keyRow, stateRow)
+  }
+
+  /** Removed all information related to a key */
+  def removeState(store: StateStore, keyRow: UnsafeRow): Unit = {
+    store.remove(keyRow)
+  }
+
+  /** Get all the keys and corresponding state rows in the state store */
+  def getAllState(store: StateStore): Iterator[FlatMapGroupsWithState_StateData] = {
+    val stateDataForGetAllState = FlatMapGroupsWithState_StateData()
+    store.getRange(None, None).map { pair =>
+      stateDataForGetAllState.withNew(
+        pair.key, pair.value, getStateObjFromRow(pair.value), getTimestamp(pair.value))
+    }
+  }
+
+  // Ordinals of the information stored in the state row
+  private lazy val nestedStateOrdinal = 0
+  private lazy val timeoutTimestampOrdinal = 1
+
+  // Get the serializer for the state, taking into account whether we need to save timestamps
+  private val stateSerializer = {
+    val nestedStateExpr = CreateNamedStruct(
+      stateEncoder.namedExpressions.flatMap(e => Seq(Literal(e.name), e)))
+    if (shouldStoreTimestamp) {
+      Seq(nestedStateExpr, Literal(GroupStateImpl.NO_TIMESTAMP))
+    } else {
+      Seq(nestedStateExpr)
+    }
+  }
+
+  // Get the deserializer for the state. Note that this must be done in the driver, as
+  // resolving and binding of deserializer expressions to the encoded type can be safely done
+  // only in the driver.
+  private val stateDeserializer = {
+    val boundRefToNestedState = BoundReference(nestedStateOrdinal, stateEncoder.schema, true)
+    val deser = stateEncoder.resolveAndBind().deserializer.transformUp {
+      case BoundReference(ordinal, _, _) => GetStructField(boundRefToNestedState, ordinal)
+    }
+    CaseWhen(Seq(IsNull(boundRefToNestedState) -> Literal(null)), elseValue = deser).toCodegen()
+  }
+
+  // Converters for translating state between rows and Java objects
+  private lazy val getStateObjFromRow = ObjectOperator.deserializeRowToObject(
+    stateDeserializer, stateSchema.toAttributes)
+  private lazy val getStateRowFromObj = ObjectOperator.serializeObjectToRow(stateSerializer)
+
+  // Reusable instance for returning state information
+  private lazy val stateDataForGets = FlatMapGroupsWithState_StateData()
+
+  /** Returns the state as Java object if defined */
+  private def getStateObj(stateRow: UnsafeRow): Any = {
+    if (stateRow == null) null
+    else getStateObjFromRow(stateRow)
+  }
+
+  /** Returns the row for an updated state */
+  private def getStateRow(obj: Any): UnsafeRow = {
+    val row = getStateRowFromObj(obj)
+    if (obj == null) {
+      row.setNullAt(nestedStateOrdinal)
+    }
+    row
+  }
+
+  /** Returns the timeout timestamp of a state row is set */
+  private def getTimestamp(stateRow: UnsafeRow): Long = {
+    if (shouldStoreTimestamp && stateRow != null) {
+      stateRow.getLong(timeoutTimestampOrdinal)
+    } else NO_TIMESTAMP
+  }
+
+  /** Set the timestamp in a state row */
+  private def setTimestamp(stateRow: UnsafeRow, timeoutTimestamps: Long): Unit = {
+    if (shouldStoreTimestamp) stateRow.setLong(timeoutTimestampOrdinal, timeoutTimestamps)
+  }
+}
+
+/**
+ * Class to capture deserialized state and timestamp return by the state manager.
+ * This is intended for reuse.
+ */
+case class FlatMapGroupsWithState_StateData(
+    var keyRow: UnsafeRow = null,
+    var stateRow: UnsafeRow = null,
+    var stateObj: Any = null,
+    var timeoutTimestamp: Long = -1) {
+  def withNew(
+      newKeyRow: UnsafeRow,
+      newStateRow: UnsafeRow,
+      newStateObj: Any,
+      newTimeout: Long): this.type = {
+    keyRow = newKeyRow
+    stateRow = newStateRow
+    stateObj = newStateObj
+    timeoutTimestamp = newTimeout
+    this
+  }
+}
+

http://git-wip-us.apache.org/repos/asf/spark/blob/969ffd63/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
index 9d74a5c..d2e8beb 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
@@ -289,13 +289,13 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
     }
   }
 
-  // Values used for testing StateStoreUpdater
+  // Values used for testing InputProcessor
   val currentBatchTimestamp = 1000
   val currentBatchWatermark = 1000
   val beforeTimeoutThreshold = 999
   val afterTimeoutThreshold = 1001
 
-  // Tests for StateStoreUpdater.updateStateForKeysWithData() when timeout = NoTimeout
+  // Tests for InputProcessor.processNewData() when timeout = NoTimeout
   for (priorState <- Seq(None, Some(0))) {
     val priorStateStr = if (priorState.nonEmpty) "prior state set" else "no prior state"
     val testName = s"NoTimeout - $priorStateStr - "
@@ -322,7 +322,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
       expectedState = None)        // should be removed
   }
 
-  // Tests for StateStoreUpdater.updateStateForKeysWithData() when timeout != NoTimeout
+  // Tests for InputProcessor.processTimedOutState() when timeout != NoTimeout
   for (priorState <- Seq(None, Some(0))) {
     for (priorTimeoutTimestamp <- Seq(NO_TIMESTAMP, 1000)) {
       var testName = ""
@@ -365,6 +365,18 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
           expectedState = None)                                 // state should be removed
       }
 
+      // Tests with ProcessingTimeTimeout
+      if (priorState == None) {
+        testStateUpdateWithData(
+          s"ProcessingTimeTimeout - $testName - timeout updated without initializing state",
+          stateUpdates = state => { state.setTimeoutDuration(5000) },
+          timeoutConf = ProcessingTimeTimeout,
+          priorState = None,
+          priorTimeoutTimestamp = priorTimeoutTimestamp,
+          expectedState = None,
+          expectedTimeoutTimestamp = currentBatchTimestamp + 5000)
+      }
+
       testStateUpdateWithData(
         s"ProcessingTimeTimeout - $testName - state and timeout duration updated",
         stateUpdates =
@@ -376,9 +388,35 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
         expectedTimeoutTimestamp = currentBatchTimestamp + 5000) // timestamp should change
 
       testStateUpdateWithData(
+        s"ProcessingTimeTimeout - $testName - timeout updated after state removed",
+        stateUpdates = state => { state.remove(); state.setTimeoutDuration(5000) },
+        timeoutConf = ProcessingTimeTimeout,
+        priorState = priorState,
+        priorTimeoutTimestamp = priorTimeoutTimestamp,
+        expectedState = None,
+        expectedTimeoutTimestamp = currentBatchTimestamp + 5000)
+
+      // Tests with EventTimeTimeout
+
+      if (priorState == None) {
+        testStateUpdateWithData(
+          s"EventTimeTimeout - $testName - setting timeout without init state not allowed",
+          stateUpdates = state => {
+            state.setTimeoutTimestamp(10000)
+          },
+          timeoutConf = EventTimeTimeout,
+          priorState = None,
+          priorTimeoutTimestamp = priorTimeoutTimestamp,
+          expectedState = None,
+          expectedTimeoutTimestamp = 10000)
+      }
+
+      testStateUpdateWithData(
         s"EventTimeTimeout - $testName - state and timeout timestamp updated",
         stateUpdates =
-          (state: GroupState[Int]) => { state.update(5); state.setTimeoutTimestamp(5000) },
+          (state: GroupState[Int]) => {
+            state.update(5); state.setTimeoutTimestamp(5000)
+          },
         timeoutConf = EventTimeTimeout,
         priorState = priorState,
         priorTimeoutTimestamp = priorTimeoutTimestamp,
@@ -397,50 +435,23 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
         timeoutConf = EventTimeTimeout,
         priorState = priorState,
         priorTimeoutTimestamp = priorTimeoutTimestamp,
-        expectedState = Some(5),                                 // state should change
-        expectedTimeoutTimestamp = NO_TIMESTAMP)                 // timestamp should not update
-    }
-  }
-
-  // Currently disallowed cases for StateStoreUpdater.updateStateForKeysWithData(),
-  // Try to remove these cases in the future
-  for (priorTimeoutTimestamp <- Seq(NO_TIMESTAMP, 1000)) {
-    val testName =
-      if (priorTimeoutTimestamp != NO_TIMESTAMP) "prior timeout set" else "no prior timeout"
-    testStateUpdateWithData(
-      s"ProcessingTimeTimeout - $testName - setting timeout without init state not allowed",
-      stateUpdates = state => { state.setTimeoutDuration(5000) },
-      timeoutConf = ProcessingTimeTimeout,
-      priorState = None,
-      priorTimeoutTimestamp = priorTimeoutTimestamp,
-      expectedException = classOf[IllegalStateException])
-
-    testStateUpdateWithData(
-      s"ProcessingTimeTimeout - $testName - setting timeout with state removal not allowed",
-      stateUpdates = state => { state.remove(); state.setTimeoutDuration(5000) },
-      timeoutConf = ProcessingTimeTimeout,
-      priorState = Some(5),
-      priorTimeoutTimestamp = priorTimeoutTimestamp,
-      expectedException = classOf[IllegalStateException])
-
-    testStateUpdateWithData(
-      s"EventTimeTimeout - $testName - setting timeout without init state not allowed",
-      stateUpdates = state => { state.setTimeoutTimestamp(10000) },
-      timeoutConf = EventTimeTimeout,
-      priorState = None,
-      priorTimeoutTimestamp = priorTimeoutTimestamp,
-      expectedException = classOf[IllegalStateException])
+        expectedState = Some(5), // state should change
+        expectedTimeoutTimestamp = NO_TIMESTAMP) // timestamp should not update
 
-    testStateUpdateWithData(
-      s"EventTimeTimeout - $testName - setting timeout with state removal not allowed",
-      stateUpdates = state => { state.remove(); state.setTimeoutTimestamp(10000) },
-      timeoutConf = EventTimeTimeout,
-      priorState = Some(5),
-      priorTimeoutTimestamp = priorTimeoutTimestamp,
-      expectedException = classOf[IllegalStateException])
+      testStateUpdateWithData(
+        s"EventTimeTimeout - $testName - setting timeout with state removal not allowed",
+        stateUpdates = state => {
+          state.remove(); state.setTimeoutTimestamp(10000)
+        },
+        timeoutConf = EventTimeTimeout,
+        priorState = priorState,
+        priorTimeoutTimestamp = priorTimeoutTimestamp,
+        expectedState = None,
+        expectedTimeoutTimestamp = 10000)
+    }
   }
 
-  // Tests for StateStoreUpdater.updateStateForTimedOutKeys()
+  // Tests for InputProcessor.processTimedOutState()
   val preTimeoutState = Some(5)
   for (timeoutConf <- Seq(ProcessingTimeTimeout, EventTimeTimeout)) {
     testStateUpdateWithTimeout(
@@ -924,7 +935,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
     if (priorState.isEmpty && priorTimeoutTimestamp != NO_TIMESTAMP) {
       return // there can be no prior timestamp, when there is no prior state
     }
-    test(s"StateStoreUpdater - updates with data - $testName") {
+    test(s"InputProcessor - process new data - $testName") {
       val mapGroupsFunc = (key: Int, values: Iterator[Int], state: GroupState[Int]) => {
         assert(state.hasTimedOut === false, "hasTimedOut not false")
         assert(values.nonEmpty, "Some value is expected")
@@ -946,7 +957,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
       expectedState: Option[Int],
       expectedTimeoutTimestamp: Long = NO_TIMESTAMP): Unit = {
 
-    test(s"StateStoreUpdater - updates for timeout - $testName") {
+    test(s"InputProcessor - process timed out state - $testName") {
       val mapGroupsFunc = (key: Int, values: Iterator[Int], state: GroupState[Int]) => {
         assert(state.hasTimedOut === true, "hasTimedOut not true")
         assert(values.isEmpty, "values not empty")
@@ -973,21 +984,20 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
     val store = newStateStore()
     val mapGroupsSparkPlan = newFlatMapGroupsWithStateExec(
       mapGroupsFunc, timeoutConf, currentBatchTimestamp)
-    val updater = new mapGroupsSparkPlan.StateStoreUpdater(store)
+    val inputProcessor = new mapGroupsSparkPlan.InputProcessor(store)
+    val stateManager = mapGroupsSparkPlan.stateManager
     val key = intToRow(0)
     // Prepare store with prior state configs
-    if (priorState.nonEmpty) {
-      val row = updater.getStateRow(priorState.get)
-      updater.setTimeoutTimestamp(row, priorTimeoutTimestamp)
-      store.put(key.copy(), row.copy())
+    if (priorState.nonEmpty || priorTimeoutTimestamp != NO_TIMESTAMP) {
+      stateManager.putState(store, key, priorState.orNull, priorTimeoutTimestamp)
     }
 
     // Call updating function to update state store
     def callFunction() = {
       val returnedIter = if (testTimeoutUpdates) {
-        updater.updateStateForTimedOutKeys()
+        inputProcessor.processTimedOutState()
       } else {
-        updater.updateStateForKeysWithData(Iterator(key))
+        inputProcessor.processNewData(Iterator(key))
       }
       returnedIter.size // consume the iterator to force state updates
     }
@@ -998,15 +1008,11 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
     } else {
       // Call function to update and verify updated state in store
       callFunction()
-      val updatedStateRow = store.get(key)
-      assert(
-        Option(updater.getStateObj(updatedStateRow)).map(_.toString.toInt) === expectedState,
+      val updatedState = stateManager.getState(store, key)
+      assert(Option(updatedState.stateObj).map(_.toString.toInt) === expectedState,
         "final state not as expected")
-      if (updatedStateRow != null) {
-        assert(
-          updater.getTimeoutTimestamp(updatedStateRow) === expectedTimeoutTimestamp,
-          "final timeout timestamp not as expected")
-      }
+      assert(updatedState.timeoutTimestamp === expectedTimeoutTimestamp,
+        "final timeout timestamp not as expected")
     }
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org