You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by ji...@apache.org on 2018/12/10 01:08:51 UTC

[flink] branch release-1.7 updated: [FLINK-10543][table] Leverage efficient timer deletion in relational operators

This is an automated email from the ASF dual-hosted git repository.

jincheng pushed a commit to branch release-1.7
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/release-1.7 by this push:
     new 0fa9ec0  [FLINK-10543][table] Leverage efficient timer deletion in relational operators
0fa9ec0 is described below

commit 0fa9ec030d6b102f4d24f0c7f8b58c0fab97fff6
Author: hequn8128 <ch...@gmail.com>
AuthorDate: Wed Oct 24 13:46:26 2018 +0800

    [FLINK-10543][table] Leverage efficient timer deletion in relational operators
    
    This closes #6918
---
 .../table/runtime/aggregate/CleanupState.scala     |  57 ++++
 ...ala => CoProcessFunctionWithCleanupState.scala} |  52 ++--
 .../aggregate/GroupAggProcessFunction.scala        |   4 +-
 .../KeyedProcessFunctionWithCleanupState.scala     |  38 +--
 .../aggregate/ProcTimeBoundedRangeOver.scala       |  29 +-
 .../aggregate/ProcTimeBoundedRowsOver.scala        |   4 +-
 .../runtime/aggregate/ProcTimeUnboundedOver.scala  |   4 +-
 .../ProcessFunctionWithCleanupState.scala          |  41 +--
 .../aggregate/RowTimeBoundedRangeOver.scala        |  20 +-
 .../runtime/aggregate/RowTimeBoundedRowsOver.scala |   8 +-
 .../runtime/aggregate/RowTimeUnboundedOver.scala   |   8 +-
 .../table/runtime/join/NonWindowFullJoin.scala     |   3 +-
 .../NonWindowFullJoinWithNonEquiPredicates.scala   |  23 +-
 .../table/runtime/join/NonWindowInnerJoin.scala    |   3 +-
 .../flink/table/runtime/join/NonWindowJoin.scala   | 104 ++-----
 .../runtime/join/NonWindowLeftRightJoin.scala      |   3 +-
 ...nWindowLeftRightJoinWithNonEquiPredicates.scala |  25 +-
 .../table/runtime/join/NonWindowOuterJoin.scala    |   9 +-
 .../NonWindowOuterJoinWithNonEquiPredicates.scala  |  47 +--
 .../triggers/StateCleaningCountTrigger.scala       |  13 +-
 .../table/runtime/harness/JoinHarnessTest.scala    | 320 +++++++++++----------
 .../StateCleaningCountTriggerHarnessTest.scala     |   7 +-
 .../KeyedProcessFunctionWithCleanupStateTest.scala |  10 +-
 .../ProcessFunctionWithCleanupStateTest.scala      |   4 +-
 24 files changed, 372 insertions(+), 464 deletions(-)

diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/CleanupState.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/CleanupState.scala
new file mode 100644
index 0000000..d9c8e2c
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/CleanupState.scala
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.aggregate
+
+import org.apache.flink.api.common.state.ValueState
+import org.apache.flink.streaming.api.functions.ProcessFunction
+import org.apache.flink.streaming.api.functions.co.CoProcessFunction
+import java.lang.{Long => JLong}
+
+import org.apache.flink.streaming.api.TimerService
+
+/**
+  * Base class for clean up state, both for [[ProcessFunction]] and [[CoProcessFunction]].
+  */
+trait CleanupState {
+
+  def registerProcessingCleanupTimer(
+      cleanupTimeState: ValueState[JLong],
+      currentTime: Long,
+      minRetentionTime: Long,
+      maxRetentionTime: Long,
+      timerService: TimerService): Unit = {
+
+    // last registered timer
+    val curCleanupTime = cleanupTimeState.value()
+
+    // check if a cleanup timer is registered and
+    // that the current cleanup timer won't delete state we need to keep
+    if (curCleanupTime == null || (currentTime + minRetentionTime) > curCleanupTime) {
+      // we need to register a new (later) timer
+      val cleanupTime = currentTime + maxRetentionTime
+      // register timer and remember clean-up time
+      timerService.registerProcessingTimeTimer(cleanupTime)
+      // delete expired timer
+      if (curCleanupTime != null) {
+        timerService.deleteProcessingTimeTimer(curCleanupTime)
+      }
+      cleanupTimeState.update(cleanupTime)
+    }
+  }
+}
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/KeyedProcessFunctionWithCleanupState.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/CoProcessFunctionWithCleanupState.scala
similarity index 57%
copy from flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/KeyedProcessFunctionWithCleanupState.scala
copy to flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/CoProcessFunctionWithCleanupState.scala
index 4d6840a..0c76636 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/KeyedProcessFunctionWithCleanupState.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/CoProcessFunctionWithCleanupState.scala
@@ -15,17 +15,19 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-
 package org.apache.flink.table.runtime.aggregate
 
 import java.lang.{Long => JLong}
+
 import org.apache.flink.api.common.state.{State, ValueState, ValueStateDescriptor}
 import org.apache.flink.streaming.api.TimeDomain
-import org.apache.flink.streaming.api.functions.{KeyedProcessFunction, ProcessFunction}
+import org.apache.flink.streaming.api.functions.co.CoProcessFunction
 import org.apache.flink.table.api.{StreamQueryConfig, Types}
 
-abstract class KeyedProcessFunctionWithCleanupState[K, I, O](queryConfig: StreamQueryConfig)
-  extends KeyedProcessFunction[K, I, O] {
+abstract class CoProcessFunctionWithCleanupState[IN1, IN2, OUT](queryConfig: StreamQueryConfig)
+  extends CoProcessFunction[IN1, IN2, OUT]
+  with CleanupState {
+
   protected val minRetentionTime: Long = queryConfig.getMinIdleStateRetentionTime
   protected val maxRetentionTime: Long = queryConfig.getMaxIdleStateRetentionTime
   protected val stateCleaningEnabled: Boolean = minRetentionTime > 1
@@ -35,29 +37,23 @@ abstract class KeyedProcessFunctionWithCleanupState[K, I, O](queryConfig: Stream
 
   protected def initCleanupTimeState(stateName: String) {
     if (stateCleaningEnabled) {
-      val inputCntDescriptor: ValueStateDescriptor[JLong] =
+      val cleanupTimeDescriptor: ValueStateDescriptor[JLong] =
         new ValueStateDescriptor[JLong](stateName, Types.LONG)
-      cleanupTimeState = getRuntimeContext.getState(inputCntDescriptor)
+      cleanupTimeState = getRuntimeContext.getState(cleanupTimeDescriptor)
     }
   }
 
-  protected def registerProcessingCleanupTimer(
-    ctx: KeyedProcessFunction[K, I, O]#Context,
+  protected def processCleanupTimer(
+    ctx: CoProcessFunction[IN1, IN2, OUT]#Context,
     currentTime: Long): Unit = {
     if (stateCleaningEnabled) {
-
-      // last registered timer
-      val curCleanupTime = cleanupTimeState.value()
-
-      // check if a cleanup timer is registered and
-      // that the current cleanup timer won't delete state we need to keep
-      if (curCleanupTime == null || (currentTime + minRetentionTime) > curCleanupTime) {
-        // we need to register a new (later) timer
-        val cleanupTime = currentTime + maxRetentionTime
-        // register timer and remember clean-up time
-        ctx.timerService().registerProcessingTimeTimer(cleanupTime)
-        cleanupTimeState.update(cleanupTime)
-      }
+      registerProcessingCleanupTimer(
+        cleanupTimeState,
+        currentTime,
+        minRetentionTime,
+        maxRetentionTime,
+        ctx.timerService()
+      )
     }
   }
 
@@ -65,21 +61,9 @@ abstract class KeyedProcessFunctionWithCleanupState[K, I, O](queryConfig: Stream
     ctx.timeDomain() == TimeDomain.PROCESSING_TIME
   }
 
-  protected def needToCleanupState(timestamp: Long): Boolean = {
-    if (stateCleaningEnabled) {
-      val cleanupTime = cleanupTimeState.value()
-      // check that the triggered timer is the last registered processing time timer.
-      null != cleanupTime && timestamp == cleanupTime
-    } else {
-      false
-    }
-  }
-
   protected def cleanupState(states: State*): Unit = {
     // clear all state
     states.foreach(_.clear())
-    if (stateCleaningEnabled) {
-      this.cleanupTimeState.clear()
-    }
+    this.cleanupTimeState.clear()
   }
 }
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/GroupAggProcessFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/GroupAggProcessFunction.scala
index c59efe2..2d72e6d 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/GroupAggProcessFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/GroupAggProcessFunction.scala
@@ -86,7 +86,7 @@ class GroupAggProcessFunction(
 
     val currentTime = ctx.timerService().currentProcessingTime()
     // register state-cleanup timer
-    registerProcessingCleanupTimer(ctx, currentTime)
+    processCleanupTimer(ctx, currentTime)
 
     val input = inputC.row
 
@@ -172,7 +172,7 @@ class GroupAggProcessFunction(
       ctx: ProcessFunction[CRow, CRow]#OnTimerContext,
       out: Collector[CRow]): Unit = {
 
-    if (needToCleanupState(timestamp)) {
+    if (stateCleaningEnabled) {
       cleanupState(state, cntState)
       function.cleanup()
     }
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/KeyedProcessFunctionWithCleanupState.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/KeyedProcessFunctionWithCleanupState.scala
index 4d6840a..edf5c2c 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/KeyedProcessFunctionWithCleanupState.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/KeyedProcessFunctionWithCleanupState.scala
@@ -25,13 +25,15 @@ import org.apache.flink.streaming.api.functions.{KeyedProcessFunction, ProcessFu
 import org.apache.flink.table.api.{StreamQueryConfig, Types}
 
 abstract class KeyedProcessFunctionWithCleanupState[K, I, O](queryConfig: StreamQueryConfig)
-  extends KeyedProcessFunction[K, I, O] {
+  extends KeyedProcessFunction[K, I, O]
+  with CleanupState {
+
   protected val minRetentionTime: Long = queryConfig.getMinIdleStateRetentionTime
   protected val maxRetentionTime: Long = queryConfig.getMaxIdleStateRetentionTime
   protected val stateCleaningEnabled: Boolean = minRetentionTime > 1
 
   // holds the latest registered cleanup timer
-  private var cleanupTimeState: ValueState[JLong] = _
+  protected var cleanupTimeState: ValueState[JLong] = _
 
   protected def initCleanupTimeState(stateName: String) {
     if (stateCleaningEnabled) {
@@ -41,23 +43,17 @@ abstract class KeyedProcessFunctionWithCleanupState[K, I, O](queryConfig: Stream
     }
   }
 
-  protected def registerProcessingCleanupTimer(
+  protected def processCleanupTimer(
     ctx: KeyedProcessFunction[K, I, O]#Context,
     currentTime: Long): Unit = {
     if (stateCleaningEnabled) {
-
-      // last registered timer
-      val curCleanupTime = cleanupTimeState.value()
-
-      // check if a cleanup timer is registered and
-      // that the current cleanup timer won't delete state we need to keep
-      if (curCleanupTime == null || (currentTime + minRetentionTime) > curCleanupTime) {
-        // we need to register a new (later) timer
-        val cleanupTime = currentTime + maxRetentionTime
-        // register timer and remember clean-up time
-        ctx.timerService().registerProcessingTimeTimer(cleanupTime)
-        cleanupTimeState.update(cleanupTime)
-      }
+      registerProcessingCleanupTimer(
+        cleanupTimeState,
+        currentTime,
+        minRetentionTime,
+        maxRetentionTime,
+        ctx.timerService()
+      )
     }
   }
 
@@ -65,16 +61,6 @@ abstract class KeyedProcessFunctionWithCleanupState[K, I, O](queryConfig: Stream
     ctx.timeDomain() == TimeDomain.PROCESSING_TIME
   }
 
-  protected def needToCleanupState(timestamp: Long): Boolean = {
-    if (stateCleaningEnabled) {
-      val cleanupTime = cleanupTimeState.value()
-      // check that the triggered timer is the last registered processing time timer.
-      null != cleanupTime && timestamp == cleanupTime
-    } else {
-      false
-    }
-  }
-
   protected def cleanupState(states: State*): Unit = {
     // clear all state
     states.foreach(_.clear())
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeBoundedRangeOver.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeBoundedRangeOver.scala
index 591b942..6126dc7 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeBoundedRangeOver.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeBoundedRangeOver.scala
@@ -95,7 +95,7 @@ class ProcTimeBoundedRangeOver(
 
     val currentTime = ctx.timerService.currentProcessingTime
     // register state-cleanup timer
-    registerProcessingCleanupTimer(ctx, currentTime)
+    processCleanupTimer(ctx, currentTime)
 
     // buffer the event incoming event
 
@@ -117,11 +117,14 @@ class ProcTimeBoundedRangeOver(
     ctx: ProcessFunction[CRow, CRow]#OnTimerContext,
     out: Collector[CRow]): Unit = {
 
-    if (needToCleanupState(timestamp)) {
-      // clean up and return
-      cleanupState(rowMapState, accumulatorState)
-      function.cleanup()
-      return
+    if (stateCleaningEnabled) {
+      val cleanupTime = cleanupTimeState.value()
+      if (null != cleanupTime && timestamp == cleanupTime) {
+        // clean up and return
+        cleanupState(rowMapState, accumulatorState)
+        function.cleanup()
+        return
+      }
     }
 
     // remove timestamp set outside of ProcessFunction.
@@ -131,11 +134,10 @@ class ProcTimeBoundedRangeOver(
     // that have registered this time trigger 1 ms ago
 
     val currentTime = timestamp - 1
-    var i = 0
     // get the list of elements of current proctime
     val currentElements = rowMapState.get(currentTime)
 
-    // Expired clean-up timers pass the needToCleanupState() check.
+    // Expired clean-up timers pass the needToCleanupState check.
     // Perform a null check to verify that we have data to process.
     if (null == currentElements) {
       return
@@ -156,7 +158,6 @@ class ProcTimeBoundedRangeOver(
     // and eliminate them. Multiple elements could have been received at the same timestamp
     // the removal of old elements happens only once per proctime as onTimer is called only once
     val iter = rowMapState.iterator
-    val markToRemove = new ArrayList[Long]()
     while (iter.hasNext) {
       val entry = iter.next()
       val elementKey = entry.getKey
@@ -169,17 +170,9 @@ class ProcTimeBoundedRangeOver(
           function.retract(accumulators, retractRow)
           iRemove += 1
         }
-        // mark element for later removal not to modify the iterator over MapState
-        markToRemove.add(elementKey)
+        iter.remove()
       }
     }
-    // need to remove in 2 steps not to have concurrent access errors via iterator to the MapState
-    i = 0
-    while (i < markToRemove.size()) {
-      rowMapState.remove(markToRemove.get(i))
-      i += 1
-    }
-
 
     // add current elements to aggregator. Multiple elements might
     // have arrived in the same proctime
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeBoundedRowsOver.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeBoundedRowsOver.scala
index ccddaa5..fa58ac5 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeBoundedRowsOver.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeBoundedRowsOver.scala
@@ -110,7 +110,7 @@ class ProcTimeBoundedRowsOver(
     val currentTime = ctx.timerService.currentProcessingTime
 
     // register state-cleanup timer
-    registerProcessingCleanupTimer(ctx, currentTime)
+    processCleanupTimer(ctx, currentTime)
 
     // initialize state for the processed element
     var accumulators = accumulatorState.value
@@ -187,7 +187,7 @@ class ProcTimeBoundedRowsOver(
     ctx: ProcessFunction[CRow, CRow]#OnTimerContext,
     out: Collector[CRow]): Unit = {
 
-    if (needToCleanupState(timestamp)) {
+    if (stateCleaningEnabled) {
       cleanupState(rowMapState, accumulatorState, counterState, smallestTsState)
       function.cleanup()
     }
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeUnboundedOver.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeUnboundedOver.scala
index 6e4c510..ce1a959 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeUnboundedOver.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeUnboundedOver.scala
@@ -71,7 +71,7 @@ class ProcTimeUnboundedOver(
     out: Collector[CRow]): Unit = {
 
     // register state-cleanup timer
-    registerProcessingCleanupTimer(ctx, ctx.timerService().currentProcessingTime())
+    processCleanupTimer(ctx, ctx.timerService().currentProcessingTime())
 
     val input = inputC.row
 
@@ -95,7 +95,7 @@ class ProcTimeUnboundedOver(
     ctx: ProcessFunction[CRow, CRow]#OnTimerContext,
     out: Collector[CRow]): Unit = {
 
-    if (needToCleanupState(timestamp)) {
+    if (stateCleaningEnabled) {
       cleanupState(state)
       function.cleanup()
     }
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcessFunctionWithCleanupState.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcessFunctionWithCleanupState.scala
index 292fd3b..7263de7 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcessFunctionWithCleanupState.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcessFunctionWithCleanupState.scala
@@ -26,40 +26,35 @@ import org.apache.flink.streaming.api.functions.ProcessFunction
 import org.apache.flink.table.api.{StreamQueryConfig, Types}
 
 abstract class ProcessFunctionWithCleanupState[IN,OUT](queryConfig: StreamQueryConfig)
-  extends ProcessFunction[IN, OUT]{
+  extends ProcessFunction[IN, OUT]
+  with CleanupState {
 
   protected val minRetentionTime: Long = queryConfig.getMinIdleStateRetentionTime
   protected val maxRetentionTime: Long = queryConfig.getMaxIdleStateRetentionTime
   protected val stateCleaningEnabled: Boolean = minRetentionTime > 1
 
   // holds the latest registered cleanup timer
-  private var cleanupTimeState: ValueState[JLong] = _
+  protected var cleanupTimeState: ValueState[JLong] = _
 
   protected def initCleanupTimeState(stateName: String) {
     if (stateCleaningEnabled) {
-      val inputCntDescriptor: ValueStateDescriptor[JLong] =
+      val cleanupTimeDescriptor: ValueStateDescriptor[JLong] =
         new ValueStateDescriptor[JLong](stateName, Types.LONG)
-      cleanupTimeState = getRuntimeContext.getState(inputCntDescriptor)
+      cleanupTimeState = getRuntimeContext.getState(cleanupTimeDescriptor)
     }
   }
 
-  protected def registerProcessingCleanupTimer(
+  protected def processCleanupTimer(
     ctx: ProcessFunction[IN, OUT]#Context,
     currentTime: Long): Unit = {
     if (stateCleaningEnabled) {
-
-      // last registered timer
-      val curCleanupTime = cleanupTimeState.value()
-
-      // check if a cleanup timer is registered and
-      // that the current cleanup timer won't delete state we need to keep
-      if (curCleanupTime == null || (currentTime + minRetentionTime) > curCleanupTime) {
-        // we need to register a new (later) timer
-        val cleanupTime = currentTime + maxRetentionTime
-        // register timer and remember clean-up time
-        ctx.timerService().registerProcessingTimeTimer(cleanupTime)
-        cleanupTimeState.update(cleanupTime)
-      }
+      registerProcessingCleanupTimer(
+        cleanupTimeState,
+        currentTime,
+        minRetentionTime,
+        maxRetentionTime,
+        ctx.timerService()
+      )
     }
   }
 
@@ -67,16 +62,6 @@ abstract class ProcessFunctionWithCleanupState[IN,OUT](queryConfig: StreamQueryC
     ctx.timeDomain() == TimeDomain.PROCESSING_TIME
   }
 
-  protected def needToCleanupState(timestamp: Long): Boolean = {
-    if (stateCleaningEnabled) {
-      val cleanupTime = cleanupTimeState.value()
-      // check that the triggered timer is the last registered processing time timer.
-      null != cleanupTime && timestamp == cleanupTime
-    } else {
-      false
-    }
-  }
-
   protected def cleanupState(states: State*): Unit = {
     // clear all state
     states.foreach(_.clear())
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RowTimeBoundedRangeOver.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RowTimeBoundedRangeOver.scala
index b13acdf..7c509d6 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RowTimeBoundedRangeOver.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RowTimeBoundedRangeOver.scala
@@ -114,7 +114,7 @@ class RowTimeBoundedRangeOver(
     val input = inputC.row
 
     // register state-cleanup timer
-    registerProcessingCleanupTimer(ctx, ctx.timerService().currentProcessingTime())
+    processCleanupTimer(ctx, ctx.timerService().currentProcessingTime())
 
     // triggering timestamp for trigger calculation
     val triggeringTs = input.getField(rowTimeIdx).asInstanceOf[Long]
@@ -143,7 +143,7 @@ class RowTimeBoundedRangeOver(
     out: Collector[CRow]): Unit = {
 
     if (isProcessingTimeTimer(ctx.asInstanceOf[OnTimerContext])) {
-      if (needToCleanupState(timestamp)) {
+      if (stateCleaningEnabled) {
 
         val keysIt = dataState.keys.iterator()
         val lastProcessedTime = lastTriggeringTsState.value
@@ -164,7 +164,7 @@ class RowTimeBoundedRangeOver(
           // There are records left to process because a watermark has not been received yet.
           // This would only happen if the input stream has stopped. So we don't need to clean up.
           // We leave the state as it is and schedule a new cleanup timer
-          registerProcessingCleanupTimer(ctx, ctx.timerService().currentProcessingTime())
+          processCleanupTimer(ctx, ctx.timerService().currentProcessingTime())
         }
       }
       return
@@ -188,9 +188,6 @@ class RowTimeBoundedRangeOver(
         aggregatesIndex = 0
       }
 
-      // keep up timestamps of retract data
-      val retractTsList: JList[Long] = new JArrayList[Long]
-
       // do retraction
       val iter = dataState.iterator()
       while (iter.hasNext) {
@@ -205,7 +202,7 @@ class RowTimeBoundedRangeOver(
             function.retract(accumulators, retractRow)
             dataListIndex += 1
           }
-          retractTsList.add(dataTs)
+          iter.remove()
         }
       }
 
@@ -230,20 +227,13 @@ class RowTimeBoundedRangeOver(
         dataListIndex += 1
       }
 
-      // remove the data that has been retracted
-      dataListIndex = 0
-      while (dataListIndex < retractTsList.size) {
-        dataState.remove(retractTsList.get(dataListIndex))
-        dataListIndex += 1
-      }
-
       // update state
       accumulatorState.update(accumulators)
     }
     lastTriggeringTsState.update(timestamp)
 
     // update cleanup timer
-    registerProcessingCleanupTimer(ctx, ctx.timerService().currentProcessingTime())
+    processCleanupTimer(ctx, ctx.timerService().currentProcessingTime())
   }
 
   override def close(): Unit = {
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RowTimeBoundedRowsOver.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RowTimeBoundedRowsOver.scala
index e120d6b..d01a499 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RowTimeBoundedRowsOver.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RowTimeBoundedRowsOver.scala
@@ -123,7 +123,7 @@ class RowTimeBoundedRowsOver(
     val input = inputC.row
 
     // register state-cleanup timer
-    registerProcessingCleanupTimer(ctx, ctx.timerService().currentProcessingTime())
+    processCleanupTimer(ctx, ctx.timerService().currentProcessingTime())
 
     // triggering timestamp for trigger calculation
     val triggeringTs = input.getField(rowTimeIdx).asInstanceOf[Long]
@@ -152,7 +152,7 @@ class RowTimeBoundedRowsOver(
     out: Collector[CRow]): Unit = {
 
     if (isProcessingTimeTimer(ctx.asInstanceOf[OnTimerContext])) {
-      if (needToCleanupState(timestamp)) {
+      if (stateCleaningEnabled) {
 
         val keysIt = dataState.keys.iterator()
         val lastProcessedTime = lastTriggeringTsState.value
@@ -173,7 +173,7 @@ class RowTimeBoundedRowsOver(
           // There are records left to process because a watermark has not been received yet.
           // This would only happen if the input stream has stopped. So we don't need to clean up.
           // We leave the state as it is and schedule a new cleanup timer
-          registerProcessingCleanupTimer(ctx, ctx.timerService().currentProcessingTime())
+          processCleanupTimer(ctx, ctx.timerService().currentProcessingTime())
         }
       }
       return
@@ -263,7 +263,7 @@ class RowTimeBoundedRowsOver(
     lastTriggeringTsState.update(timestamp)
 
     // update cleanup timer
-    registerProcessingCleanupTimer(ctx, ctx.timerService().currentProcessingTime())
+    processCleanupTimer(ctx, ctx.timerService().currentProcessingTime())
   }
 
   override def close(): Unit = {
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RowTimeUnboundedOver.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RowTimeUnboundedOver.scala
index 181c768..690d0d0 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RowTimeUnboundedOver.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RowTimeUnboundedOver.scala
@@ -108,7 +108,7 @@ abstract class RowTimeUnboundedOver(
     val input = inputC.row
 
     // register state-cleanup timer
-    registerProcessingCleanupTimer(ctx, ctx.timerService().currentProcessingTime())
+    processCleanupTimer(ctx, ctx.timerService().currentProcessingTime())
 
     val timestamp = input.getField(rowTimeIdx).asInstanceOf[Long]
     val curWatermark = ctx.timerService().currentWatermark()
@@ -143,7 +143,7 @@ abstract class RowTimeUnboundedOver(
       out: Collector[CRow]): Unit = {
 
     if (isProcessingTimeTimer(ctx.asInstanceOf[OnTimerContext])) {
-      if (needToCleanupState(timestamp)) {
+      if (stateCleaningEnabled) {
 
         // we check whether there are still records which have not been processed yet
         val noRecordsToProcess = !rowMapState.keys.iterator().hasNext
@@ -155,7 +155,7 @@ abstract class RowTimeUnboundedOver(
           // There are records left to process because a watermark has not been received yet.
           // This would only happen if the input stream has stopped. So we don't need to clean up.
           // We leave the state as it is and schedule a new cleanup timer
-          registerProcessingCleanupTimer(ctx, ctx.timerService().currentProcessingTime())
+          processCleanupTimer(ctx, ctx.timerService().currentProcessingTime())
         }
       }
       return
@@ -207,7 +207,7 @@ abstract class RowTimeUnboundedOver(
     }
 
     // update cleanup timer
-    registerProcessingCleanupTimer(ctx, ctx.timerService().currentProcessingTime())
+    processCleanupTimer(ctx, ctx.timerService().currentProcessingTime())
   }
 
   /**
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/NonWindowFullJoin.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/NonWindowFullJoin.scala
index 57c60f1..5b1069f 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/NonWindowFullJoin.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/NonWindowFullJoin.scala
@@ -66,13 +66,12 @@ class NonWindowFullJoin(
       value: CRow,
       ctx: CoProcessFunction[CRow, CRow, CRow]#Context,
       out: Collector[CRow],
-      timerState: ValueState[Long],
       currentSideState: MapState[Row, JTuple2[Long, Long]],
       otherSideState: MapState[Row, JTuple2[Long, Long]],
       recordFromLeft: Boolean): Unit = {
 
     val inputRow = value.row
-    updateCurrentSide(value, ctx, timerState, currentSideState)
+    updateCurrentSide(value, ctx, currentSideState)
 
     cRowWrapper.reset()
     cRowWrapper.setCollector(out)
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/NonWindowFullJoinWithNonEquiPredicates.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/NonWindowFullJoinWithNonEquiPredicates.scala
index 9c27eb4..0166eef 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/NonWindowFullJoinWithNonEquiPredicates.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/NonWindowFullJoinWithNonEquiPredicates.scala
@@ -68,14 +68,13 @@ class NonWindowFullJoinWithNonEquiPredicates(
       value: CRow,
       ctx: CoProcessFunction[CRow, CRow, CRow]#Context,
       out: Collector[CRow],
-      timerState: ValueState[Long],
       currentSideState: MapState[Row, JTuple2[Long, Long]],
       otherSideState: MapState[Row, JTuple2[Long, Long]],
       recordFromLeft: Boolean): Unit = {
 
     val currentJoinCntState = getJoinCntState(joinCntState, recordFromLeft)
     val inputRow = value.row
-    val cntAndExpiredTime = updateCurrentSide(value, ctx, timerState, currentSideState)
+    val cntAndExpiredTime = updateCurrentSide(value, ctx, currentSideState)
     if (!value.change && cntAndExpiredTime.f0 <= 0) {
       currentJoinCntState.remove(inputRow)
     }
@@ -99,18 +98,18 @@ class NonWindowFullJoinWithNonEquiPredicates(
   }
 
   /**
-    * Removes records which are expired from left state. Register a new timer if the state still
-    * holds records after the clean-up. Also, clear leftJoinCnt map state when clear left
-    * rowMapState.
+    * Called when a processing timer trigger.
+    * Expire left/right expired records and expired joinCnt state.
     */
-  override def expireOutTimeRow(
-      curTime: Long,
-      rowMapState: MapState[Row, JTuple2[Long, Long]],
-      timerState: ValueState[Long],
-      isLeft: Boolean,
-      ctx: CoProcessFunction[CRow, CRow, CRow]#OnTimerContext): Unit = {
+  override def onTimer(
+      timestamp: Long,
+      ctx: CoProcessFunction[CRow, CRow, CRow]#OnTimerContext,
+      out: Collector[CRow]): Unit = {
 
-    expireOutTimeRow(curTime, rowMapState, timerState, isLeft, joinCntState, ctx)
+    // expired timer has already been removed, delete state directly.
+    if (stateCleaningEnabled) {
+      cleanupState(leftState, rightState, joinCntState(0), joinCntState(1))
+    }
   }
 }
 
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/NonWindowInnerJoin.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/NonWindowInnerJoin.scala
index 2e5832c..91a7507 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/NonWindowInnerJoin.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/NonWindowInnerJoin.scala
@@ -63,13 +63,12 @@ class NonWindowInnerJoin(
       value: CRow,
       ctx: CoProcessFunction[CRow, CRow, CRow]#Context,
       out: Collector[CRow],
-      timerState: ValueState[Long],
       currentSideState: MapState[Row, JTuple2[Long, Long]],
       otherSideState: MapState[Row, JTuple2[Long, Long]],
       isLeft: Boolean): Unit = {
 
     val inputRow = value.row
-    updateCurrentSide(value, ctx, timerState, currentSideState)
+    updateCurrentSide(value, ctx, currentSideState)
 
     cRowWrapper.setCollector(out)
     cRowWrapper.setChange(value.change)
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/NonWindowJoin.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/NonWindowJoin.scala
index c59f4c2..e15cbfa 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/NonWindowJoin.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/NonWindowJoin.scala
@@ -19,7 +19,7 @@ package org.apache.flink.table.runtime.join
 
 import org.apache.flink.api.common.functions.FlatJoinFunction
 import org.apache.flink.api.common.functions.util.FunctionUtils
-import org.apache.flink.api.common.state.{MapState, MapStateDescriptor, ValueState, ValueStateDescriptor}
+import org.apache.flink.api.common.state.{MapState, MapStateDescriptor}
 import org.apache.flink.api.common.typeinfo.TypeInformation
 import org.apache.flink.api.java.tuple.{Tuple2 => JTuple2}
 import org.apache.flink.api.java.typeutils.TupleTypeInfo
@@ -27,6 +27,7 @@ import org.apache.flink.configuration.Configuration
 import org.apache.flink.streaming.api.functions.co.CoProcessFunction
 import org.apache.flink.table.api.{StreamQueryConfig, Types}
 import org.apache.flink.table.codegen.Compiler
+import org.apache.flink.table.runtime.aggregate.CoProcessFunctionWithCleanupState
 import org.apache.flink.table.runtime.types.CRow
 import org.apache.flink.table.typeutils.TypeCheckUtils._
 import org.apache.flink.table.util.Logging
@@ -48,7 +49,7 @@ abstract class NonWindowJoin(
     genJoinFuncName: String,
     genJoinFuncCode: String,
     queryConfig: StreamQueryConfig)
-  extends CoProcessFunction[CRow, CRow, CRow]
+  extends CoProcessFunctionWithCleanupState[CRow, CRow, CRow](queryConfig)
   with Compiler[FlatJoinFunction[Row, Row, Row]]
   with Logging {
 
@@ -62,15 +63,6 @@ abstract class NonWindowJoin(
   protected var rightState: MapState[Row, JTuple2[Long, Long]] = _
   protected var cRowWrapper: CRowWrappingMultiOutputCollector = _
 
-  protected val minRetentionTime: Long = queryConfig.getMinIdleStateRetentionTime
-  protected val maxRetentionTime: Long = queryConfig.getMaxIdleStateRetentionTime
-  protected val stateCleaningEnabled: Boolean = minRetentionTime > 1
-
-  // state to record last timer of left stream, 0 means no timer
-  protected var leftTimer: ValueState[Long] = _
-  // state to record last timer of right stream, 0 means no timer
-  protected var rightTimer: ValueState[Long] = _
-
   // other condition function
   protected var joinFunction: FlatJoinFunction[Row, Row, Row] = _
 
@@ -78,7 +70,8 @@ abstract class NonWindowJoin(
   protected var curProcessTime: Long = _
 
   override def open(parameters: Configuration): Unit = {
-    LOG.debug(s"Compiling JoinFunction: $genJoinFuncName \n\n Code:\n$genJoinFuncCode")
+    LOG.debug(s"Compiling JoinFunction: $genJoinFuncName \n\n " +
+                s"Code:\n$genJoinFuncCode")
     val clazz = compile(
       getRuntimeContext.getUserCodeClassLoader,
       genJoinFuncName,
@@ -100,10 +93,7 @@ abstract class NonWindowJoin(
     rightState = getRuntimeContext.getMapState(rightStateDescriptor)
 
     // initialize timer state
-    val valueStateDescriptor1 = new ValueStateDescriptor[Long]("timervaluestate1", classOf[Long])
-    leftTimer = getRuntimeContext.getState(valueStateDescriptor1)
-    val valueStateDescriptor2 = new ValueStateDescriptor[Long]("timervaluestate2", classOf[Long])
-    rightTimer = getRuntimeContext.getState(valueStateDescriptor2)
+    initCleanupTimeState("NonWindowJoinCleanupTime")
 
     cRowWrapper = new CRowWrappingMultiOutputCollector()
     LOG.debug("Instantiating NonWindowJoin.")
@@ -122,7 +112,7 @@ abstract class NonWindowJoin(
       ctx: CoProcessFunction[CRow, CRow, CRow]#Context,
       out: Collector[CRow]): Unit = {
 
-    processElement(valueC, ctx, out, leftTimer, leftState, rightState, isLeft = true)
+    processElement(valueC, ctx, out, leftState, rightState, isLeft = true)
   }
 
   /**
@@ -138,7 +128,7 @@ abstract class NonWindowJoin(
       ctx: CoProcessFunction[CRow, CRow, CRow]#Context,
       out: Collector[CRow]): Unit = {
 
-    processElement(valueC, ctx, out, rightTimer, rightState, leftState, isLeft = false)
+    processElement(valueC, ctx, out, rightState, leftState, isLeft = false)
   }
 
   /**
@@ -154,28 +144,13 @@ abstract class NonWindowJoin(
       ctx: CoProcessFunction[CRow, CRow, CRow]#OnTimerContext,
       out: Collector[CRow]): Unit = {
 
-    if (stateCleaningEnabled && leftTimer.value == timestamp) {
-      expireOutTimeRow(
-        timestamp,
-        leftState,
-        leftTimer,
-        isLeft = true,
-        ctx
-      )
-    }
-
-    if (stateCleaningEnabled && rightTimer.value == timestamp) {
-      expireOutTimeRow(
-        timestamp,
-        rightState,
-        rightTimer,
-        isLeft = false,
-        ctx
-      )
+    // expired timer has already been removed, delete state directly.
+    if (stateCleaningEnabled) {
+      cleanupState(leftState, rightState)
     }
   }
 
-  def getNewExpiredTime(curProcessTime: Long, oldExpiredTime: Long): Long = {
+  protected def getNewExpiredTime(curProcessTime: Long, oldExpiredTime: Long): Long = {
     if (stateCleaningEnabled && curProcessTime + minRetentionTime > oldExpiredTime) {
       curProcessTime + maxRetentionTime
     } else {
@@ -184,52 +159,14 @@ abstract class NonWindowJoin(
   }
 
   /**
-    * Removes records which are expired from the state. Register a new timer if the state still
-    * holds records after the clean-up.
-    */
-  def expireOutTimeRow(
-      curTime: Long,
-      rowMapState: MapState[Row, JTuple2[Long, Long]],
-      timerState: ValueState[Long],
-      isLeft: Boolean,
-      ctx: CoProcessFunction[CRow, CRow, CRow]#OnTimerContext): Unit = {
-
-    val rowMapIter = rowMapState.iterator()
-    var validTimestamp: Boolean = false
-
-    while (rowMapIter.hasNext) {
-      val mapEntry = rowMapIter.next()
-      val recordExpiredTime = mapEntry.getValue.f1
-      if (recordExpiredTime <= curTime) {
-        rowMapIter.remove()
-      } else {
-        // we found a timestamp that is still valid
-        validTimestamp = true
-      }
-    }
-
-    // If the state has non-expired timestamps, register a new timer.
-    // Otherwise clean the complete state for this input.
-    if (validTimestamp) {
-      val cleanupTime = curTime + maxRetentionTime
-      ctx.timerService.registerProcessingTimeTimer(cleanupTime)
-      timerState.update(cleanupTime)
-    } else {
-      timerState.clear()
-      rowMapState.clear()
-    }
-  }
-
-  /**
     * Puts or Retract an element from the input stream into state and search the other state to
     * output records meet the condition. Records will be expired in state if state retention time
     * has been specified.
     */
-  def processElement(
+  protected def processElement(
       value: CRow,
       ctx: CoProcessFunction[CRow, CRow, CRow]#Context,
       out: Collector[CRow],
-      timerState: ValueState[Long],
       currentSideState: MapState[Row, JTuple2[Long, Long]],
       otherSideState: MapState[Row, JTuple2[Long, Long]],
       isLeft: Boolean): Unit
@@ -240,14 +177,12 @@ abstract class NonWindowJoin(
     *
     * @param value            The input CRow
     * @param ctx              The ctx to register timer or get current time
-    * @param timerState       The state to record last timer
     * @param currentSideState The state to hold current side stream element
     * @return The row number and expired time for current input row
     */
-  def updateCurrentSide(
+  protected def updateCurrentSide(
       value: CRow,
       ctx: CoProcessFunction[CRow, CRow, CRow]#Context,
-      timerState: ValueState[Long],
       currentSideState: MapState[Row, JTuple2[Long, Long]]): JTuple2[Long, Long] = {
 
     val inputRow = value.row
@@ -261,10 +196,7 @@ abstract class NonWindowJoin(
 
     cntAndExpiredTime.f1 = getNewExpiredTime(curProcessTime, cntAndExpiredTime.f1)
     // update timer if necessary
-    if (stateCleaningEnabled && timerState.value() == 0) {
-      timerState.update(cntAndExpiredTime.f1)
-      ctx.timerService().registerProcessingTimeTimer(cntAndExpiredTime.f1)
-    }
+    processCleanupTimer(ctx, curProcessTime)
 
     // update current side stream state
     if (!value.change) {
@@ -282,7 +214,7 @@ abstract class NonWindowJoin(
     cntAndExpiredTime
   }
 
-  def callJoinFunction(
+  protected def callJoinFunction(
       inputRow: Row,
       inputRowFromLeft: Boolean,
       otherSideRow: Row,
@@ -294,8 +226,4 @@ abstract class NonWindowJoin(
       joinFunction.join(otherSideRow, inputRow, cRowWrapper)
     }
   }
-
-  override def close(): Unit = {
-    FunctionUtils.closeFunction(joinFunction)
-  }
 }
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/NonWindowLeftRightJoin.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/NonWindowLeftRightJoin.scala
index b4f14e4..5995fb8 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/NonWindowLeftRightJoin.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/NonWindowLeftRightJoin.scala
@@ -69,13 +69,12 @@ class NonWindowLeftRightJoin(
       value: CRow,
       ctx: CoProcessFunction[CRow, CRow, CRow]#Context,
       out: Collector[CRow],
-      timerState: ValueState[Long],
       currentSideState: MapState[Row, JTuple2[Long, Long]],
       otherSideState: MapState[Row, JTuple2[Long, Long]],
       recordFromLeft: Boolean): Unit = {
 
     val inputRow = value.row
-    updateCurrentSide(value, ctx, timerState, currentSideState)
+    updateCurrentSide(value, ctx, currentSideState)
 
     cRowWrapper.reset()
     cRowWrapper.setCollector(out)
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/NonWindowLeftRightJoinWithNonEquiPredicates.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/NonWindowLeftRightJoinWithNonEquiPredicates.scala
index 33517cc..a3e25f9 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/NonWindowLeftRightJoinWithNonEquiPredicates.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/NonWindowLeftRightJoinWithNonEquiPredicates.scala
@@ -71,14 +71,13 @@ class NonWindowLeftRightJoinWithNonEquiPredicates(
       value: CRow,
       ctx: CoProcessFunction[CRow, CRow, CRow]#Context,
       out: Collector[CRow],
-      timerState: ValueState[Long],
       currentSideState: MapState[Row, JTuple2[Long, Long]],
       otherSideState: MapState[Row, JTuple2[Long, Long]],
       recordFromLeft: Boolean): Unit = {
 
     val currentJoinCntState = getJoinCntState(joinCntState, recordFromLeft)
     val inputRow = value.row
-    val cntAndExpiredTime = updateCurrentSide(value, ctx, timerState, currentSideState)
+    val cntAndExpiredTime = updateCurrentSide(value, ctx, currentSideState)
     if (!value.change && cntAndExpiredTime.f0 <= 0 && recordFromLeft == isLeftJoin) {
       currentJoinCntState.remove(inputRow)
     }
@@ -101,17 +100,21 @@ class NonWindowLeftRightJoinWithNonEquiPredicates(
   }
 
   /**
-    * Removes records which are expired from state. Register a new timer if the state still
-    * holds records after the clean-up. Also, clear joinCnt map state when clear rowMapState.
+    * Called when a processing timer trigger.
+    * Expire left/right expired records and expired joinCnt state.
     */
-  override def expireOutTimeRow(
-      curTime: Long,
-      rowMapState: MapState[Row, JTuple2[Long, Long]],
-      timerState: ValueState[Long],
-      isLeft: Boolean,
-      ctx: CoProcessFunction[CRow, CRow, CRow]#OnTimerContext): Unit = {
+  override def onTimer(
+      timestamp: Long,
+      ctx: CoProcessFunction[CRow, CRow, CRow]#OnTimerContext,
+      out: Collector[CRow]): Unit = {
 
-    expireOutTimeRow(curTime, rowMapState, timerState, isLeft, joinCntState, ctx)
+    // expired timer has already been removed, delete state directly.
+    if (stateCleaningEnabled) {
+      cleanupState(
+        leftState,
+        rightState,
+        getJoinCntState(joinCntState, isLeftJoin))
+    }
   }
 }
 
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/NonWindowOuterJoin.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/NonWindowOuterJoin.scala
index 0018a16..9fadbb0 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/NonWindowOuterJoin.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/NonWindowOuterJoin.scala
@@ -73,7 +73,7 @@ abstract class NonWindowOuterJoin(
     * @param otherSideState   the other side state
     * @return the number of matched rows
     */
-  def preservedJoin(
+  protected def preservedJoin(
       inputRow: Row,
       inputRowFromLeft: Boolean,
       otherSideState: MapState[Row, JTuple2[Long, Long]]): Long = {
@@ -106,7 +106,7 @@ abstract class NonWindowOuterJoin(
     * RowWrapper has been reset before we call retractJoin and we also assume that the current
     * change of cRowWrapper is equal to value.change.
     */
-  def retractJoin(
+  protected def retractJoin(
       value: CRow,
       inputRowFromLeft: Boolean,
       currentSideState: MapState[Row, JTuple2[Long, Long]],
@@ -152,7 +152,8 @@ abstract class NonWindowOuterJoin(
     * Return approximate number of records in corresponding state. Only check if record number is
     * 0, 1 or bigger.
     */
-  def approxiRecordNumInState(currentSideState: MapState[Row, JTuple2[Long, Long]]): Long = {
+  protected def approxiRecordNumInState(
+      currentSideState: MapState[Row, JTuple2[Long, Long]]): Long = {
     var recordNum = 0L
     val it = currentSideState.iterator()
     while(it.hasNext && recordNum < 2) {
@@ -164,7 +165,7 @@ abstract class NonWindowOuterJoin(
   /**
     * Append input row with default null value if there is no match and Collect.
     */
-  def collectAppendNull(
+  protected def collectAppendNull(
       inputRow: Row,
       inputFromLeft: Boolean,
       out: Collector[Row]): Unit = {
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/NonWindowOuterJoinWithNonEquiPredicates.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/NonWindowOuterJoinWithNonEquiPredicates.scala
index 8fe2f4f..0eb6a81 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/NonWindowOuterJoinWithNonEquiPredicates.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/NonWindowOuterJoinWithNonEquiPredicates.scala
@@ -21,7 +21,6 @@ import org.apache.flink.api.common.state._
 import org.apache.flink.api.common.typeinfo.TypeInformation
 import org.apache.flink.api.java.tuple.{Tuple2 => JTuple2}
 import org.apache.flink.configuration.Configuration
-import org.apache.flink.streaming.api.functions.co.CoProcessFunction
 import org.apache.flink.table.api.{StreamQueryConfig, Types}
 import org.apache.flink.table.runtime.types.CRow
 import org.apache.flink.types.Row
@@ -83,7 +82,7 @@ import org.apache.flink.types.Row
     * unmatched or vice versa. The RowWrapper has been reset before we call retractJoin and we
     * also assume that the current change of cRowWrapper is equal to value.change.
     */
-  def retractJoinWithNonEquiPreds(
+  protected def retractJoinWithNonEquiPreds(
       value: CRow,
       inputRowFromLeft: Boolean,
       otherSideState: MapState[Row, JTuple2[Long, Long]],
@@ -132,48 +131,6 @@ import org.apache.flink.types.Row
   }
 
   /**
-    * Removes records which are expired from state. Registers a new timer if the state still
-    * holds records after the clean-up. Also, clear joinCnt map state when clear rowMapState.
-    */
-  def expireOutTimeRow(
-      curTime: Long,
-      rowMapState: MapState[Row, JTuple2[Long, Long]],
-      timerState: ValueState[Long],
-      isLeft: Boolean,
-      joinCntState: Array[MapState[Row, Long]],
-      ctx: CoProcessFunction[CRow, CRow, CRow]#OnTimerContext): Unit = {
-
-    val currentJoinCntState = getJoinCntState(joinCntState, isLeft)
-    val rowMapIter = rowMapState.iterator()
-    var validTimestamp: Boolean = false
-
-    while (rowMapIter.hasNext) {
-      val mapEntry = rowMapIter.next()
-      val recordExpiredTime = mapEntry.getValue.f1
-      if (recordExpiredTime <= curTime) {
-        rowMapIter.remove()
-        currentJoinCntState.remove(mapEntry.getKey)
-      } else {
-        // we found a timestamp that is still valid
-        validTimestamp = true
-      }
-    }
-    // If the state has non-expired timestamps, register a new timer.
-    // Otherwise clean the complete state for this input.
-    if (validTimestamp) {
-      val cleanupTime = curTime + maxRetentionTime
-      ctx.timerService.registerProcessingTimeTimer(cleanupTime)
-      timerState.update(cleanupTime)
-    } else {
-      timerState.clear()
-      rowMapState.clear()
-      if (isLeft == isLeftJoin) {
-        currentJoinCntState.clear()
-      }
-    }
-  }
-
-  /**
     * Get left or right join cnt state.
     *
     * @param joinCntState    the join cnt state array, index 0 is left join cnt state, index 1
@@ -181,7 +138,7 @@ import org.apache.flink.types.Row
     * @param isLeftCntState the flag whether get the left join cnt state
     * @return the corresponding join cnt state
     */
-  def getJoinCntState(
+  protected def getJoinCntState(
       joinCntState: Array[MapState[Row, Long]],
       isLeftCntState: Boolean)
     : MapState[Row, Long] = {
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/triggers/StateCleaningCountTrigger.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/triggers/StateCleaningCountTrigger.scala
index 3c18449..6ae5e63 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/triggers/StateCleaningCountTrigger.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/triggers/StateCleaningCountTrigger.scala
@@ -71,6 +71,10 @@ class StateCleaningCountTrigger(queryConfig: StreamQueryConfig, maxCount: Long)
         val cleanupTime = currentTime + maxRetentionTime
         // register timer and remember clean-up time
         ctx.registerProcessingTimeTimer(cleanupTime)
+        // delete expired timer
+        if (curCleanupTime != null) {
+          ctx.deleteProcessingTimeTimer(curCleanupTime)
+        }
 
         ctx.getPartitionedState(cleanupStateDesc).update(cleanupTime)
       }
@@ -93,12 +97,9 @@ class StateCleaningCountTrigger(queryConfig: StreamQueryConfig, maxCount: Long)
       ctx: TriggerContext): TriggerResult = {
 
     if (stateCleaningEnabled) {
-      val cleanupTime = ctx.getPartitionedState(cleanupStateDesc).value()
-      // check that the triggered timer is the last registered processing time timer.
-      if (null != cleanupTime && time == cleanupTime) {
-        clear(window, ctx)
-        return TriggerResult.FIRE_AND_PURGE
-      }
+      // do clear directly, since we have already deleted expired timer
+      clear(window, ctx)
+      return TriggerResult.FIRE_AND_PURGE
     }
     TriggerResult.CONTINUE
   }
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/JoinHarnessTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/JoinHarnessTest.scala
index bd19be8..4619c75 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/JoinHarnessTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/JoinHarnessTest.scala
@@ -21,21 +21,18 @@ import java.lang.{Integer => JInt, Long => JLong}
 import java.util.concurrent.ConcurrentLinkedQueue
 
 import org.apache.flink.api.common.time.Time
-import org.apache.flink.api.common.typeinfo.BasicTypeInfo._
-import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation}
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo
 import org.apache.flink.api.java.operators.join.JoinType
-import org.apache.flink.api.java.typeutils.RowTypeInfo
 import org.apache.flink.streaming.api.operators.co.KeyedCoProcessOperator
 import org.apache.flink.streaming.api.watermark.Watermark
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord
 import org.apache.flink.streaming.util.KeyedTwoInputStreamOperatorTestHarness
-import org.apache.flink.table.api.{StreamQueryConfig, Types}
-import org.apache.flink.table.runtime.harness.HarnessTestBase.{RowResultSortComparator, RowResultSortComparatorWithWatermarks, TestStreamQueryConfig, TupleRowKeySelector}
+import org.apache.flink.table.api.Types
+import org.apache.flink.table.runtime.harness.HarnessTestBase.{TestStreamQueryConfig, TupleRowKeySelector}
 import org.apache.flink.table.runtime.join._
 import org.apache.flink.table.runtime.operators.KeyedCoProcessOperatorWithWatermarkDelay
-import org.apache.flink.table.runtime.types.{CRow, CRowTypeInfo}
-import org.apache.flink.types.Row
-import org.junit.Assert.{assertEquals, assertTrue}
+import org.apache.flink.table.runtime.types.CRow
+import org.junit.Assert.assertEquals
 import org.junit.Test
 
 /**
@@ -830,14 +827,6 @@ class JoinHarnessTest extends HarnessTestBase {
   @Test
   def testNonWindowInnerJoin() {
 
-    val joinReturnType = CRowTypeInfo(new RowTypeInfo(
-      Array[TypeInformation[_]](
-        INT_TYPE_INFO,
-        STRING_TYPE_INFO,
-        INT_TYPE_INFO,
-        STRING_TYPE_INFO),
-      Array("a", "b", "c", "d")))
-
     val joinProcessFunc = new NonWindowInnerJoin(
       rowType,
       rowType,
@@ -879,35 +868,32 @@ class JoinHarnessTest extends HarnessTestBase {
     // right stream input and output normally
     testHarness.processElement2(new StreamRecord(
       CRow(1: JInt, "Hi1")))
-    assertEquals(6, testHarness.numKeyedStateEntries())
-    assertEquals(3, testHarness.numProcessingTimeTimers())
+    // lkeys(1, 2) rkeys(1) timer_key_time(1:5, 2:6)
+    assertEquals(5, testHarness.numKeyedStateEntries())
+    assertEquals(2, testHarness.numProcessingTimeTimers())
     testHarness.setProcessingTime(4)
     testHarness.processElement2(new StreamRecord(
       CRow(2: JInt, "Hello1")))
-    assertEquals(8, testHarness.numKeyedStateEntries())
-    assertEquals(4, testHarness.numProcessingTimeTimers())
+    // lkeys(1, 2) rkeys(1, 2) timer_key_time(1:5, 2:6)
+    assertEquals(6, testHarness.numKeyedStateEntries())
+    assertEquals(2, testHarness.numProcessingTimeTimers())
 
-    // expired left stream record with key value of 1
+    // expired stream record with key value of 1
     testHarness.setProcessingTime(5)
     testHarness.processElement2(new StreamRecord(
       CRow(1: JInt, "Hi2")))
-    assertEquals(6, testHarness.numKeyedStateEntries())
-    assertEquals(3, testHarness.numProcessingTimeTimers())
-
-    // expired all left stream record
-    testHarness.setProcessingTime(6)
-    assertEquals(4, testHarness.numKeyedStateEntries())
+    // lkeys(2) rkeys(1, 2) timer_key_time(1:9, 2:6)
+    assertEquals(5, testHarness.numKeyedStateEntries())
     assertEquals(2, testHarness.numProcessingTimeTimers())
 
-    // expired right stream record with key value of 2
-    testHarness.setProcessingTime(8)
+    // expired all left stream records
+    testHarness.setProcessingTime(6)
+    // lkeys() rkeys(1) timer_key_time(1:9)
     assertEquals(2, testHarness.numKeyedStateEntries())
     assertEquals(1, testHarness.numProcessingTimeTimers())
 
-    testHarness.setProcessingTime(10)
-    assertTrue(testHarness.numKeyedStateEntries() > 0)
-    // expired all right stream record
-    testHarness.setProcessingTime(11)
+    // expired all stream records
+    testHarness.setProcessingTime(9)
     assertEquals(0, testHarness.numKeyedStateEntries())
     assertEquals(0, testHarness.numProcessingTimeTimers())
 
@@ -975,32 +961,37 @@ class JoinHarnessTest extends HarnessTestBase {
       CRow(1: JInt, "Hi1")))
     testHarness.processElement2(new StreamRecord(
       CRow(false, 1: JInt, "Hi1")))
-    assertEquals(5, testHarness.numKeyedStateEntries())
-    assertEquals(3, testHarness.numProcessingTimeTimers())
+    // lkeys(1, 2) rkeys() timer_key_time(1:5, 2:6)
+    assertEquals(4, testHarness.numKeyedStateEntries())
+    assertEquals(2, testHarness.numProcessingTimeTimers())
     testHarness.setProcessingTime(4)
     testHarness.processElement2(new StreamRecord(
       CRow(2: JInt, "Hello1")))
-    assertEquals(7, testHarness.numKeyedStateEntries())
-    assertEquals(4, testHarness.numProcessingTimeTimers())
+    // lkeys(1, 2) rkeys(2) timer_key_time(1:5, 2:6)
+    assertEquals(5, testHarness.numKeyedStateEntries())
+    assertEquals(2, testHarness.numProcessingTimeTimers())
 
     testHarness.processElement1(new StreamRecord(
       CRow(false, 1: JInt, "aaa")))
-    // expired left stream record with key value of 1
+    // expired stream records with key value of 1
     testHarness.setProcessingTime(5)
+    // lkeys(2) rkeys(2) timer_key_time(2:6)
     testHarness.processElement2(new StreamRecord(
       CRow(1: JInt, "Hi2")))
     testHarness.processElement2(new StreamRecord(
       CRow(false, 1: JInt, "Hi2")))
-    assertEquals(5, testHarness.numKeyedStateEntries())
-    assertEquals(3, testHarness.numProcessingTimeTimers())
+    // lkeys(2) rkeys(2) timer_key_time(1:9, 2:6)
+    assertEquals(4, testHarness.numKeyedStateEntries())
+    assertEquals(2, testHarness.numProcessingTimeTimers())
 
-    // expired all left stream record
+    // expired all stream records
     testHarness.setProcessingTime(6)
-    assertEquals(3, testHarness.numKeyedStateEntries())
-    assertEquals(2, testHarness.numProcessingTimeTimers())
+    // lkeys() rkeys() timer_key_time(1:9)
+    assertEquals(1, testHarness.numKeyedStateEntries())
+    assertEquals(1, testHarness.numProcessingTimeTimers())
 
-    // expired right stream record with key value of 2
-    testHarness.setProcessingTime(8)
+    // expired all data
+    testHarness.setProcessingTime(9)
     assertEquals(0, testHarness.numKeyedStateEntries())
     assertEquals(0, testHarness.numProcessingTimeTimers())
 
@@ -1067,32 +1058,36 @@ class JoinHarnessTest extends HarnessTestBase {
       CRow(1: JInt, "Hi1")))
     testHarness.processElement2(new StreamRecord(
       CRow(false, 1: JInt, "Hi1")))
-    assertEquals(5, testHarness.numKeyedStateEntries())
-    assertEquals(3, testHarness.numProcessingTimeTimers())
+    // lkeys(1, 2) rkeys() timer_key_time(1:5, 2:6)
+    assertEquals(4, testHarness.numKeyedStateEntries())
+    assertEquals(2, testHarness.numProcessingTimeTimers())
     testHarness.setProcessingTime(4)
     testHarness.processElement2(new StreamRecord(
       CRow(2: JInt, "Hello1")))
-    assertEquals(7, testHarness.numKeyedStateEntries())
-    assertEquals(4, testHarness.numProcessingTimeTimers())
+    // lkeys(1, 2) rkeys(2) timer_key_time(1:5, 2:6)
+    assertEquals(5, testHarness.numKeyedStateEntries())
+    assertEquals(2, testHarness.numProcessingTimeTimers())
 
     testHarness.processElement1(new StreamRecord(
       CRow(false, 1: JInt, "aaa")))
-    // expired left stream record with key value of 1
+    // expired stream records with key value of 1
     testHarness.setProcessingTime(5)
     testHarness.processElement2(new StreamRecord(
       CRow(1: JInt, "Hi2")))
     testHarness.processElement2(new StreamRecord(
       CRow(false, 1: JInt, "Hi2")))
-    assertEquals(5, testHarness.numKeyedStateEntries())
-    assertEquals(3, testHarness.numProcessingTimeTimers())
+    // lkeys(2) rkeys(2) timer_key_time(1:9, 2:6)
+    assertEquals(4, testHarness.numKeyedStateEntries())
+    assertEquals(2, testHarness.numProcessingTimeTimers())
 
-    // expired all left stream record
+    // expired stream records with key value of 2
     testHarness.setProcessingTime(6)
-    assertEquals(3, testHarness.numKeyedStateEntries())
-    assertEquals(2, testHarness.numProcessingTimeTimers())
+    // lkeys() rkeys() timer_key_time(1:9)
+    assertEquals(1, testHarness.numKeyedStateEntries())
+    assertEquals(1, testHarness.numProcessingTimeTimers())
 
-    // expired right stream record with key value of 2
-    testHarness.setProcessingTime(8)
+    // expired all data
+    testHarness.setProcessingTime(9)
     assertEquals(0, testHarness.numKeyedStateEntries())
     assertEquals(0, testHarness.numProcessingTimeTimers())
 
@@ -1160,7 +1155,7 @@ class JoinHarnessTest extends HarnessTestBase {
     testHarness.processElement1(new StreamRecord(
       CRow(1: JInt, "bbb")))
     assertEquals(1, testHarness.numProcessingTimeTimers())
-    // 1 left timer(5), 1 left key(1), 1 join cnt
+    // lkeys(1) rkeys() timer_key_time(1:5)
     assertEquals(3, testHarness.numKeyedStateEntries())
     testHarness.setProcessingTime(2)
     testHarness.processElement1(new StreamRecord(
@@ -1168,7 +1163,8 @@ class JoinHarnessTest extends HarnessTestBase {
     testHarness.processElement1(new StreamRecord(
       CRow(2: JInt, "bbb")))
     assertEquals(2, testHarness.numProcessingTimeTimers())
-    // 2 left timer(5,6), 2 left key(1,2), 2 join cnt
+    // lkeys(1, 2) rkeys() timer_key_time(1:5, 2:6)
+    // l_join_cnt_keys(1, 2)
     assertEquals(6, testHarness.numKeyedStateEntries())
     testHarness.setProcessingTime(3)
 
@@ -1177,17 +1173,19 @@ class JoinHarnessTest extends HarnessTestBase {
       CRow(1: JInt, "Hi1")))
     testHarness.processElement2(new StreamRecord(
       CRow(false, 1: JInt, "bbb")))
-    // 2 left timer(5,6), 2 left keys(1,2), 2 join cnt, 1 right timer(7), 1 right key(1)
-    assertEquals(8, testHarness.numKeyedStateEntries())
-    assertEquals(3, testHarness.numProcessingTimeTimers())
+    // lkeys(1, 2) rkeys(1) timer_key_time(1:5, 2:6)
+    // l_join_cnt_keys(1, 2)
+    assertEquals(7, testHarness.numKeyedStateEntries())
+    assertEquals(2, testHarness.numProcessingTimeTimers())
     testHarness.setProcessingTime(4)
     testHarness.processElement2(new StreamRecord(
       CRow(2: JInt, "ccc")))
     testHarness.processElement2(new StreamRecord(
       CRow(2: JInt, "Hello")))
-    // 2 left timer(5,6), 2 left keys(1,2), 2 join cnt, 2 right timer(7,8), 2 right key(1,2)
-    assertEquals(10, testHarness.numKeyedStateEntries())
-    assertEquals(4, testHarness.numProcessingTimeTimers())
+    // lkeys(1, 2) rkeys(1, 2) timer_key_time(1:5, 2:6)
+    // l_join_cnt_keys(1, 2)
+    assertEquals(8, testHarness.numKeyedStateEntries())
+    assertEquals(2, testHarness.numProcessingTimeTimers())
 
     testHarness.processElement1(new StreamRecord(
       CRow(false, 1: JInt, "aaa")))
@@ -1197,22 +1195,29 @@ class JoinHarnessTest extends HarnessTestBase {
       CRow(false, 1: JInt, "Hi2")))
     testHarness.processElement2(new StreamRecord(
       CRow(false, 1: JInt, "Hi1")))
-    // expired left stream record with key value of 1
+    // lkeys(1, 2) rkeys(2) timer_key_time(1:8, 2:6)
+    // l_join_cnt_keys(1, 2)
+    assertEquals(7, testHarness.numKeyedStateEntries())
     testHarness.setProcessingTime(5)
+    // [1]. this will clean up left stream records with expired time of 5
     testHarness.processElement2(new StreamRecord(
       CRow(1: JInt, "Hi3")))
+    // [2]. there are no elements can be connected, since be cleaned by [1]
     testHarness.processElement2(new StreamRecord(
       CRow(false, 1: JInt, "Hi3")))
-    // 1 left timer(6), 1 left keys(2), 1 join cnt, 2 right timer(7,8), 1 right key(2)
-    assertEquals(6, testHarness.numKeyedStateEntries())
-    assertEquals(3, testHarness.numProcessingTimeTimers())
+    // lkeys(1, 2) rkeys(2) timer_key_time(1:8, 2:6)
+    // l_join_cnt_keys(1, 2)
+    assertEquals(7, testHarness.numKeyedStateEntries())
+    assertEquals(2, testHarness.numProcessingTimeTimers())
 
-    // expired all left stream record
+    // expired all records with key value of 2
     testHarness.setProcessingTime(6)
+    // lkeys(1) rkeys() timer_key_time(1:8)
+    // l_join_cnt_keys(1)
     assertEquals(3, testHarness.numKeyedStateEntries())
-    assertEquals(2, testHarness.numProcessingTimeTimers())
+    assertEquals(1, testHarness.numProcessingTimeTimers())
 
-    // expired right stream record with key value of 2
+    // expired all data
     testHarness.setProcessingTime(8)
     assertEquals(0, testHarness.numKeyedStateEntries())
     assertEquals(0, testHarness.numProcessingTimeTimers())
@@ -1253,6 +1258,12 @@ class JoinHarnessTest extends HarnessTestBase {
       CRow(false, 1: JInt, "bbb", 1: JInt, "Hi1")))
     expectedOutput.add(new StreamRecord(
       CRow(1: JInt, "bbb", null: JInt, null)))
+    // processing time of 5
+    // timer of 8, we use only one timer state now
+    expectedOutput.add(new StreamRecord(
+      CRow(false, 1: JInt, "bbb", null: JInt, null)))
+    expectedOutput.add(new StreamRecord(
+      CRow(1: JInt, "bbb", 1: JInt, "Hi3")))
     verify(expectedOutput, result)
 
     testHarness.close()
@@ -1305,32 +1316,36 @@ class JoinHarnessTest extends HarnessTestBase {
       CRow(1: JInt, "Hi1")))
     testHarness.processElement1(new StreamRecord(
       CRow(false, 1: JInt, "Hi1")))
-    assertEquals(5, testHarness.numKeyedStateEntries())
-    assertEquals(3, testHarness.numProcessingTimeTimers())
+    // lkeys() rkeys(1, 2) timer_key_time(1:5, 2:6)
+    assertEquals(4, testHarness.numKeyedStateEntries())
+    assertEquals(2, testHarness.numProcessingTimeTimers())
     testHarness.setProcessingTime(4)
     testHarness.processElement1(new StreamRecord(
       CRow(2: JInt, "Hello1")))
-    assertEquals(7, testHarness.numKeyedStateEntries())
-    assertEquals(4, testHarness.numProcessingTimeTimers())
+    // lkeys(2) rkeys(1, 2) timer_key_time(1:5, 2:6)
+    assertEquals(5, testHarness.numKeyedStateEntries())
+    assertEquals(2, testHarness.numProcessingTimeTimers())
 
     testHarness.processElement2(new StreamRecord(
       CRow(false, 1: JInt, "aaa")))
-    // expired right stream record with key value of 1
+    // expired stream records with key value of 1
     testHarness.setProcessingTime(5)
     testHarness.processElement1(new StreamRecord(
       CRow(1: JInt, "Hi2")))
     testHarness.processElement1(new StreamRecord(
       CRow(false, 1: JInt, "Hi2")))
-    assertEquals(5, testHarness.numKeyedStateEntries())
-    assertEquals(3, testHarness.numProcessingTimeTimers())
+    // lkeys(2) rkeys(2) timer_key_time(1:9, 2:6)
+    assertEquals(4, testHarness.numKeyedStateEntries())
+    assertEquals(2, testHarness.numProcessingTimeTimers())
 
-    // expired all right stream record
+    // expired stream records with key value of 2
     testHarness.setProcessingTime(6)
-    assertEquals(3, testHarness.numKeyedStateEntries())
-    assertEquals(2, testHarness.numProcessingTimeTimers())
+    // lkeys() rkeys() timer_key_time(1:9)
+    assertEquals(1, testHarness.numKeyedStateEntries())
+    assertEquals(1, testHarness.numProcessingTimeTimers())
 
-    // expired left stream record with key value of 2
-    testHarness.setProcessingTime(8)
+    // expired all data
+    testHarness.setProcessingTime(9)
     assertEquals(0, testHarness.numKeyedStateEntries())
     assertEquals(0, testHarness.numProcessingTimeTimers())
 
@@ -1398,15 +1413,17 @@ class JoinHarnessTest extends HarnessTestBase {
     testHarness.processElement2(new StreamRecord(
       CRow(1: JInt, "bbb")))
     assertEquals(1, testHarness.numProcessingTimeTimers())
-    // 1 right timer(5), 1 right key(1), 1 join cnt
+    // lkeys() rkeys(1) timer_key_time(1:5)
+    // r_join_cnt_keys(1)
     assertEquals(3, testHarness.numKeyedStateEntries())
     testHarness.setProcessingTime(2)
     testHarness.processElement2(new StreamRecord(
       CRow(1: JInt, "aaa")))
     testHarness.processElement2(new StreamRecord(
       CRow(2: JInt, "bbb")))
+    // lkeys() rkeys(1, 2) timer_key_time(1:5, 2:6)
+    // r_join_cnt_keys(1, 2)
     assertEquals(2, testHarness.numProcessingTimeTimers())
-    // 2 right timer(5,6), 2 right key(1,2), 2 join cnt
     assertEquals(6, testHarness.numKeyedStateEntries())
     testHarness.setProcessingTime(3)
 
@@ -1415,17 +1432,19 @@ class JoinHarnessTest extends HarnessTestBase {
       CRow(1: JInt, "Hi1")))
     testHarness.processElement1(new StreamRecord(
       CRow(false, 1: JInt, "bbb")))
-    // 2 right timer(5,6), 2 right keys(1,2), 2 join cnt, 1 left timer(7), 1 left key(1)
-    assertEquals(8, testHarness.numKeyedStateEntries())
-    assertEquals(3, testHarness.numProcessingTimeTimers())
+    // lkeys(1) rkeys(1, 2) timer_key_time(1:5, 2:6)
+    // r_join_cnt_keys(1, 2)
+    assertEquals(7, testHarness.numKeyedStateEntries())
+    assertEquals(2, testHarness.numProcessingTimeTimers())
     testHarness.setProcessingTime(4)
     testHarness.processElement1(new StreamRecord(
       CRow(2: JInt, "ccc")))
     testHarness.processElement1(new StreamRecord(
       CRow(2: JInt, "Hello")))
-    // 2 right timer(5,6), 2 right keys(1,2), 2 join cnt, 2 left timer(7,8), 2 left key(1,2)
-    assertEquals(10, testHarness.numKeyedStateEntries())
-    assertEquals(4, testHarness.numProcessingTimeTimers())
+    // lkeys(1, 2) rkeys(1, 2) timer_key_time(1:5, 2:6)
+    // r_join_cnt_keys(1, 2)
+    assertEquals(8, testHarness.numKeyedStateEntries())
+    assertEquals(2, testHarness.numProcessingTimeTimers())
 
     testHarness.processElement2(new StreamRecord(
       CRow(false, 1: JInt, "aaa")))
@@ -1435,22 +1454,27 @@ class JoinHarnessTest extends HarnessTestBase {
       CRow(false, 1: JInt, "Hi2")))
     testHarness.processElement1(new StreamRecord(
       CRow(false, 1: JInt, "Hi1")))
-    // expired right stream record with key value of 1
+    // lkeys(2) rkeys(1, 2) timer_key_time(1:8, 2:6)
+    // r_join_cnt_keys(1, 2)
+    assertEquals(7, testHarness.numKeyedStateEntries())
     testHarness.setProcessingTime(5)
     testHarness.processElement1(new StreamRecord(
       CRow(1: JInt, "Hi3")))
     testHarness.processElement1(new StreamRecord(
       CRow(false, 1: JInt, "Hi3")))
-    // 1 right timer(6), 1 right keys(2), 1 join cnt, 2 left timer(7,8), 1 left key(2)
-    assertEquals(6, testHarness.numKeyedStateEntries())
-    assertEquals(3, testHarness.numProcessingTimeTimers())
+    // lkeys(2) rkeys(1, 2) timer_key_time(1:8, 2:6)
+    // r_join_cnt_keys(1, 2)
+    assertEquals(7, testHarness.numKeyedStateEntries())
+    assertEquals(2, testHarness.numProcessingTimeTimers())
 
-    // expired all right stream record
+    // expired all stream records with key value of 2
+    // lkeys() rkeys(1) timer_key_time(1:8)
+    // r_join_cnt_keys(1)
     testHarness.setProcessingTime(6)
     assertEquals(3, testHarness.numKeyedStateEntries())
-    assertEquals(2, testHarness.numProcessingTimeTimers())
+    assertEquals(1, testHarness.numProcessingTimeTimers())
 
-    // expired left stream record with key value of 2
+    // expired all data
     testHarness.setProcessingTime(8)
     assertEquals(0, testHarness.numKeyedStateEntries())
     assertEquals(0, testHarness.numProcessingTimeTimers())
@@ -1491,6 +1515,12 @@ class JoinHarnessTest extends HarnessTestBase {
       CRow(false, 1: JInt, "Hi1", 1: JInt, "bbb")))
     expectedOutput.add(new StreamRecord(
       CRow(null: JInt, null, 1: JInt, "bbb")))
+    // processing time of 5
+    // timer of 8, we use only one timer state now
+    expectedOutput.add(new StreamRecord(
+      CRow(false, null: JInt, null, 1: JInt, "bbb")))
+    expectedOutput.add(new StreamRecord(
+      CRow(1: JInt, "Hi3", 1: JInt, "bbb")))
     verify(expectedOutput, result)
 
     testHarness.close()
@@ -1524,8 +1554,8 @@ class JoinHarnessTest extends HarnessTestBase {
       CRow(1: JInt, "bbb")))
     testHarness.processElement1(new StreamRecord(
       CRow(1: JInt, "ccc")))
+    // lkeys(1) rkeys() timer_key_time(1:5)
     assertEquals(1, testHarness.numProcessingTimeTimers())
-    // 1 left timer(5), 1 left key(1)
     assertEquals(2, testHarness.numKeyedStateEntries())
 
     testHarness.setProcessingTime(2)
@@ -1534,8 +1564,7 @@ class JoinHarnessTest extends HarnessTestBase {
     testHarness.processElement2(new StreamRecord(
       CRow(2: JInt, "ccc")))
     assertEquals(2, testHarness.numProcessingTimeTimers())
-    // 1 left timer(5), 1 left key(1)
-    // 1 right timer(6), 1 right key(1)
+    // lkeys(1) rkeys(2) timer_key_time(1:5, 2:6)
     assertEquals(4, testHarness.numKeyedStateEntries())
 
     testHarness.setProcessingTime(3)
@@ -1543,18 +1572,16 @@ class JoinHarnessTest extends HarnessTestBase {
       CRow(2: JInt, "aaa")))
     testHarness.processElement1(new StreamRecord(
       CRow(2: JInt, "ddd")))
-    assertEquals(3, testHarness.numProcessingTimeTimers())
-    // 2 left timer(5,7), 2 left key(1,2)
-    // 1 right timer(6), 1 right key(1)
-    assertEquals(6, testHarness.numKeyedStateEntries())
+    // lkeys(1, 2) rkeys(2) timer_key_time(1:5, 2:6)
+    assertEquals(2, testHarness.numProcessingTimeTimers())
+    assertEquals(5, testHarness.numKeyedStateEntries())
     testHarness.processElement2(new StreamRecord(
       CRow(1: JInt, "aaa")))
     testHarness.processElement2(new StreamRecord(
       CRow(1: JInt, "ddd")))
-    assertEquals(4, testHarness.numProcessingTimeTimers())
-    // 2 left timer(5,7), 2 left key(1,2)
-    // 2 right timer(6,7), 2 right key(1,2)
-    assertEquals(8, testHarness.numKeyedStateEntries())
+    // lkeys(1, 2) rkeys(1, 2) timer_key_time(1:5, 2:6)
+    assertEquals(2, testHarness.numProcessingTimeTimers())
+    assertEquals(6, testHarness.numKeyedStateEntries())
 
     testHarness.setProcessingTime(4)
     testHarness.processElement1(new StreamRecord(
@@ -1565,28 +1592,26 @@ class JoinHarnessTest extends HarnessTestBase {
       CRow(false, 1: JInt, "aaa")))
     testHarness.processElement2(new StreamRecord(
       CRow(false, 1: JInt, "ddd")))
-    assertEquals(4, testHarness.numProcessingTimeTimers())
-    // 2 left timer(5,7), 1 left key(1)
-    // 2 right timer(6,7), 1 right key(2)
-    assertEquals(6, testHarness.numKeyedStateEntries())
+    // lkeys(1) rkeys(2) timer_key_time(1:8, 2:6)
+    assertEquals(2, testHarness.numProcessingTimeTimers())
+    assertEquals(4, testHarness.numKeyedStateEntries())
 
     testHarness.setProcessingTime(5)
-    assertEquals(3, testHarness.numProcessingTimeTimers())
-    // 1 left timer(7)
-    // 2 right timer(6,7), 1 right key(2)
+    assertEquals(2, testHarness.numProcessingTimeTimers())
     assertEquals(4, testHarness.numKeyedStateEntries())
 
     testHarness.setProcessingTime(6)
-    assertEquals(2, testHarness.numProcessingTimeTimers())
-    // 1 left timer(7)
-    // 2 right timer(7)
+    // lkeys(1) rkeys() timer_key_time(1:8)
+    assertEquals(1, testHarness.numProcessingTimeTimers())
     assertEquals(2, testHarness.numKeyedStateEntries())
 
     testHarness.setProcessingTime(7)
-    assertEquals(0, testHarness.numProcessingTimeTimers())
-    assertEquals(0, testHarness.numKeyedStateEntries())
+    assertEquals(1, testHarness.numProcessingTimeTimers())
+    assertEquals(2, testHarness.numKeyedStateEntries())
 
     testHarness.setProcessingTime(8)
+    assertEquals(0, testHarness.numProcessingTimeTimers())
+    assertEquals(0, testHarness.numKeyedStateEntries())
     testHarness.processElement1(new StreamRecord(
       CRow(1: JInt, "bbb")))
     testHarness.processElement2(new StreamRecord(
@@ -1693,8 +1718,9 @@ class JoinHarnessTest extends HarnessTestBase {
       CRow(1: JInt, "bbb")))
     testHarness.processElement1(new StreamRecord(
       CRow(1: JInt, "ccc")))
+    // lkeys(1) rkeys() timer_key_time(1:5)
+    // l_join_cnt_keys(1) r_join_cnt_keys()
     assertEquals(1, testHarness.numProcessingTimeTimers())
-    // 1 left timer(5), 1 left key(1), 1 left joincnt key(1)
     assertEquals(3, testHarness.numKeyedStateEntries())
 
     testHarness.setProcessingTime(2)
@@ -1702,9 +1728,9 @@ class JoinHarnessTest extends HarnessTestBase {
       CRow(2: JInt, "bbb")))
     testHarness.processElement2(new StreamRecord(
       CRow(2: JInt, "ccc")))
+    // lkeys(1) rkeys(2) timer_key_time(1:5, 2:6)
+    // l_join_cnt_keys(1) r_join_cnt_keys(2)
     assertEquals(2, testHarness.numProcessingTimeTimers())
-    // 1 left timer(5), 1 left key(1), 1 left joincnt key(1)
-    // 1 right timer(6), 1 right key(1), 1 right joincnt key(1)
     assertEquals(6, testHarness.numKeyedStateEntries())
 
     testHarness.setProcessingTime(3)
@@ -1712,46 +1738,46 @@ class JoinHarnessTest extends HarnessTestBase {
       CRow(2: JInt, "aaa")))
     testHarness.processElement1(new StreamRecord(
       CRow(2: JInt, "ddd")))
-    assertEquals(3, testHarness.numProcessingTimeTimers())
-    // 2 left timer(5,7), 2 left key(1,2), 2 left joincnt key(1,2)
-    // 1 right timer(6), 1 right key(1), 1 right joincnt key(1)
-    assertEquals(9, testHarness.numKeyedStateEntries())
+    // lkeys(1, 2) rkeys(2) timer_key_time(1:5, 2:6)
+    // l_join_cnt_keys(1, 2) r_join_cnt_keys(2)
+    assertEquals(2, testHarness.numProcessingTimeTimers())
+    assertEquals(8, testHarness.numKeyedStateEntries())
     testHarness.processElement2(new StreamRecord(
       CRow(1: JInt, "aaa")))
     testHarness.processElement2(new StreamRecord(
       CRow(1: JInt, "ddd")))
-    assertEquals(4, testHarness.numProcessingTimeTimers())
-    // 2 left timer(5,7), 2 left key(1,2), 2 left joincnt key(1,2)
-    // 2 right timer(6,7), 2 right key(1,2), 2 right joincnt key(1,2)
-    assertEquals(12, testHarness.numKeyedStateEntries())
+    // lkeys(1, 2) rkeys(1, 2) timer_key_time(1:5, 2:6)
+    // l_join_cnt_keys(1, 2) r_join_cnt_keys(1, 2)
+    assertEquals(2, testHarness.numProcessingTimeTimers())
+    assertEquals(10, testHarness.numKeyedStateEntries())
 
     testHarness.setProcessingTime(4)
     testHarness.processElement1(new StreamRecord(
       CRow(false, 2: JInt, "aaa")))
     testHarness.processElement2(new StreamRecord(
       CRow(false, 1: JInt, "ddd")))
-    assertEquals(4, testHarness.numProcessingTimeTimers())
-    // 2 left timer(5,7), 2 left key(1,2), 2 left joincnt key(1,2)
-    // 2 right timer(6,7), 2 right key(1,2), 2 right joincnt key(1,2)
-    assertEquals(12, testHarness.numKeyedStateEntries())
+    // lkeys(1, 2) rkeys(1, 2) timer_key_time(1:8, 2:6)
+    // l_join_cnt_keys(1, 2) r_join_cnt_keys(1, 2)
+    assertEquals(2, testHarness.numProcessingTimeTimers())
+    assertEquals(10, testHarness.numKeyedStateEntries())
 
     testHarness.setProcessingTime(5)
-    assertEquals(3, testHarness.numProcessingTimeTimers())
-    // 1 left timer(7), 1 left key(2), 1 left joincnt key(2)
-    // 2 right timer(6,7), 2 right key(1,2), 2 right joincnt key(1,2)
-    assertEquals(9, testHarness.numKeyedStateEntries())
+    assertEquals(2, testHarness.numProcessingTimeTimers())
+    assertEquals(10, testHarness.numKeyedStateEntries())
 
     testHarness.setProcessingTime(6)
-    assertEquals(2, testHarness.numProcessingTimeTimers())
-    // 1 left timer(7), 1 left key(2), 1 left joincnt key(2)
-    // 1 right timer(7), 1 right key(2), 1 right joincnt key(2)
-    assertEquals(6, testHarness.numKeyedStateEntries())
+    // lkeys(1) rkeys(1) timer_key_time(1:8)
+    // l_join_cnt_keys(1) r_join_cnt_keys(1)
+    assertEquals(1, testHarness.numProcessingTimeTimers())
+    assertEquals(5, testHarness.numKeyedStateEntries())
 
     testHarness.setProcessingTime(7)
-    assertEquals(0, testHarness.numProcessingTimeTimers())
-    assertEquals(0, testHarness.numKeyedStateEntries())
+    assertEquals(1, testHarness.numProcessingTimeTimers())
+    assertEquals(5, testHarness.numKeyedStateEntries())
 
     testHarness.setProcessingTime(8)
+    assertEquals(0, testHarness.numProcessingTimeTimers())
+    assertEquals(0, testHarness.numKeyedStateEntries())
     testHarness.processElement1(new StreamRecord(
       CRow(1: JInt, "bbb")))
     testHarness.processElement2(new StreamRecord(
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/StateCleaningCountTriggerHarnessTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/StateCleaningCountTriggerHarnessTest.scala
index 7f9c0ef..25395be 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/StateCleaningCountTriggerHarnessTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/StateCleaningCountTriggerHarnessTest.scala
@@ -80,8 +80,8 @@ class StateCleaningCountTriggerHarnessTest {
       TriggerResult.CONTINUE,
       testHarness.processElement(new StreamRecord(1), GlobalWindow.get))
 
-    // have two timers 6001 and 7002
-    assertEquals(2, testHarness.numProcessingTimeTimers)
+    // have one timer 7002
+    assertEquals(1, testHarness.numProcessingTimeTimers)
     assertEquals(0, testHarness.numEventTimeTimers)
     assertEquals(2, testHarness.numStateEntries)
     assertEquals(2, testHarness.numStateEntries(GlobalWindow.get))
@@ -116,9 +116,6 @@ class StateCleaningCountTriggerHarnessTest {
 
     // try to trigger onProcessingTime method via 7002, and all states are cleared
     val timesIt = testHarness.advanceProcessingTime(7002).iterator()
-    assertEquals(
-      TriggerResult.CONTINUE,
-      timesIt.next().f1)
 
     assertEquals(
       TriggerResult.FIRE_AND_PURGE,
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/operators/KeyedProcessFunctionWithCleanupStateTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/operators/KeyedProcessFunctionWithCleanupStateTest.scala
index fe90a5f..1c02889 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/operators/KeyedProcessFunctionWithCleanupStateTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/operators/KeyedProcessFunctionWithCleanupStateTest.scala
@@ -110,7 +110,7 @@ private class MockedKeyedProcessFunction(queryConfig: StreamQueryConfig)
       out: Collector[String]): Unit = {
 
     val curTime = ctx.timerService().currentProcessingTime()
-    registerProcessingCleanupTimer(ctx, curTime)
+    processCleanupTimer(ctx, curTime)
     state.update(value._2)
   }
 
@@ -119,8 +119,12 @@ private class MockedKeyedProcessFunction(queryConfig: StreamQueryConfig)
       ctx: KeyedProcessFunction[String, (String, String), String]#OnTimerContext,
       out: Collector[String]): Unit = {
 
-    if (needToCleanupState(timestamp)) {
-      cleanupState(state)
+    if (stateCleaningEnabled) {
+      val cleanupTime = cleanupTimeState.value()
+      if (null != cleanupTime && timestamp == cleanupTime) {
+        // clean up
+        cleanupState(state)
+      }
     }
   }
 }
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/operators/ProcessFunctionWithCleanupStateTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/operators/ProcessFunctionWithCleanupStateTest.scala
index 519b03f..6c0ca1a 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/operators/ProcessFunctionWithCleanupStateTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/operators/ProcessFunctionWithCleanupStateTest.scala
@@ -110,7 +110,7 @@ private class MockedProcessFunction(queryConfig: StreamQueryConfig)
       out: Collector[String]): Unit = {
 
     val curTime = ctx.timerService().currentProcessingTime()
-    registerProcessingCleanupTimer(ctx, curTime)
+    processCleanupTimer(ctx, curTime)
     state.update(value._2)
   }
 
@@ -119,7 +119,7 @@ private class MockedProcessFunction(queryConfig: StreamQueryConfig)
       ctx: ProcessFunction[(String, String), String]#OnTimerContext,
       out: Collector[String]): Unit = {
 
-    if (needToCleanupState(timestamp)) {
+    if (stateCleaningEnabled) {
       cleanupState(state)
     }
   }