You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by th...@apache.org on 2019/03/29 19:37:49 UTC

[beam] branch master updated: [BEAM-6876] Cleanup user state in portable Flink Runner

This is an automated email from the ASF dual-hosted git repository.

thw pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new baa09ae  [BEAM-6876] Cleanup user state in portable Flink Runner
     new afcf5ee  Merge pull request #8118: [BEAM-6876] Cleanup user state in portable Flink Runner
baa09ae is described below

commit baa09aedb8eb4d60e3c192f302a25c33fdbc9460
Author: Maximilian Michels <mx...@apache.org>
AuthorDate: Fri Mar 22 16:25:14 2019 +0100

    [BEAM-6876] Cleanup user state in portable Flink Runner
    
    State had to be explicitly cleaned up in user state using timers which fire at
    the end of a window. This uses the StatefulDoFnRunner to set timers to clean up
    user state at the end of each window.
---
 .../org/apache/beam/runners/core/DoFnRunners.java  |   2 +-
 .../beam/runners/core/StatefulDoFnRunner.java      |  15 +-
 .../runners/flink/translation/utils/NoopLock.java  |  72 --------
 .../streaming/ExecutableStageDoFnOperator.java     | 185 +++++++++++++++++----
 .../streaming/ExecutableStageDoFnOperatorTest.java | 132 ++++++++++++++-
 5 files changed, 287 insertions(+), 119 deletions(-)

diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/DoFnRunners.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/DoFnRunners.java
index f7646ea..3d929d7 100644
--- a/runners/core-java/src/main/java/org/apache/beam/runners/core/DoFnRunners.java
+++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/DoFnRunners.java
@@ -100,7 +100,7 @@ public class DoFnRunners {
           DoFn<InputT, OutputT> fn,
           DoFnRunner<InputT, OutputT> doFnRunner,
           WindowingStrategy<?, ?> windowingStrategy,
-          CleanupTimer cleanupTimer,
+          CleanupTimer<InputT> cleanupTimer,
           StateCleaner<W> stateCleaner) {
     return new StatefulDoFnRunner<>(doFnRunner, windowingStrategy, cleanupTimer, stateCleaner);
   }
diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/StatefulDoFnRunner.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/StatefulDoFnRunner.java
index 6cd580c..14a9502 100644
--- a/runners/core-java/src/main/java/org/apache/beam/runners/core/StatefulDoFnRunner.java
+++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/StatefulDoFnRunner.java
@@ -52,13 +52,13 @@ public class StatefulDoFnRunner<InputT, OutputT, W extends BoundedWindow>
   private final WindowingStrategy<?, ?> windowingStrategy;
   private final Counter droppedDueToLateness =
       Metrics.counter(StatefulDoFnRunner.class, DROPPED_DUE_TO_LATENESS_COUNTER);
-  private final CleanupTimer cleanupTimer;
+  private final CleanupTimer<InputT> cleanupTimer;
   private final StateCleaner stateCleaner;
 
   public StatefulDoFnRunner(
       DoFnRunner<InputT, OutputT> doFnRunner,
       WindowingStrategy<?, ?> windowingStrategy,
-      CleanupTimer cleanupTimer,
+      CleanupTimer<InputT> cleanupTimer,
       StateCleaner<W> stateCleaner) {
     this.doFnRunner = doFnRunner;
     this.windowingStrategy = windowingStrategy;
@@ -103,7 +103,7 @@ public class StatefulDoFnRunner<InputT, OutputT, W extends BoundedWindow>
             window,
             cleanupTimer.currentInputWatermarkTime());
       } else {
-        cleanupTimer.setForWindow(window);
+        cleanupTimer.setForWindow(value.getValue(), window);
         doFnRunner.processElement(value);
       }
     }
@@ -151,7 +151,7 @@ public class StatefulDoFnRunner<InputT, OutputT, W extends BoundedWindow>
    * time or (b) not need a timer at all because it is a batch runner that discards state when it is
    * done.
    */
-  public interface CleanupTimer {
+  public interface CleanupTimer<InputT> {
 
     /**
      * Return the current, local input watermark timestamp for this computation in the {@link
@@ -160,7 +160,7 @@ public class StatefulDoFnRunner<InputT, OutputT, W extends BoundedWindow>
     Instant currentInputWatermarkTime();
 
     /** Set the garbage collect time of the window to timer. */
-    void setForWindow(BoundedWindow window);
+    void setForWindow(InputT value, BoundedWindow window);
 
     /** Checks whether the given timer is a cleanup timer for the window. */
     boolean isForWindow(
@@ -174,7 +174,8 @@ public class StatefulDoFnRunner<InputT, OutputT, W extends BoundedWindow>
   }
 
   /** A {@link StatefulDoFnRunner.CleanupTimer} implemented via {@link TimerInternals}. */
-  public static class TimeInternalsCleanupTimer implements StatefulDoFnRunner.CleanupTimer {
+  public static class TimeInternalsCleanupTimer<InputT>
+      implements StatefulDoFnRunner.CleanupTimer<InputT> {
 
     public static final String GC_TIMER_ID = "__StatefulParDoGcTimerId";
 
@@ -202,7 +203,7 @@ public class StatefulDoFnRunner<InputT, OutputT, W extends BoundedWindow>
     }
 
     @Override
-    public void setForWindow(BoundedWindow window) {
+    public void setForWindow(InputT input, BoundedWindow window) {
       Instant gcTime = LateDataUtils.garbageCollectionTime(window, windowingStrategy);
       // make sure this fires after any window.maxTimestamp() timers
       gcTime = gcTime.plus(GC_DELAY_MS);
diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/utils/NoopLock.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/utils/NoopLock.java
deleted file mode 100644
index ee65c22..0000000
--- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/utils/NoopLock.java
+++ /dev/null
@@ -1,72 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.beam.runners.flink.translation.utils;
-
-import java.io.Serializable;
-import java.util.concurrent.TimeUnit;
-import java.util.concurrent.locks.Condition;
-import java.util.concurrent.locks.Lock;
-import javax.annotation.Nonnull;
-
-/**
- * A lock which can always be acquired. It should not be used when a proper lock is required, but it
- * is useful as a performance optimization when locking is not necessary but the code paths have to
- * be shared between the locking and the non-locking variant.
- *
- * <p>For example, in {@link
- * org.apache.beam.runners.flink.translation.wrappers.streaming.ExecutableStageDoFnOperator}, the
- * locking on the state backend is only required when both timers and state are used.
- */
-public class NoopLock implements Lock, Serializable {
-
-  private static NoopLock instance;
-
-  public static NoopLock get() {
-    if (instance == null) {
-      instance = new NoopLock();
-    }
-    return instance;
-  }
-
-  private NoopLock() {}
-
-  @Override
-  public void lock() {}
-
-  @Override
-  public void lockInterruptibly() {}
-
-  @Override
-  public boolean tryLock() {
-    return true;
-  }
-
-  @Override
-  public boolean tryLock(long time, @Nonnull TimeUnit unit) {
-    return true;
-  }
-
-  @Override
-  public void unlock() {}
-
-  @Nonnull
-  @Override
-  public Condition newCondition() {
-    throw new UnsupportedOperationException("Not implemented");
-  }
-}
diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperator.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperator.java
index 210da7d..dcc473f 100644
--- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperator.java
+++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperator.java
@@ -35,22 +35,27 @@ import java.util.concurrent.LinkedBlockingQueue;
 import java.util.concurrent.locks.Lock;
 import java.util.concurrent.locks.ReentrantLock;
 import java.util.function.BiConsumer;
+import java.util.function.Consumer;
+import java.util.stream.Collectors;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleProgressResponse;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleResponse;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey.TypeCase;
 import org.apache.beam.model.pipeline.v1.RunnerApi;
 import org.apache.beam.runners.core.DoFnRunner;
+import org.apache.beam.runners.core.DoFnRunners;
+import org.apache.beam.runners.core.LateDataUtils;
 import org.apache.beam.runners.core.StateInternals;
 import org.apache.beam.runners.core.StateNamespace;
 import org.apache.beam.runners.core.StateNamespaces;
 import org.apache.beam.runners.core.StateTags;
+import org.apache.beam.runners.core.StatefulDoFnRunner;
 import org.apache.beam.runners.core.TimerInternals;
 import org.apache.beam.runners.core.construction.Timer;
 import org.apache.beam.runners.core.construction.graph.ExecutableStage;
+import org.apache.beam.runners.core.construction.graph.UserStateReference;
 import org.apache.beam.runners.flink.metrics.FlinkMetricContainer;
 import org.apache.beam.runners.flink.translation.functions.FlinkExecutableStageContext;
 import org.apache.beam.runners.flink.translation.functions.FlinkStreamingSideInputHandlerFactory;
-import org.apache.beam.runners.flink.translation.utils.NoopLock;
 import org.apache.beam.runners.fnexecution.control.BundleProgressHandler;
 import org.apache.beam.runners.fnexecution.control.OutputReceiverFactory;
 import org.apache.beam.runners.fnexecution.control.ProcessBundleDescriptors;
@@ -104,8 +109,7 @@ public class ExecutableStageDoFnOperator<InputT, OutputT> extends DoFnOperator<I
   private final FlinkExecutableStageContext.Factory contextFactory;
   private final Map<String, TupleTag<?>> outputMap;
   private final Map<RunnerApi.ExecutableStagePayload.SideInputId, PCollectionView<?>> sideInputIds;
-  private final boolean usesTimers;
-  /** A lock which has to be acquired when concurrently accessing state and setting timers. */
+  /** A lock which has to be acquired when concurrently accessing state and timers. */
   private final Lock stateBackendLock;
 
   private transient FlinkExecutableStageContext stageContext;
@@ -158,19 +162,7 @@ public class ExecutableStageDoFnOperator<InputT, OutputT> extends DoFnOperator<I
     this.contextFactory = contextFactory;
     this.outputMap = outputMap;
     this.sideInputIds = sideInputIds;
-    this.usesTimers = payload.getTimersCount() > 0;
-    if (usesTimers) {
-      // We only need to lock if we have timers. 1) Timers can
-      // interfere with state access. 2) Even without state access,
-      // setting timers can interfere with firing timers.
-      this.stateBackendLock = new ReentrantLock();
-    } else {
-      // Plain state access is guaranteed to not interfere with the state
-      // backend. The current key of the state backend is set manually before
-      // accessing the keyed state. Flink's automatic key setting before
-      // processing elements is overridden in this class.
-      this.stateBackendLock = NoopLock.get();
-    }
+    this.stateBackendLock = new ReentrantLock();
   }
 
   @Override
@@ -351,25 +343,26 @@ public class ExecutableStageDoFnOperator<InputT, OutputT> extends DoFnOperator<I
     }
   }
 
+  /**
+   * Note: This is only relevant when we have a stateful DoFn. We want to control the key of the
+   * state backend ourselves and we must avoid any concurrent setting of the current active key. By
+   * overwriting this, we also prevent unnecessary serialization as the key has to be encoded as a
+   * byte array.
+   */
   @Override
-  public void setKeyContextElement1(StreamRecord record) throws Exception {
-    // Note: This is only relevant when we have a stateful DoFn.
-    // We want to control the key of the state backend ourselves and
-    // we must avoid any concurrent setting of the current active key.
-    // By overwriting this, we also prevent unnecessary serialization
-    // as the key has to be encoded as a byte array.
-  }
-
+  public void setKeyContextElement1(StreamRecord record) {}
+
+  /**
+   * We don't want to set anything here. This is due to asynchronous nature of processing elements
+   * from the SDK Harness. The Flink runtime sets the current key before calling {@code
+   * processElement}, but this does not work when sending elements to the SDK harness which may be
+   * processed at an arbitrary later point in time. State for keys is also accessed asynchronously
+   * via state requests.
+   *
+   * <p>We set the key only as it is required for 1) State requests 2) Timers (setting/firing).
+   */
   @Override
-  public void setCurrentKey(Object key) {
-    // We don't need to set anything, the key is set manually on the state backend in
-    // the case of state access. For timers, the key will be extracted from the timer
-    // element, i.e. in HeapInternalTimerService
-    if (!usesTimers) {
-      throw new UnsupportedOperationException(
-          "Current key for state backend can only be set by state requests from SDK workers or when processing timers.");
-    }
-  }
+  public void setCurrentKey(Object key) {}
 
   @Override
   public Object getCurrentKey() {
@@ -464,7 +457,8 @@ public class ExecutableStageDoFnOperator<InputT, OutputT> extends DoFnOperator<I
             (Coder<BoundedWindow>) windowingStrategy.getWindowFn().windowCoder(),
             keySelector,
             this::setTimer);
-    return sdkHarnessRunner;
+
+    return ensureStateCleanup(sdkHarnessRunner);
   }
 
   @Override
@@ -713,6 +707,129 @@ public class ExecutableStageDoFnOperator<InputT, OutputT> extends DoFnOperator<I
     }
   }
 
+  private DoFnRunner<InputT, OutputT> ensureStateCleanup(
+      SdkHarnessDoFnRunner<InputT, OutputT> sdkHarnessRunner) {
+    if (keyCoder == null) {
+      // There won't be any state to clean up
+      // (stateful functions have to be keyed)
+      return sdkHarnessRunner;
+    }
+    // Takes care of state cleanup via StatefulDoFnRunner
+    Coder windowCoder = windowingStrategy.getWindowFn().windowCoder();
+    CleanupTimer<InputT> cleanupTimer =
+        new CleanupTimer<>(
+            timerInternals,
+            stateBackendLock,
+            windowingStrategy,
+            keyCoder,
+            windowCoder,
+            sdkHarnessRunner::setCurrentTimerKey,
+            getKeyedStateBackend());
+
+    List<String> userStates =
+        executableStage.getUserStates().stream()
+            .map(UserStateReference::localName)
+            .collect(Collectors.toList());
+    StateCleaner stateCleaner = new StateCleaner(userStates, windowCoder, keyedStateInternals);
+
+    return DoFnRunners.defaultStatefulDoFnRunner(
+        doFn, sdkHarnessRunner, windowingStrategy, cleanupTimer, stateCleaner);
+  }
+
+  static class CleanupTimer<InputT> implements StatefulDoFnRunner.CleanupTimer<InputT> {
+    private static final String GC_TIMER_ID = "__user-state-cleanup__";
+
+    private final TimerInternals timerInternals;
+    private final Lock stateBackendLock;
+    private final WindowingStrategy windowingStrategy;
+    private final Coder keyCoder;
+    private final Coder windowCoder;
+    private final Consumer<ByteBuffer> currentKeyConsumer;
+    private final KeyedStateBackend<ByteBuffer> keyedStateBackend;
+
+    CleanupTimer(
+        TimerInternals timerInternals,
+        Lock stateBackendLock,
+        WindowingStrategy windowingStrategy,
+        Coder keyCoder,
+        Coder windowCoder,
+        Consumer<ByteBuffer> currentKeyConsumer,
+        KeyedStateBackend<ByteBuffer> keyedStateBackend) {
+      this.timerInternals = timerInternals;
+      this.stateBackendLock = stateBackendLock;
+      this.windowingStrategy = windowingStrategy;
+      this.keyCoder = keyCoder;
+      this.windowCoder = windowCoder;
+      this.currentKeyConsumer = currentKeyConsumer;
+      this.keyedStateBackend = keyedStateBackend;
+    }
+
+    @Override
+    public Instant currentInputWatermarkTime() {
+      return timerInternals.currentInputWatermarkTime();
+    }
+
+    @Override
+    public void setForWindow(InputT input, BoundedWindow window) {
+      Preconditions.checkNotNull(input, "Null input passed to CleanupTimer");
+      // make sure this fires after any window.maxTimestamp() timers
+      Instant gcTime = LateDataUtils.garbageCollectionTime(window, windowingStrategy).plus(1);
+      final ByteBuffer key;
+      try {
+        key = ByteBuffer.wrap(CoderUtils.encodeToByteArray(keyCoder, ((KV) input).getKey()));
+      } catch (CoderException e) {
+        throw new RuntimeException("Failed to encode key for Flink state backend", e);
+      }
+      // Ensure the state backend is not concurrently accessed by the state requests
+      try {
+        stateBackendLock.lock();
+        // Set these two to ensure correct timer registration
+        // 1) For the timer setting
+        currentKeyConsumer.accept(key);
+        // 2) For the timer deduplication
+        keyedStateBackend.setCurrentKey(key);
+        timerInternals.setTimer(
+            StateNamespaces.window(windowCoder, window),
+            GC_TIMER_ID,
+            gcTime,
+            TimeDomain.EVENT_TIME);
+      } finally {
+        stateBackendLock.unlock();
+      }
+    }
+
+    @Override
+    public boolean isForWindow(
+        String timerId, BoundedWindow window, Instant timestamp, TimeDomain timeDomain) {
+      boolean isEventTimer = timeDomain.equals(TimeDomain.EVENT_TIME);
+      Instant gcTime = LateDataUtils.garbageCollectionTime(window, windowingStrategy).plus(1);
+      return isEventTimer && GC_TIMER_ID.equals(timerId) && gcTime.equals(timestamp);
+    }
+  }
+
+  static class StateCleaner implements StatefulDoFnRunner.StateCleaner<BoundedWindow> {
+
+    private final List<String> userStateNames;
+    private final Coder windowCoder;
+    private final StateInternals stateInternals;
+
+    StateCleaner(List<String> userStateNames, Coder windowCoder, StateInternals stateInternals) {
+      this.userStateNames = userStateNames;
+      this.windowCoder = windowCoder;
+      this.stateInternals = stateInternals;
+    }
+
+    @Override
+    public void clearForWindow(BoundedWindow window) {
+      // Executed in the context of onTimer(..) where the correct key will be set
+      for (String userState : userStateNames) {
+        StateNamespace namespace = StateNamespaces.window(windowCoder, window);
+        BagState<?> state = stateInternals.state(namespace, StateTags.bag(userState, null));
+        state.clear();
+      }
+    }
+  }
+
   private static class NoOpDoFn<InputT, OutputT> extends DoFn<InputT, OutputT> {
     @ProcessElement
     public void doNothing(ProcessContext context) {}
diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/ExecutableStageDoFnOperatorTest.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperatorTest.java
similarity index 76%
rename from runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/ExecutableStageDoFnOperatorTest.java
rename to runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperatorTest.java
index f306a40..2b51583 100644
--- a/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/ExecutableStageDoFnOperatorTest.java
+++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperatorTest.java
@@ -15,31 +15,41 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-package org.apache.beam.runners.flink.streaming;
+package org.apache.beam.runners.flink.translation.wrappers.streaming;
 
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.Matchers.instanceOf;
 import static org.hamcrest.Matchers.is;
 import static org.hamcrest.collection.IsIterableContainingInOrder.contains;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertNotEquals;
 import static org.junit.Assert.assertNotNull;
-import static org.junit.Assert.assertThat;
 import static org.mockito.Matchers.any;
 import static org.mockito.Mockito.doThrow;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.verifyNoMoreInteractions;
 import static org.mockito.Mockito.when;
 
+import java.nio.ByteBuffer;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
+import java.util.concurrent.locks.Lock;
+import java.util.function.Consumer;
+import javax.annotation.Nullable;
 import org.apache.beam.model.pipeline.v1.RunnerApi.Components;
 import org.apache.beam.model.pipeline.v1.RunnerApi.ExecutableStagePayload;
 import org.apache.beam.model.pipeline.v1.RunnerApi.PCollection;
+import org.apache.beam.runners.core.InMemoryStateInternals;
+import org.apache.beam.runners.core.InMemoryTimerInternals;
+import org.apache.beam.runners.core.StateNamespaces;
+import org.apache.beam.runners.core.StateTags;
+import org.apache.beam.runners.core.StatefulDoFnRunner;
 import org.apache.beam.runners.flink.FlinkPipelineOptions;
+import org.apache.beam.runners.flink.metrics.DoFnRunnerWithMetricsUpdate;
 import org.apache.beam.runners.flink.translation.functions.FlinkExecutableStageContext;
-import org.apache.beam.runners.flink.translation.wrappers.streaming.DoFnOperator;
-import org.apache.beam.runners.flink.translation.wrappers.streaming.ExecutableStageDoFnOperator;
+import org.apache.beam.runners.flink.translation.types.CoderTypeInformation;
 import org.apache.beam.runners.fnexecution.control.BundleProgressHandler;
 import org.apache.beam.runners.fnexecution.control.OutputReceiverFactory;
 import org.apache.beam.runners.fnexecution.control.ProcessBundleDescriptors;
@@ -48,22 +58,31 @@ import org.apache.beam.runners.fnexecution.control.StageBundleFactory;
 import org.apache.beam.runners.fnexecution.provisioning.JobInfo;
 import org.apache.beam.runners.fnexecution.state.StateRequestHandler;
 import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.StringUtf8Coder;
 import org.apache.beam.sdk.coders.VarIntCoder;
 import org.apache.beam.sdk.coders.VoidCoder;
 import org.apache.beam.sdk.fn.data.FnDataReceiver;
 import org.apache.beam.sdk.options.PipelineOptionsFactory;
+import org.apache.beam.sdk.state.BagState;
+import org.apache.beam.sdk.state.TimeDomain;
+import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
+import org.apache.beam.sdk.util.CoderUtils;
 import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.TupleTag;
 import org.apache.beam.sdk.values.WindowingStrategy;
 import org.apache.beam.vendor.grpc.v1p13p1.com.google.protobuf.Struct;
 import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.ImmutableList;
 import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.ImmutableMap;
+import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.Iterables;
 import org.apache.commons.lang3.SerializationUtils;
 import org.apache.flink.api.common.cache.DistributedCache;
 import org.apache.flink.api.common.functions.RuntimeContext;
 import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.runtime.state.KeyedStateBackend;
 import org.apache.flink.streaming.api.watermark.Watermark;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.streaming.util.KeyedOneInputStreamOperatorTestHarness;
 import org.apache.flink.streaming.util.OneInputStreamOperatorTestHarness;
 import org.apache.flink.util.OutputTag;
 import org.junit.Before;
@@ -341,6 +360,101 @@ public class ExecutableStageDoFnOperatorTest {
   }
 
   @Test
+  @SuppressWarnings("unchecked")
+  public void testEnsureStateCleanupWithKeyedInput() throws Exception {
+    TupleTag<Integer> mainOutput = new TupleTag<>("main-output");
+    DoFnOperator.MultiOutputOutputManagerFactory<Integer> outputManagerFactory =
+        new DoFnOperator.MultiOutputOutputManagerFactory(mainOutput, VarIntCoder.of());
+    VarIntCoder keyCoder = VarIntCoder.of();
+    ExecutableStageDoFnOperator<Integer, Integer> operator =
+        getOperator(mainOutput, Collections.emptyList(), outputManagerFactory, keyCoder);
+
+    KeyedOneInputStreamOperatorTestHarness<Integer, WindowedValue<Integer>, WindowedValue<Integer>>
+        testHarness =
+            new KeyedOneInputStreamOperatorTestHarness(
+                operator, val -> val, new CoderTypeInformation<>(keyCoder));
+
+    RemoteBundle bundle = Mockito.mock(RemoteBundle.class);
+    when(bundle.getInputReceivers())
+        .thenReturn(
+            ImmutableMap.<String, FnDataReceiver<WindowedValue>>builder()
+                .put("input", Mockito.mock(FnDataReceiver.class))
+                .build());
+    when(stageBundleFactory.getBundle(any(), any(), any())).thenReturn(bundle);
+
+    testHarness.open();
+
+    Object doFnRunner = Whitebox.getInternalState(operator, "doFnRunner");
+    assertThat(doFnRunner, instanceOf(DoFnRunnerWithMetricsUpdate.class));
+
+    // There should be a StatefulDoFnRunner installed which takes care of clearing state
+    Object statefulDoFnRunner = Whitebox.getInternalState(doFnRunner, "delegate");
+    assertThat(statefulDoFnRunner, instanceOf(StatefulDoFnRunner.class));
+  }
+
+  @Test
+  public void testEnsureStateCleanupWithKeyedInputCleanupTimer() throws Exception {
+    InMemoryTimerInternals inMemoryTimerInternals = new InMemoryTimerInternals();
+    Consumer<ByteBuffer> keyConsumer = Mockito.mock(Consumer.class);
+    KeyedStateBackend keyedStateBackend = Mockito.mock(KeyedStateBackend.class);
+    Lock stateBackendLock = Mockito.mock(Lock.class);
+    StringUtf8Coder keyCoder = StringUtf8Coder.of();
+    GlobalWindow window = GlobalWindow.INSTANCE;
+    GlobalWindow.Coder windowCoder = GlobalWindow.Coder.INSTANCE;
+
+    // Test that cleanup timer is set correctly
+    ExecutableStageDoFnOperator.CleanupTimer cleanupTimer =
+        new ExecutableStageDoFnOperator.CleanupTimer<>(
+            inMemoryTimerInternals,
+            stateBackendLock,
+            WindowingStrategy.globalDefault(),
+            keyCoder,
+            windowCoder,
+            keyConsumer,
+            keyedStateBackend);
+    cleanupTimer.setForWindow(KV.of("key", "string"), window);
+
+    Mockito.verify(stateBackendLock).lock();
+    ByteBuffer key = ByteBuffer.wrap(CoderUtils.encodeToByteArray(keyCoder, "key"));
+    Mockito.verify(keyConsumer).accept(key);
+    Mockito.verify(keyedStateBackend).setCurrentKey(key);
+    assertThat(
+        inMemoryTimerInternals.getNextTimer(TimeDomain.EVENT_TIME),
+        is(window.maxTimestamp().plus(1)));
+    Mockito.verify(stateBackendLock).unlock();
+  }
+
+  @Test
+  public void testEnsureStateCleanupWithKeyedInputStateCleaner() throws Exception {
+    GlobalWindow.Coder windowCoder = GlobalWindow.Coder.INSTANCE;
+    InMemoryStateInternals<String> stateInternals = InMemoryStateInternals.forKey("key");
+    List<String> userStateNames = ImmutableList.of("state1", "state2");
+    ImmutableList.Builder<BagState<String>> bagStateBuilder = ImmutableList.builder();
+    for (String userStateName : userStateNames) {
+      BagState<String> state =
+          stateInternals.state(
+              StateNamespaces.window(windowCoder, GlobalWindow.INSTANCE),
+              StateTags.bag(userStateName, StringUtf8Coder.of()));
+      bagStateBuilder.add(state);
+      state.add("this should be cleaned");
+    }
+    ImmutableList<BagState<String>> bagStates = bagStateBuilder.build();
+
+    // Test that state is cleaned up correctly
+    ExecutableStageDoFnOperator.StateCleaner stateCleaner =
+        new ExecutableStageDoFnOperator.StateCleaner(userStateNames, windowCoder, stateInternals);
+    for (BagState<String> bagState : bagStates) {
+      assertThat(Iterables.size(bagState.read()), is(1));
+    }
+
+    stateCleaner.clearForWindow(GlobalWindow.INSTANCE);
+
+    for (BagState<String> bagState : bagStates) {
+      assertThat(Iterables.size(bagState.read()), is(0));
+    }
+  }
+
+  @Test
   public void testSerialization() {
     WindowedValue.ValueOnlyWindowedValueCoder<Integer> coder =
         WindowedValue.getValueOnlyCoder(VarIntCoder.of());
@@ -405,6 +519,14 @@ public class ExecutableStageDoFnOperatorTest {
       TupleTag<Integer> mainOutput,
       List<TupleTag<?>> additionalOutputs,
       DoFnOperator.MultiOutputOutputManagerFactory<Integer> outputManagerFactory) {
+    return getOperator(mainOutput, additionalOutputs, outputManagerFactory, null);
+  }
+
+  private ExecutableStageDoFnOperator<Integer, Integer> getOperator(
+      TupleTag<Integer> mainOutput,
+      List<TupleTag<?>> additionalOutputs,
+      DoFnOperator.MultiOutputOutputManagerFactory<Integer> outputManagerFactory,
+      @Nullable Coder keyCoder) {
 
     FlinkExecutableStageContext.Factory contextFactory =
         Mockito.mock(FlinkExecutableStageContext.Factory.class);
@@ -428,7 +550,7 @@ public class ExecutableStageDoFnOperatorTest {
             contextFactory,
             createOutputMap(mainOutput, additionalOutputs),
             WindowingStrategy.globalDefault(),
-            null,
+            keyCoder,
             null);
 
     Whitebox.setInternalState(operator, "stateRequestHandler", stateRequestHandler);