You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by re...@apache.org on 2020/04/27 16:31:44 UTC

[beam] branch master updated: Merge pull request #11154: [BEAM-1819] Key should be available in @OnTimer methods

This is an automated email from the ASF dual-hosted git repository.

reuvenlax pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 591de34  Merge pull request #11154: [BEAM-1819] Key should be available in @OnTimer methods
591de34 is described below

commit 591de3473144de54beef0932131025e2a4d8504b
Author: Rehman Murad Ali <re...@gmail.com>
AuthorDate: Mon Apr 27 21:31:29 2020 +0500

    Merge pull request #11154: [BEAM-1819] Key should be available in @OnTimer methods
---
 CHANGES.md                                         |   1 +
 .../beam/examples/complete/AutoComplete.java       |   4 +-
 .../translation/operators/ApexParDoOperator.java   |   1 +
 .../construction/SplittableParDoNaiveBounded.java  |   5 +
 .../org/apache/beam/runners/core/DoFnRunner.java   |   3 +-
 .../runners/core/LateDataDroppingDoFnRunner.java   |   5 +-
 .../apache/beam/runners/core/ProcessFnRunner.java  |   3 +-
 .../runners/core/PushbackSideInputDoFnRunner.java  |   3 +-
 .../apache/beam/runners/core/SimpleDoFnRunner.java |  24 +++-
 .../core/SimplePushbackSideInputDoFnRunner.java    |   5 +-
 .../beam/runners/core/StatefulDoFnRunner.java      |   6 +-
 .../beam/runners/core/SimpleDoFnRunnerTest.java    |   2 +
 .../SimplePushbackSideInputDoFnRunnerTest.java     |  12 +-
 .../beam/runners/core/StatefulDoFnRunnerTest.java  |   1 +
 ...LifecycleManagerRemovingTransformEvaluator.java |   4 +-
 .../apache/beam/runners/direct/ParDoEvaluator.java |   3 +-
 .../direct/StatefulParDoEvaluatorFactory.java      |   3 +-
 ...cycleManagerRemovingTransformEvaluatorTest.java |   3 +-
 .../flink/metrics/DoFnRunnerWithMetricsUpdate.java |   5 +-
 .../functions/FlinkStatefulDoFnFunction.java       |  14 ++-
 .../wrappers/streaming/DoFnOperator.java           |   3 +-
 .../streaming/ExecutableStageDoFnOperator.java     |   5 +-
 .../streaming/stableinput/BufferedElements.java    |  16 ++-
 .../streaming/stableinput/BufferingDoFnRunner.java |  10 +-
 .../wrappers/streaming/DoFnOperatorTest.java       | 124 +++++++++++++------
 .../stableinput/BufferedElementsTest.java          |   6 +-
 .../dataflow/worker/DataflowProcessFnRunner.java   |   3 +-
 .../dataflow/worker/GroupAlsoByWindowFnRunner.java |   3 +-
 .../runners/dataflow/worker/SimpleParDoFn.java     |   1 +
 .../dataflow/worker/StreamingDataflowWorker.java   |   5 +
 .../StreamingKeyedWorkItemSideInputDoFnRunner.java |   3 +-
 .../worker/StreamingSideInputDoFnRunner.java       |   3 +-
 .../runners/jet/processors/StatefulParDoP.java     |   1 +
 .../samza/metrics/DoFnRunnerWithMetrics.java       |   5 +-
 .../apache/beam/runners/samza/runtime/DoFnOp.java  |   1 +
 .../runtime/DoFnRunnerWithKeyedInternals.java      |   6 +-
 .../runners/samza/runtime/SamzaDoFnRunners.java    |   3 +-
 .../translation/batch/DoFnRunnerWithMetrics.java   |   5 +-
 .../spark/translation/DoFnRunnerWithMetrics.java   |   5 +-
 .../spark/translation/SparkProcessContext.java     |   1 +
 .../java/org/apache/beam/sdk/testing/UsesKey.java  |  28 +++++
 .../java/org/apache/beam/sdk/transforms/DoFn.java  |   9 ++
 .../org/apache/beam/sdk/transforms/DoFnTester.java |   6 +
 .../reflect/ByteBuddyDoFnInvokerFactory.java       |   8 ++
 .../beam/sdk/transforms/reflect/DoFnInvoker.java   |  16 +++
 .../beam/sdk/transforms/reflect/DoFnSignature.java |  21 ++++
 .../sdk/transforms/reflect/DoFnSignatures.java     |  16 ++-
 .../org/apache/beam/sdk/transforms/ParDoTest.java  | 135 +++++++++++++++++++++
 .../meta/provider/datastore/DataStoreV1Table.java  |   5 +-
 .../apache/beam/fn/harness/FnApiDoFnRunner.java    |  23 +++-
 .../beam/sdk/io/gcp/datastore/V1TestUtil.java      |   2 +-
 51 files changed, 482 insertions(+), 103 deletions(-)

diff --git a/CHANGES.md b/CHANGES.md
index cb0f4e8..520e387 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -180,6 +180,7 @@ Schema Options, it will be removed in version `2.23.0`. ([BEAM-9704](https://iss
 * Fixed exception when running in IPython notebook (Python) ([BEAM-X9277](https://issues.apache.org/jira/browse/BEAM-9277)).
 * Fixed Flink uberjar job termination bug. ([BEAM-9225](https://issues.apache.org/jira/browse/BEAM-9225))
 * Fixed SyntaxError in process worker startup ([BEAM-9503](https://issues.apache.org/jira/browse/BEAM-9503))
+* Key should be available in @OnTimer methods (Java) ([BEAM-1819](https://issues.apache.org/jira/browse/BEAM-1819)).
 
 ## Known Issues
 
diff --git a/examples/java/src/main/java/org/apache/beam/examples/complete/AutoComplete.java b/examples/java/src/main/java/org/apache/beam/examples/complete/AutoComplete.java
index 58ab72b..4fcf282 100644
--- a/examples/java/src/main/java/org/apache/beam/examples/complete/AutoComplete.java
+++ b/examples/java/src/main/java/org/apache/beam/examples/complete/AutoComplete.java
@@ -26,7 +26,6 @@ import com.google.api.services.bigquery.model.TableReference;
 import com.google.api.services.bigquery.model.TableRow;
 import com.google.api.services.bigquery.model.TableSchema;
 import com.google.datastore.v1.Entity;
-import com.google.datastore.v1.Key;
 import com.google.datastore.v1.Value;
 import java.io.IOException;
 import java.util.ArrayList;
@@ -386,7 +385,8 @@ public class AutoComplete {
     @ProcessElement
     public void processElement(ProcessContext c) {
       Entity.Builder entityBuilder = Entity.newBuilder();
-      Key key = makeKey(makeKey(kind, ancestorKey).build(), kind, c.element().getKey()).build();
+      com.google.datastore.v1.Key key =
+          makeKey(makeKey(kind, ancestorKey).build(), kind, c.element().getKey()).build();
 
       entityBuilder.setKey(key);
       List<Value> candidates = new ArrayList<>();
diff --git a/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/operators/ApexParDoOperator.java b/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/operators/ApexParDoOperator.java
index 111d997..7e55b67 100644
--- a/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/operators/ApexParDoOperator.java
+++ b/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/operators/ApexParDoOperator.java
@@ -385,6 +385,7 @@ public class ApexParDoOperator<InputT, OutputT> extends BaseOperator
       pushbackDoFnRunner.onTimer(
           timerData.getTimerId(),
           timerData.getTimerFamilyId(),
+          null,
           window,
           timerData.getTimestamp(),
           timerData.getOutputTimestamp(),
diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SplittableParDoNaiveBounded.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SplittableParDoNaiveBounded.java
index 31ce3b6..95ac33e 100644
--- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SplittableParDoNaiveBounded.java
+++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SplittableParDoNaiveBounded.java
@@ -421,6 +421,11 @@ public class SplittableParDoNaiveBounded {
       }
 
       @Override
+      public Object key() {
+        throw new UnsupportedOperationException();
+      }
+
+      @Override
       public Object sideInput(String tagId) {
         throw new UnsupportedOperationException();
       }
diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/DoFnRunner.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/DoFnRunner.java
index e2fc262..7ee73c2 100644
--- a/runners/core-java/src/main/java/org/apache/beam/runners/core/DoFnRunner.java
+++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/DoFnRunner.java
@@ -38,9 +38,10 @@ public interface DoFnRunner<InputT, OutputT> {
    * Calls a {@link DoFn DoFn's} {@link DoFn.OnTimer @OnTimer} method for the given timer in the
    * given window.
    */
-  void onTimer(
+  <KeyT> void onTimer(
       String timerId,
       String timerFamilyId,
+      KeyT key,
       BoundedWindow window,
       Instant timestamp,
       Instant outputTimestamp,
diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/LateDataDroppingDoFnRunner.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/LateDataDroppingDoFnRunner.java
index 8f19b5f..8b87e07 100644
--- a/runners/core-java/src/main/java/org/apache/beam/runners/core/LateDataDroppingDoFnRunner.java
+++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/LateDataDroppingDoFnRunner.java
@@ -81,14 +81,15 @@ public class LateDataDroppingDoFnRunner<K, InputT, OutputT, W extends BoundedWin
   }
 
   @Override
-  public void onTimer(
+  public <KeyT> void onTimer(
       String timerId,
       String timerFamilyId,
+      KeyT key,
       BoundedWindow window,
       Instant timestamp,
       Instant outputTimestamp,
       TimeDomain timeDomain) {
-    doFnRunner.onTimer(timerId, timerFamilyId, window, timestamp, outputTimestamp, timeDomain);
+    doFnRunner.onTimer(timerId, timerFamilyId, key, window, timestamp, outputTimestamp, timeDomain);
   }
 
   @Override
diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/ProcessFnRunner.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/ProcessFnRunner.java
index c310c49..707eda7 100644
--- a/runners/core-java/src/main/java/org/apache/beam/runners/core/ProcessFnRunner.java
+++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/ProcessFnRunner.java
@@ -82,9 +82,10 @@ public class ProcessFnRunner<InputT, OutputT, RestrictionT>
   }
 
   @Override
-  public void onTimer(
+  public <KeyT> void onTimer(
       String timerId,
       String timerFamilyId,
+      KeyT key,
       BoundedWindow window,
       Instant timestamp,
       Instant outputTimestamp,
diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/PushbackSideInputDoFnRunner.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/PushbackSideInputDoFnRunner.java
index 32a61af..1fd9eb2 100644
--- a/runners/core-java/src/main/java/org/apache/beam/runners/core/PushbackSideInputDoFnRunner.java
+++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/PushbackSideInputDoFnRunner.java
@@ -43,9 +43,10 @@ public interface PushbackSideInputDoFnRunner<InputT, OutputT> {
   Iterable<WindowedValue<InputT>> processElementInReadyWindows(WindowedValue<InputT> elem);
 
   /** Calls the underlying {@link DoFn.OnTimer} method. */
-  void onTimer(
+  <KeyT> void onTimer(
       String timerId,
       String timerFamilyId,
+      KeyT key,
       BoundedWindow window,
       Instant timestamp,
       Instant outputTimestamp,
diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/SimpleDoFnRunner.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/SimpleDoFnRunner.java
index b8985dc..80ed59c 100644
--- a/runners/core-java/src/main/java/org/apache/beam/runners/core/SimpleDoFnRunner.java
+++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/SimpleDoFnRunner.java
@@ -188,9 +188,10 @@ public class SimpleDoFnRunner<InputT, OutputT> implements DoFnRunner<InputT, Out
   }
 
   @Override
-  public void onTimer(
+  public <KeyT> void onTimer(
       String timerId,
       String timerFamilyId,
+      KeyT key,
       BoundedWindow window,
       Instant timestamp,
       Instant outputTimestamp,
@@ -214,8 +215,9 @@ public class SimpleDoFnRunner<InputT, OutputT> implements DoFnRunner<InputT, Out
         throw new IllegalArgumentException(String.format("Unknown time domain: %s", timeDomain));
     }
 
-    OnTimerArgumentProvider argumentProvider =
-        new OnTimerArgumentProvider(timerId, window, timestamp, effectiveTimestamp, timeDomain);
+    OnTimerArgumentProvider<KeyT> argumentProvider =
+        new OnTimerArgumentProvider<>(
+            timerId, key, window, timestamp, effectiveTimestamp, timeDomain);
     invoker.invokeOnTimer(timerId, timerFamilyId, argumentProvider);
   }
 
@@ -477,6 +479,12 @@ public class SimpleDoFnRunner<InputT, OutputT> implements DoFnRunner<InputT, Out
     }
 
     @Override
+    public Object key() {
+      throw new UnsupportedOperationException(
+          "Cannot access key as parameter outside of @OnTimer method.");
+    }
+
+    @Override
     public Object sideInput(String tagId) {
       return sideInput(sideInputMapping.get(tagId));
     }
@@ -605,13 +613,14 @@ public class SimpleDoFnRunner<InputT, OutputT> implements DoFnRunner<InputT, Out
    * A concrete implementation of {@link DoFnInvoker.ArgumentProvider} used for running a {@link
    * DoFn} on a timer.
    */
-  private class OnTimerArgumentProvider extends DoFn<InputT, OutputT>.OnTimerContext
+  private class OnTimerArgumentProvider<KeyT> extends DoFn<InputT, OutputT>.OnTimerContext
       implements DoFnInvoker.ArgumentProvider<InputT, OutputT> {
     private final BoundedWindow window;
     private final Instant fireTimestamp;
     private final Instant timestamp;
     private final TimeDomain timeDomain;
     private final String timerId;
+    private final KeyT key;
 
     /** Lazily initialized; should only be accessed via {@link #getNamespace()}. */
     private @Nullable StateNamespace namespace;
@@ -632,6 +641,7 @@ public class SimpleDoFnRunner<InputT, OutputT> implements DoFnRunner<InputT, Out
 
     private OnTimerArgumentProvider(
         String timerId,
+        KeyT key,
         BoundedWindow window,
         Instant fireTimestamp,
         Instant timestamp,
@@ -642,6 +652,7 @@ public class SimpleDoFnRunner<InputT, OutputT> implements DoFnRunner<InputT, Out
       this.fireTimestamp = fireTimestamp;
       this.timestamp = timestamp;
       this.timeDomain = timeDomain;
+      this.key = key;
     }
 
     @Override
@@ -687,6 +698,11 @@ public class SimpleDoFnRunner<InputT, OutputT> implements DoFnRunner<InputT, Out
     }
 
     @Override
+    public KeyT key() {
+      return key;
+    }
+
+    @Override
     public DoFn<InputT, OutputT>.ProcessContext processContext(DoFn<InputT, OutputT> doFn) {
       throw new UnsupportedOperationException("ProcessContext parameters are not supported.");
     }
diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/SimplePushbackSideInputDoFnRunner.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/SimplePushbackSideInputDoFnRunner.java
index b27e046..d622ba1 100644
--- a/runners/core-java/src/main/java/org/apache/beam/runners/core/SimplePushbackSideInputDoFnRunner.java
+++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/SimplePushbackSideInputDoFnRunner.java
@@ -107,14 +107,15 @@ public class SimplePushbackSideInputDoFnRunner<InputT, OutputT>
   }
 
   @Override
-  public void onTimer(
+  public <KeyT> void onTimer(
       String timerId,
       String timerFamilyId,
+      KeyT key,
       BoundedWindow window,
       Instant timestamp,
       Instant outputTimestamp,
       TimeDomain timeDomain) {
-    underlying.onTimer(timerId, timerFamilyId, window, timestamp, outputTimestamp, timeDomain);
+    underlying.onTimer(timerId, timerFamilyId, key, window, timestamp, outputTimestamp, timeDomain);
   }
 
   @Override
diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/StatefulDoFnRunner.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/StatefulDoFnRunner.java
index a339333..baead79 100644
--- a/runners/core-java/src/main/java/org/apache/beam/runners/core/StatefulDoFnRunner.java
+++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/StatefulDoFnRunner.java
@@ -196,9 +196,10 @@ public class StatefulDoFnRunner<InputT, OutputT, W extends BoundedWindow>
   }
 
   @Override
-  public void onTimer(
+  public <KeyT> void onTimer(
       String timerId,
       String timerFamilyId,
+      KeyT key,
       BoundedWindow window,
       Instant timestamp,
       Instant outputTimestamp,
@@ -224,7 +225,8 @@ public class StatefulDoFnRunner<InputT, OutputT, W extends BoundedWindow>
             window,
             stepContext.timerInternals().currentInputWatermarkTime());
       } else {
-        doFnRunner.onTimer(timerId, timerFamilyId, window, timestamp, outputTimestamp, timeDomain);
+        doFnRunner.onTimer(
+            timerId, timerFamilyId, key, window, timestamp, outputTimestamp, timeDomain);
       }
     }
   }
diff --git a/runners/core-java/src/test/java/org/apache/beam/runners/core/SimpleDoFnRunnerTest.java b/runners/core-java/src/test/java/org/apache/beam/runners/core/SimpleDoFnRunnerTest.java
index e5696bf..a1c3d72 100644
--- a/runners/core-java/src/test/java/org/apache/beam/runners/core/SimpleDoFnRunnerTest.java
+++ b/runners/core-java/src/test/java/org/apache/beam/runners/core/SimpleDoFnRunnerTest.java
@@ -122,6 +122,7 @@ public class SimpleDoFnRunnerTest {
     runner.onTimer(
         TimerDeclaration.PREFIX + ThrowingDoFn.TIMER_ID,
         "",
+        null,
         GlobalWindow.INSTANCE,
         new Instant(0),
         new Instant(0),
@@ -246,6 +247,7 @@ public class SimpleDoFnRunnerTest {
     runner.onTimer(
         TimerDeclaration.PREFIX + DoFnWithTimers.TIMER_ID,
         "",
+        null,
         GlobalWindow.INSTANCE,
         currentTime.plus(offset),
         currentTime.plus(offset),
diff --git a/runners/core-java/src/test/java/org/apache/beam/runners/core/SimplePushbackSideInputDoFnRunnerTest.java b/runners/core-java/src/test/java/org/apache/beam/runners/core/SimplePushbackSideInputDoFnRunnerTest.java
index 16ba0b1..f9d4296 100644
--- a/runners/core-java/src/test/java/org/apache/beam/runners/core/SimplePushbackSideInputDoFnRunnerTest.java
+++ b/runners/core-java/src/test/java/org/apache/beam/runners/core/SimplePushbackSideInputDoFnRunnerTest.java
@@ -303,7 +303,13 @@ public class SimplePushbackSideInputDoFnRunnerTest {
     // Mocking is not easily compatible with annotation analysis, so we manually record
     // the method call.
     runner.onTimer(
-        timerId, "", window, new Instant(timestamp), new Instant(timestamp), TimeDomain.EVENT_TIME);
+        timerId,
+        "",
+        null,
+        window,
+        new Instant(timestamp),
+        new Instant(timestamp),
+        TimeDomain.EVENT_TIME);
 
     assertThat(
         underlying.firedTimers,
@@ -344,9 +350,10 @@ public class SimplePushbackSideInputDoFnRunnerTest {
     }
 
     @Override
-    public void onTimer(
+    public <KeyT> void onTimer(
         String timerId,
         String timerFamilyId,
+        KeyT key,
         BoundedWindow window,
         Instant timestamp,
         Instant outputTimestamp,
@@ -493,6 +500,7 @@ public class SimplePushbackSideInputDoFnRunnerTest {
       toTrigger.onTimer(
           timer.getTimerId(),
           timer.getTimerFamilyId(),
+          null,
           window,
           timer.getTimestamp(),
           timer.getOutputTimestamp(),
diff --git a/runners/core-java/src/test/java/org/apache/beam/runners/core/StatefulDoFnRunnerTest.java b/runners/core-java/src/test/java/org/apache/beam/runners/core/StatefulDoFnRunnerTest.java
index 066b41f..6947a4c 100644
--- a/runners/core-java/src/test/java/org/apache/beam/runners/core/StatefulDoFnRunnerTest.java
+++ b/runners/core-java/src/test/java/org/apache/beam/runners/core/StatefulDoFnRunnerTest.java
@@ -391,6 +391,7 @@ public class StatefulDoFnRunnerTest {
       toTrigger.onTimer(
           timer.getTimerId(),
           timer.getTimerFamilyId(),
+          null,
           window,
           timer.getTimestamp(),
           timer.getOutputTimestamp(),
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DoFnLifecycleManagerRemovingTransformEvaluator.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DoFnLifecycleManagerRemovingTransformEvaluator.java
index 8f67ebe..dbc52d8 100644
--- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DoFnLifecycleManagerRemovingTransformEvaluator.java
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DoFnLifecycleManagerRemovingTransformEvaluator.java
@@ -58,9 +58,9 @@ class DoFnLifecycleManagerRemovingTransformEvaluator<InputT> implements Transfor
     }
   }
 
-  public void onTimer(TimerData timer, BoundedWindow window) throws Exception {
+  public <KeyT> void onTimer(TimerData timer, KeyT key, BoundedWindow window) throws Exception {
     try {
-      underlying.onTimer(timer, window);
+      underlying.onTimer(timer, key, window);
     } catch (Exception e) {
       onException(e, "Exception encountered while cleaning up after processing a timer");
       throw e;
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoEvaluator.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoEvaluator.java
index 7986fd1..fbe727f 100644
--- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoEvaluator.java
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoEvaluator.java
@@ -244,11 +244,12 @@ class ParDoEvaluator<InputT> implements TransformEvaluator<InputT> {
     }
   }
 
-  public void onTimer(TimerData timer, BoundedWindow window) {
+  public <KeyT> void onTimer(TimerData timer, KeyT key, BoundedWindow window) {
     try {
       fnRunner.onTimer(
           timer.getTimerId(),
           timer.getTimerFamilyId(),
+          key,
           window,
           timer.getTimestamp(),
           timer.getOutputTimestamp(),
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactory.java
index 0dd826d..4708276 100644
--- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactory.java
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactory.java
@@ -284,7 +284,8 @@ final class StatefulParDoEvaluatorFactory<K, InputT, OutputT> implements Transfo
             timer.getNamespace().getClass().getName());
         WindowNamespace<?> windowNamespace = (WindowNamespace) timer.getNamespace();
         BoundedWindow timerWindow = windowNamespace.getWindow();
-        delegateEvaluator.onTimer(timer, timerWindow);
+
+        delegateEvaluator.onTimer(timer, gbkResult.getValue().key(), timerWindow);
 
         StateTag<WatermarkHoldState> timerWatermarkHoldTag = setTimerTag(timer);
 
diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DoFnLifecycleManagerRemovingTransformEvaluatorTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DoFnLifecycleManagerRemovingTransformEvaluatorTest.java
index da93882..3f10123 100644
--- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DoFnLifecycleManagerRemovingTransformEvaluatorTest.java
+++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DoFnLifecycleManagerRemovingTransformEvaluatorTest.java
@@ -95,7 +95,7 @@ public class DoFnLifecycleManagerRemovingTransformEvaluatorTest {
     ParDoEvaluator<Object> underlying = mock(ParDoEvaluator.class);
     doThrow(IllegalArgumentException.class)
         .when(underlying)
-        .onTimer(any(TimerData.class), any(BoundedWindow.class));
+        .onTimer(any(TimerData.class), any(Object.class), any(BoundedWindow.class));
 
     DoFn<?, ?> original = lifecycleManager.get();
     assertThat(original, not(nullValue()));
@@ -105,6 +105,7 @@ public class DoFnLifecycleManagerRemovingTransformEvaluatorTest {
     try {
       evaluator.onTimer(
           TimerData.of("foo", StateNamespaces.global(), new Instant(0), TimeDomain.EVENT_TIME),
+          "",
           GlobalWindow.INSTANCE);
     } catch (Exception e) {
       assertThat(lifecycleManager.get(), not(Matchers.theInstance(original)));
diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/metrics/DoFnRunnerWithMetricsUpdate.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/metrics/DoFnRunnerWithMetricsUpdate.java
index ce54d5b..a3fe450 100644
--- a/runners/flink/src/main/java/org/apache/beam/runners/flink/metrics/DoFnRunnerWithMetricsUpdate.java
+++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/metrics/DoFnRunnerWithMetricsUpdate.java
@@ -66,16 +66,17 @@ public class DoFnRunnerWithMetricsUpdate<InputT, OutputT> implements DoFnRunner<
   }
 
   @Override
-  public void onTimer(
+  public <KeyT> void onTimer(
       final String timerId,
       final String timerFamilyId,
+      final KeyT key,
       final BoundedWindow window,
       final Instant timestamp,
       final Instant outputTimestamp,
       final TimeDomain timeDomain) {
     try (Closeable ignored =
         MetricsEnvironment.scopedMetricsContainer(container.getMetricsContainer(stepName))) {
-      delegate.onTimer(timerId, timerFamilyId, window, timestamp, outputTimestamp, timeDomain);
+      delegate.onTimer(timerId, timerFamilyId, key, window, timestamp, outputTimestamp, timeDomain);
     } catch (IOException e) {
       throw new RuntimeException(e);
     }
diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkStatefulDoFnFunction.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkStatefulDoFnFunction.java
index d7fc2de..1dc13fe 100644
--- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkStatefulDoFnFunction.java
+++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkStatefulDoFnFunction.java
@@ -179,13 +179,13 @@ public class FlinkStatefulDoFnFunction<K, V, OutputT>
     timerInternals.advanceProcessingTime(BoundedWindow.TIMESTAMP_MAX_VALUE);
     timerInternals.advanceSynchronizedProcessingTime(BoundedWindow.TIMESTAMP_MAX_VALUE);
 
-    fireEligibleTimers(timerInternals, doFnRunner);
+    fireEligibleTimers(key, timerInternals, doFnRunner);
 
     doFnRunner.finishBundle();
   }
 
   private void fireEligibleTimers(
-      InMemoryTimerInternals timerInternals, DoFnRunner<KV<K, V>, OutputT> runner)
+      final K key, InMemoryTimerInternals timerInternals, DoFnRunner<KV<K, V>, OutputT> runner)
       throws Exception {
 
     while (true) {
@@ -195,15 +195,15 @@ public class FlinkStatefulDoFnFunction<K, V, OutputT>
 
       while ((timer = timerInternals.removeNextEventTimer()) != null) {
         hasFired = true;
-        fireTimer(timer, runner);
+        fireTimer(key, timer, runner);
       }
       while ((timer = timerInternals.removeNextProcessingTimer()) != null) {
         hasFired = true;
-        fireTimer(timer, runner);
+        fireTimer(key, timer, runner);
       }
       while ((timer = timerInternals.removeNextSynchronizedProcessingTimer()) != null) {
         hasFired = true;
-        fireTimer(timer, runner);
+        fireTimer(key, timer, runner);
       }
       if (!hasFired) {
         break;
@@ -211,13 +211,15 @@ public class FlinkStatefulDoFnFunction<K, V, OutputT>
     }
   }
 
-  private void fireTimer(TimerInternals.TimerData timer, DoFnRunner<KV<K, V>, OutputT> doFnRunner) {
+  private void fireTimer(
+      final K key, TimerInternals.TimerData timer, DoFnRunner<KV<K, V>, OutputT> doFnRunner) {
     StateNamespace namespace = timer.getNamespace();
     checkArgument(namespace instanceof StateNamespaces.WindowNamespace);
     BoundedWindow window = ((StateNamespaces.WindowNamespace) namespace).getWindow();
     doFnRunner.onTimer(
         timer.getTimerId(),
         timer.getTimerFamilyId(),
+        key,
         window,
         timer.getTimestamp(),
         timer.getOutputTimestamp(),
diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperator.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperator.java
index ca5642f..55f294c 100644
--- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperator.java
+++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperator.java
@@ -878,7 +878,7 @@ public class DoFnOperator<InputT, OutputT> extends AbstractStreamOperator<Window
   }
 
   // allow overriding this in ExecutableStageDoFnOperator to set the key context
-  protected void fireTimerInternal(Object key, TimerData timerData) {
+  protected void fireTimerInternal(ByteBuffer key, TimerData timerData) {
     fireTimer(timerData);
   }
 
@@ -897,6 +897,7 @@ public class DoFnOperator<InputT, OutputT> extends AbstractStreamOperator<Window
     pushbackDoFnRunner.onTimer(
         timerData.getTimerId(),
         timerData.getTimerFamilyId(),
+        keyedStateInternals.getKey(),
         window,
         timerData.getTimestamp(),
         timerData.getOutputTimestamp(),
diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperator.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperator.java
index 1bbd1e5..1e73a4b 100644
--- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperator.java
+++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperator.java
@@ -488,7 +488,7 @@ public class ExecutableStageDoFnOperator<InputT, OutputT> extends DoFnOperator<I
   }
 
   @Override
-  protected void fireTimerInternal(Object key, TimerInternals.TimerData timer) {
+  protected void fireTimerInternal(ByteBuffer key, TimerInternals.TimerData timer) {
     // We have to synchronize to ensure the state backend is not concurrently accessed by the state
     // requests
     try (Locker locker = Locker.locked(stateBackendLock)) {
@@ -730,9 +730,10 @@ public class ExecutableStageDoFnOperator<InputT, OutputT> extends DoFnOperator<I
     }
 
     @Override
-    public void onTimer(
+    public <KeyT> void onTimer(
         String timerId,
         String timerFamilyId,
+        KeyT key,
         BoundedWindow window,
         Instant timestamp,
         Instant outputTimestamp,
diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/stableinput/BufferedElements.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/stableinput/BufferedElements.java
index 5c5ca6a..883d1cc 100644
--- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/stableinput/BufferedElements.java
+++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/stableinput/BufferedElements.java
@@ -64,7 +64,7 @@ class BufferedElements {
     }
   }
 
-  static final class Timer implements BufferedElement {
+  static final class Timer<KeyT> implements BufferedElement {
 
     private final String timerId;
     private final String timerFamilyId;
@@ -72,10 +72,12 @@ class BufferedElements {
     private final Instant timestamp;
     private final Instant outputTimestamp;
     private final TimeDomain timeDomain;
+    private final KeyT key;
 
     Timer(
         String timerId,
         String timerFamilyId,
+        KeyT key,
         BoundedWindow window,
         Instant timestamp,
         Instant outputTimestamp,
@@ -83,6 +85,7 @@ class BufferedElements {
       this.timerId = timerId;
       this.window = window;
       this.timestamp = timestamp;
+      this.key = key;
       this.timeDomain = timeDomain;
       this.outputTimestamp = outputTimestamp;
       this.timerFamilyId = timerFamilyId;
@@ -90,7 +93,8 @@ class BufferedElements {
 
     @Override
     public void processWith(DoFnRunner doFnRunner) {
-      doFnRunner.onTimer(timerId, timerFamilyId, window, timestamp, outputTimestamp, timeDomain);
+      doFnRunner.onTimer(
+          timerId, timerFamilyId, key, window, timestamp, outputTimestamp, timeDomain);
     }
 
     @Override
@@ -123,12 +127,15 @@ class BufferedElements {
 
     private final org.apache.beam.sdk.coders.Coder<WindowedValue> elementCoder;
     private final org.apache.beam.sdk.coders.Coder<BoundedWindow> windowCoder;
+    private final Object key;
 
     public Coder(
         org.apache.beam.sdk.coders.Coder<WindowedValue> elementCoder,
-        org.apache.beam.sdk.coders.Coder<BoundedWindow> windowCoder) {
+        org.apache.beam.sdk.coders.Coder<BoundedWindow> windowCoder,
+        Object key) {
       this.elementCoder = elementCoder;
       this.windowCoder = windowCoder;
+      this.key = key;
     }
 
     @Override
@@ -157,9 +164,10 @@ class BufferedElements {
         case ELEMENT_MAGIC_BYTE:
           return new Element(elementCoder.decode(inStream));
         case TIMER_MAGIC_BYTE:
-          return new Timer(
+          return new Timer<>(
               STRING_CODER.decode(inStream),
               STRING_CODER.decode(inStream),
+              key,
               windowCoder.decode(inStream),
               INSTANT_CODER.decode(inStream),
               INSTANT_CODER.decode(inStream),
diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/stableinput/BufferingDoFnRunner.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/stableinput/BufferingDoFnRunner.java
index b0f9905..2499f47 100644
--- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/stableinput/BufferingDoFnRunner.java
+++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/stableinput/BufferingDoFnRunner.java
@@ -102,7 +102,8 @@ public class BufferingDoFnRunner<InputT, OutputT> implements DoFnRunner<InputT,
           ListStateDescriptor<BufferedElement> stateDescriptor =
               new ListStateDescriptor<>(
                   stateName + stateId,
-                  new CoderTypeSerializer<>(new BufferedElements.Coder(inputCoder, windowCoder)));
+                  new CoderTypeSerializer<>(
+                      new BufferedElements.Coder(inputCoder, windowCoder, null)));
           if (keyedStateBackend != null) {
             return KeyedBufferingElementsHandler.create(keyedStateBackend, stateDescriptor);
           } else {
@@ -147,16 +148,17 @@ public class BufferingDoFnRunner<InputT, OutputT> implements DoFnRunner<InputT,
   }
 
   @Override
-  public void onTimer(
+  public <KeyT> void onTimer(
       String timerId,
       String timerFamilyId,
+      KeyT key,
       BoundedWindow window,
       Instant timestamp,
       Instant outputTimestamp,
       TimeDomain timeDomain) {
     currentBufferingElementsHandler.buffer(
-        new BufferedElements.Timer(
-            timerId, timerFamilyId, window, timestamp, outputTimestamp, timeDomain));
+        new BufferedElements.Timer<>(
+            timerId, timerFamilyId, key, window, timestamp, outputTimestamp, timeDomain));
   }
 
   @Override
diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperatorTest.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperatorTest.java
index c569045..5469f6fd 100644
--- a/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperatorTest.java
+++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperatorTest.java
@@ -30,6 +30,7 @@ import static org.junit.Assert.assertThrows;
 
 import com.fasterxml.jackson.databind.type.TypeFactory;
 import com.fasterxml.jackson.databind.util.LRUMap;
+import java.nio.ByteBuffer;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
@@ -79,7 +80,6 @@ import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.FluentIt
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
-import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
 import org.apache.flink.api.common.typeinfo.TypeInformation;
 import org.apache.flink.api.java.functions.KeySelector;
 import org.apache.flink.runtime.checkpoint.OperatorSubtaskState;
@@ -313,13 +313,17 @@ public class DoFnOperatorTest {
           }
         };
 
+    VarIntCoder keyCoder = VarIntCoder.of();
     WindowedValue.FullWindowedValueCoder<Integer> inputCoder =
-        WindowedValue.getFullCoder(VarIntCoder.of(), windowingStrategy.getWindowFn().windowCoder());
+        WindowedValue.getFullCoder(keyCoder, windowingStrategy.getWindowFn().windowCoder());
 
     WindowedValue.FullWindowedValueCoder<String> outputCoder =
         WindowedValue.getFullCoder(
             StringUtf8Coder.of(), windowingStrategy.getWindowFn().windowCoder());
 
+    KeySelector<WindowedValue<Integer>, ByteBuffer> keySelector =
+        e -> FlinkKeyUtils.encodeKey(e.getValue(), keyCoder);
+
     TupleTag<String> outputTag = new TupleTag<>("main-output");
 
     DoFnOperator<Integer, String> doFnOperator =
@@ -335,14 +339,16 @@ public class DoFnOperatorTest {
             new HashMap<>(), /* side-input mapping */
             Collections.emptyList(), /* side inputs */
             PipelineOptionsFactory.as(FlinkPipelineOptions.class),
-            VarIntCoder.of(), /* key coder */
-            WindowedValue::getValue,
+            keyCoder, /* key coder */
+            keySelector,
             DoFnSchemaInformation.create(),
             Collections.emptyMap());
 
     OneInputStreamOperatorTestHarness<WindowedValue<Integer>, WindowedValue<String>> testHarness =
         new KeyedOneInputStreamOperatorTestHarness<>(
-            doFnOperator, WindowedValue::getValue, new CoderTypeInformation<>(VarIntCoder.of()));
+            doFnOperator,
+            keySelector,
+            new CoderTypeInformation<>(FlinkKeyUtils.ByteBufferCoder.of()));
 
     testHarness.setup(new CoderTypeSerializer<>(outputCoder));
 
@@ -417,12 +423,16 @@ public class DoFnOperatorTest {
           }
         };
 
+    VarIntCoder keyCoder = VarIntCoder.of();
     Coder<WindowedValue<Integer>> inputCoder =
-        WindowedValue.getFullCoder(VarIntCoder.of(), windowingStrategy.getWindowFn().windowCoder());
+        WindowedValue.getFullCoder(keyCoder, windowingStrategy.getWindowFn().windowCoder());
     Coder<WindowedValue<String>> outputCoder =
         WindowedValue.getFullCoder(
             StringUtf8Coder.of(), windowingStrategy.getWindowFn().windowCoder());
 
+    KeySelector<WindowedValue<Integer>, ByteBuffer> keySelector =
+        e -> FlinkKeyUtils.encodeKey(e.getValue(), keyCoder);
+
     TupleTag<String> outputTag = new TupleTag<>("main-output");
 
     DoFnOperator<Integer, String> doFnOperator =
@@ -438,14 +448,16 @@ public class DoFnOperatorTest {
             new HashMap<>(), /* side-input mapping */
             Collections.emptyList(), /* side inputs */
             PipelineOptionsFactory.as(FlinkPipelineOptions.class),
-            VarIntCoder.of(), /* key coder */
-            WindowedValue::getValue,
+            keyCoder, /* key coder */
+            keySelector,
             DoFnSchemaInformation.create(),
             Collections.emptyMap());
 
     OneInputStreamOperatorTestHarness<WindowedValue<Integer>, WindowedValue<String>> testHarness =
         new KeyedOneInputStreamOperatorTestHarness<>(
-            doFnOperator, WindowedValue::getValue, new CoderTypeInformation<>(VarIntCoder.of()));
+            doFnOperator,
+            keySelector,
+            new CoderTypeInformation<>(FlinkKeyUtils.ByteBufferCoder.of()));
 
     testHarness.open();
 
@@ -533,6 +545,9 @@ public class DoFnOperatorTest {
 
     TupleTag<KV<String, Integer>> outputTag = new TupleTag<>("main-output");
 
+    KeySelector<WindowedValue<KV<String, Integer>>, ByteBuffer> keySelector =
+        e -> FlinkKeyUtils.encodeKey(e.getValue().getKey(), StringUtf8Coder.of());
+
     DoFnOperator<KV<String, Integer>, KV<String, Integer>> doFnOperator =
         new DoFnOperator<>(
             fn,
@@ -547,17 +562,17 @@ public class DoFnOperatorTest {
             Collections.emptyList(), /* side inputs */
             PipelineOptionsFactory.as(FlinkPipelineOptions.class),
             StringUtf8Coder.of(), /* key coder */
-            kvWindowedValue -> kvWindowedValue.getValue().getKey(),
+            keySelector,
             DoFnSchemaInformation.create(),
             Collections.emptyMap());
 
     KeyedOneInputStreamOperatorTestHarness<
-            String, WindowedValue<KV<String, Integer>>, WindowedValue<KV<String, Integer>>>
+            ByteBuffer, WindowedValue<KV<String, Integer>>, WindowedValue<KV<String, Integer>>>
         testHarness =
             new KeyedOneInputStreamOperatorTestHarness<>(
                 doFnOperator,
-                kvWindowedValue -> kvWindowedValue.getValue().getKey(),
-                new CoderTypeInformation<>(StringUtf8Coder.of()));
+                keySelector,
+                new CoderTypeInformation<>(FlinkKeyUtils.ByteBufferCoder.of()));
 
     testHarness.open();
 
@@ -633,9 +648,11 @@ public class DoFnOperatorTest {
     ImmutableMap<Integer, PCollectionView<?>> sideInputMapping =
         ImmutableMap.<Integer, PCollectionView<?>>builder().put(1, view1).put(2, view2).build();
 
-    Coder<String> keyCoder = null;
+    Coder<String> keyCoder = StringUtf8Coder.of();
+    ;
+    KeySelector<WindowedValue<String>, ByteBuffer> keySelector = null;
     if (keyed) {
-      keyCoder = StringUtf8Coder.of();
+      keySelector = value -> FlinkKeyUtils.encodeKey(value.getValue(), keyCoder);
     }
 
     DoFnOperator<String, String> doFnOperator =
@@ -651,8 +668,8 @@ public class DoFnOperatorTest {
             sideInputMapping, /* side-input mapping */
             ImmutableList.of(view1, view2), /* side inputs */
             PipelineOptionsFactory.as(FlinkPipelineOptions.class),
-            keyCoder,
-            keyed ? WindowedValue::getValue : null,
+            keyed ? keyCoder : null,
+            keyed ? keySelector : null,
             DoFnSchemaInformation.create(),
             Collections.emptyMap());
 
@@ -663,7 +680,10 @@ public class DoFnOperatorTest {
       // we use a dummy key for the second input since it is considered to be broadcast
       testHarness =
           new KeyedTwoInputStreamOperatorTestHarness<>(
-              doFnOperator, WindowedValue::getValue, null, BasicTypeInfo.STRING_TYPE_INFO);
+              doFnOperator,
+              keySelector,
+              null,
+              new CoderTypeInformation<>(FlinkKeyUtils.ByteBufferCoder.of()));
     }
 
     testHarness.open();
@@ -735,13 +755,15 @@ public class DoFnOperatorTest {
 
     TupleTag<KV<String, Long>> outputTag = new TupleTag<>("main-output");
 
+    StringUtf8Coder keyCoder = StringUtf8Coder.of();
+    KvToByteBufferKeySelector keySelector = new KvToByteBufferKeySelector<>(keyCoder);
+    KvCoder<String, Long> coder = KvCoder.of(keyCoder, VarLongCoder.of());
+
     FullWindowedValueCoder<KV<String, Long>> kvCoder =
-        WindowedValue.getFullCoder(
-            KvCoder.of(StringUtf8Coder.of(), VarLongCoder.of()),
-            windowingStrategy.getWindowFn().windowCoder());
+        WindowedValue.getFullCoder(coder, windowingStrategy.getWindowFn().windowCoder());
 
-    CoderTypeInformation<String> keyCoderInfo = new CoderTypeInformation<>(StringUtf8Coder.of());
-    KeySelector<WindowedValue<KV<String, Long>>, String> keySelector = e -> e.getValue().getKey();
+    CoderTypeInformation<ByteBuffer> keyCoderInfo =
+        new CoderTypeInformation<>(FlinkKeyUtils.ByteBufferCoder.of());
 
     OneInputStreamOperatorTestHarness<
             WindowedValue<KV<String, Long>>, WindowedValue<KV<String, Long>>>
@@ -751,6 +773,7 @@ public class DoFnOperatorTest {
                 filterElementsEqualToCountFn,
                 kvCoder,
                 kvCoder,
+                keyCoder,
                 outputTag,
                 keyCoderInfo,
                 keySelector);
@@ -770,6 +793,7 @@ public class DoFnOperatorTest {
             filterElementsEqualToCountFn,
             kvCoder,
             kvCoder,
+            keyCoder,
             outputTag,
             keyCoderInfo,
             keySelector);
@@ -835,11 +859,13 @@ public class DoFnOperatorTest {
   public void keyedParDoSideInputCheckpointing() throws Exception {
     sideInputCheckpointing(
         () -> {
+          StringUtf8Coder keyCoder = StringUtf8Coder.of();
           Coder<WindowedValue<String>> coder =
-              WindowedValue.getFullCoder(StringUtf8Coder.of(), IntervalWindow.getCoder());
+              WindowedValue.getFullCoder(keyCoder, IntervalWindow.getCoder());
           TupleTag<String> outputTag = new TupleTag<>("main-output");
 
-          StringUtf8Coder keyCoder = StringUtf8Coder.of();
+          KeySelector<WindowedValue<String>, ByteBuffer> keySelector =
+              e -> FlinkKeyUtils.encodeKey(e.getValue(), keyCoder);
 
           ImmutableMap<Integer, PCollectionView<?>> sideInputMapping =
               ImmutableMap.<Integer, PCollectionView<?>>builder()
@@ -861,16 +887,16 @@ public class DoFnOperatorTest {
                   ImmutableList.of(view1, view2), /* side inputs */
                   PipelineOptionsFactory.as(FlinkPipelineOptions.class),
                   keyCoder,
-                  WindowedValue::getValue,
+                  keySelector,
                   DoFnSchemaInformation.create(),
                   Collections.emptyMap());
 
           return new KeyedTwoInputStreamOperatorTestHarness<>(
               doFnOperator,
-              WindowedValue::getValue,
+              keySelector,
               // we use a dummy key for the second input since it is considered to be broadcast
               null,
-              BasicTypeInfo.STRING_TYPE_INFO);
+              new CoderTypeInformation<>(FlinkKeyUtils.ByteBufferCoder.of()));
         });
   }
 
@@ -970,12 +996,14 @@ public class DoFnOperatorTest {
   public void keyedParDoPushbackDataCheckpointing() throws Exception {
     pushbackDataCheckpointing(
         () -> {
+          StringUtf8Coder keyCoder = StringUtf8Coder.of();
           Coder<WindowedValue<String>> coder =
-              WindowedValue.getFullCoder(StringUtf8Coder.of(), IntervalWindow.getCoder());
+              WindowedValue.getFullCoder(keyCoder, IntervalWindow.getCoder());
 
           TupleTag<String> outputTag = new TupleTag<>("main-output");
 
-          StringUtf8Coder keyCoder = StringUtf8Coder.of();
+          KeySelector<WindowedValue<String>, ByteBuffer> keySelector =
+              e -> FlinkKeyUtils.encodeKey(e.getValue(), keyCoder);
 
           ImmutableMap<Integer, PCollectionView<?>> sideInputMapping =
               ImmutableMap.<Integer, PCollectionView<?>>builder()
@@ -997,16 +1025,16 @@ public class DoFnOperatorTest {
                   ImmutableList.of(view1, view2), /* side inputs */
                   PipelineOptionsFactory.as(FlinkPipelineOptions.class),
                   keyCoder,
-                  WindowedValue::getValue,
+                  keySelector,
                   DoFnSchemaInformation.create(),
                   Collections.emptyMap());
 
           return new KeyedTwoInputStreamOperatorTestHarness<>(
               doFnOperator,
-              WindowedValue::getValue,
+              keySelector,
               // we use a dummy key for the second input since it is considered to be broadcast
               null,
-              BasicTypeInfo.STRING_TYPE_INFO);
+              new CoderTypeInformation<>(FlinkKeyUtils.ByteBufferCoder.of()));
         });
   }
 
@@ -1093,8 +1121,9 @@ public class DoFnOperatorTest {
           }
         };
 
+    VarIntCoder keyCoder = VarIntCoder.of();
     WindowedValue.FullWindowedValueCoder<Integer> inputCoder =
-        WindowedValue.getFullCoder(VarIntCoder.of(), windowingStrategy.getWindowFn().windowCoder());
+        WindowedValue.getFullCoder(keyCoder, windowingStrategy.getWindowFn().windowCoder());
 
     WindowedValue.FullWindowedValueCoder<String> outputCoder =
         WindowedValue.getFullCoder(
@@ -1104,12 +1133,21 @@ public class DoFnOperatorTest {
 
     final CoderTypeSerializer<WindowedValue<String>> outputSerializer =
         new CoderTypeSerializer<>(outputCoder);
-    CoderTypeInformation<Integer> keyCoderInfo = new CoderTypeInformation<>(VarIntCoder.of());
-    KeySelector<WindowedValue<Integer>, Integer> keySelector = WindowedValue::getValue;
+    CoderTypeInformation<ByteBuffer> keyCoderInfo =
+        new CoderTypeInformation<>(FlinkKeyUtils.ByteBufferCoder.of());
+    KeySelector<WindowedValue<Integer>, ByteBuffer> keySelector =
+        e -> FlinkKeyUtils.encodeKey(e.getValue(), keyCoder);
 
     OneInputStreamOperatorTestHarness<WindowedValue<Integer>, WindowedValue<String>> testHarness =
         createTestHarness(
-            windowingStrategy, fn, inputCoder, outputCoder, outputTag, keyCoderInfo, keySelector);
+            windowingStrategy,
+            fn,
+            inputCoder,
+            outputCoder,
+            keyCoder,
+            outputTag,
+            keyCoderInfo,
+            keySelector);
 
     testHarness.setup(outputSerializer);
 
@@ -1131,7 +1169,14 @@ public class DoFnOperatorTest {
 
     testHarness =
         createTestHarness(
-            windowingStrategy, fn, inputCoder, outputCoder, outputTag, keyCoderInfo, keySelector);
+            windowingStrategy,
+            fn,
+            inputCoder,
+            outputCoder,
+            VarIntCoder.of(),
+            outputTag,
+            keyCoderInfo,
+            keySelector);
     testHarness.setup(outputSerializer);
     testHarness.initializeState(snapshot);
     testHarness.open();
@@ -1154,6 +1199,7 @@ public class DoFnOperatorTest {
           DoFn<InT, OutT> fn,
           FullWindowedValueCoder<InT> inputCoder,
           FullWindowedValueCoder<OutT> outputCoder,
+          Coder<?> keyCoder,
           TupleTag<OutT> outputTag,
           TypeInformation<K> keyCoderInfo,
           KeySelector<WindowedValue<InT>, K> keySelector)
@@ -1171,7 +1217,7 @@ public class DoFnOperatorTest {
             new HashMap<>(), /* side-input mapping */
             Collections.emptyList(), /* side inputs */
             PipelineOptionsFactory.as(FlinkPipelineOptions.class),
-            VarIntCoder.of() /* key coder */,
+            keyCoder /* key coder */,
             keySelector,
             DoFnSchemaInformation.create(),
             Collections.emptyMap());
diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/stableinput/BufferedElementsTest.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/stableinput/BufferedElementsTest.java
index 0828a22..0868088 100644
--- a/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/stableinput/BufferedElementsTest.java
+++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/stableinput/BufferedElementsTest.java
@@ -28,6 +28,7 @@ import org.apache.beam.sdk.state.TimeDomain;
 import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
 import org.apache.beam.sdk.transforms.windowing.PaneInfo;
 import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.sdk.values.KV;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
 import org.hamcrest.Matchers;
 import org.joda.time.Instant;
@@ -44,8 +45,8 @@ public class BufferedElementsTest {
     org.apache.beam.sdk.coders.Coder windowCoder = GlobalWindow.Coder.INSTANCE;
     WindowedValue.WindowedValueCoder windowedValueCoder =
         WindowedValue.FullWindowedValueCoder.of(elementCoder, windowCoder);
-
-    BufferedElements.Coder coder = new BufferedElements.Coder(windowedValueCoder, windowCoder);
+    KV<String, Integer> key = KV.of("one", 1);
+    BufferedElements.Coder coder = new BufferedElements.Coder(windowedValueCoder, windowCoder, key);
 
     BufferedElement element =
         new BufferedElements.Element(
@@ -54,6 +55,7 @@ public class BufferedElementsTest {
         new BufferedElements.Timer(
             "timerId",
             "timerId",
+            key,
             GlobalWindow.INSTANCE,
             new Instant(1),
             new Instant(1),
diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/DataflowProcessFnRunner.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/DataflowProcessFnRunner.java
index 3cb92c9..62185f5 100644
--- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/DataflowProcessFnRunner.java
+++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/DataflowProcessFnRunner.java
@@ -108,9 +108,10 @@ class DataflowProcessFnRunner<InputT, OutputT, RestrictionT>
   }
 
   @Override
-  public void onTimer(
+  public <KeyT> void onTimer(
       String timerId,
       String timerFamilyId,
+      KeyT key,
       BoundedWindow window,
       Instant timestamp,
       Instant outputTimestamp,
diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/GroupAlsoByWindowFnRunner.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/GroupAlsoByWindowFnRunner.java
index 6296461..c8d8567 100644
--- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/GroupAlsoByWindowFnRunner.java
+++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/GroupAlsoByWindowFnRunner.java
@@ -81,9 +81,10 @@ public class GroupAlsoByWindowFnRunner<InputT, OutputT> implements DoFnRunner<In
   }
 
   @Override
-  public void onTimer(
+  public <KeyT> void onTimer(
       String timerId,
       String timerFamilyId,
+      KeyT key,
       BoundedWindow window,
       Instant timestamp,
       Instant outputTimestamp,
diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/SimpleParDoFn.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/SimpleParDoFn.java
index 76ea9d0..2985a6f 100644
--- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/SimpleParDoFn.java
+++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/SimpleParDoFn.java
@@ -361,6 +361,7 @@ public class SimpleParDoFn<InputT, OutputT> implements ParDoFn {
       fnRunner.onTimer(
           timer.getTimerId(),
           timer.getTimerFamilyId(),
+          this.stepContext.stateInternals().getKey(),
           window,
           timer.getTimestamp(),
           timer.getOutputTimestamp(),
diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java
index 114eec3..083914b 100644
--- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java
+++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java
@@ -118,6 +118,7 @@ import org.apache.beam.runners.dataflow.worker.windmill.WindmillServerStub.Commi
 import org.apache.beam.runners.dataflow.worker.windmill.WindmillServerStub.GetWorkStream;
 import org.apache.beam.runners.dataflow.worker.windmill.WindmillServerStub.StreamPool;
 import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.KvCoder;
 import org.apache.beam.sdk.extensions.gcp.util.Transport;
 import org.apache.beam.sdk.fn.IdGenerator;
 import org.apache.beam.sdk.fn.IdGenerators;
@@ -1160,6 +1161,10 @@ public class StreamingDataflowWorker {
     // Note that TimerOrElementCoder is a backwards-compatibility class
     // that is really a FakeKeyedWorkItemCoder
     Coder<?> valueCoder = ((WindowedValueCoder<?>) readCoder).getValueCoder();
+
+    if (valueCoder instanceof KvCoder<?, ?>) {
+      return ((KvCoder<?, ?>) valueCoder).getKeyCoder();
+    }
     if (!(valueCoder instanceof WindmillKeyedWorkItem.FakeKeyedWorkItemCoder<?, ?>)) {
       return null;
     }
diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingKeyedWorkItemSideInputDoFnRunner.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingKeyedWorkItemSideInputDoFnRunner.java
index 41da4b3..be97130 100644
--- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingKeyedWorkItemSideInputDoFnRunner.java
+++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingKeyedWorkItemSideInputDoFnRunner.java
@@ -132,9 +132,10 @@ public class StreamingKeyedWorkItemSideInputDoFnRunner<K, InputT, OutputT, W ext
   }
 
   @Override
-  public void onTimer(
+  public <KeyT> void onTimer(
       String timerId,
       String timerFamilyId,
+      KeyT key,
       BoundedWindow window,
       Instant timestamp,
       Instant outputTimestamp,
diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingSideInputDoFnRunner.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingSideInputDoFnRunner.java
index f7e86c7..f4c0c77 100644
--- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingSideInputDoFnRunner.java
+++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingSideInputDoFnRunner.java
@@ -76,9 +76,10 @@ public class StreamingSideInputDoFnRunner<InputT, OutputT, W extends BoundedWind
   }
 
   @Override
-  public void onTimer(
+  public <KeyT> void onTimer(
       String timerId,
       String timerFamilyId,
+      KeyT key,
       BoundedWindow window,
       Instant timestamp,
       Instant outputTimestamp,
diff --git a/runners/jet/src/main/java/org/apache/beam/runners/jet/processors/StatefulParDoP.java b/runners/jet/src/main/java/org/apache/beam/runners/jet/processors/StatefulParDoP.java
index 76e8375..1c51e9f 100644
--- a/runners/jet/src/main/java/org/apache/beam/runners/jet/processors/StatefulParDoP.java
+++ b/runners/jet/src/main/java/org/apache/beam/runners/jet/processors/StatefulParDoP.java
@@ -95,6 +95,7 @@ public class StatefulParDoP<OutputT>
     doFnRunner.onTimer(
         timer.getTimerId(),
         timer.getTimerFamilyId(),
+        null,
         window,
         timer.getTimestamp(),
         timer.getOutputTimestamp(),
diff --git a/runners/samza/src/main/java/org/apache/beam/runners/samza/metrics/DoFnRunnerWithMetrics.java b/runners/samza/src/main/java/org/apache/beam/runners/samza/metrics/DoFnRunnerWithMetrics.java
index 8a0f579..b3cedbf 100644
--- a/runners/samza/src/main/java/org/apache/beam/runners/samza/metrics/DoFnRunnerWithMetrics.java
+++ b/runners/samza/src/main/java/org/apache/beam/runners/samza/metrics/DoFnRunnerWithMetrics.java
@@ -56,9 +56,10 @@ public class DoFnRunnerWithMetrics<InT, OutT> implements DoFnRunner<InT, OutT> {
   }
 
   @Override
-  public void onTimer(
+  public <KeyT> void onTimer(
       String timerId,
       String timerFamilyId,
+      KeyT key,
       BoundedWindow window,
       Instant timestamp,
       Instant outputTimestamp,
@@ -66,7 +67,7 @@ public class DoFnRunnerWithMetrics<InT, OutT> implements DoFnRunner<InT, OutT> {
     withMetrics(
         () ->
             underlying.onTimer(
-                timerId, timerFamilyId, window, timestamp, outputTimestamp, timeDomain),
+                timerId, timerFamilyId, key, window, timestamp, outputTimestamp, timeDomain),
         false);
   }
 
diff --git a/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/DoFnOp.java b/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/DoFnOp.java
index cd140c7..b9d6216 100644
--- a/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/DoFnOp.java
+++ b/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/DoFnOp.java
@@ -455,6 +455,7 @@ public class DoFnOp<InT, FnOutT, OutT> implements Op<InT, OutT, Void> {
       pushbackFnRunner.onTimer(
           timer.getTimerId(),
           timer.getTimerFamilyId(),
+          null,
           window,
           timer.getTimestamp(),
           timer.getOutputTimestamp(),
diff --git a/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/DoFnRunnerWithKeyedInternals.java b/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/DoFnRunnerWithKeyedInternals.java
index 3b2d1cb..d319c6a 100644
--- a/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/DoFnRunnerWithKeyedInternals.java
+++ b/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/DoFnRunnerWithKeyedInternals.java
@@ -65,6 +65,7 @@ public class DoFnRunnerWithKeyedInternals<InputT, OutputT> implements DoFnRunner
       onTimer(
           timer.getTimerId(),
           timer.getTimerFamilyId(),
+          keyedTimerData.getKey(),
           window,
           timer.getTimestamp(),
           timer.getOutputTimestamp(),
@@ -75,16 +76,17 @@ public class DoFnRunnerWithKeyedInternals<InputT, OutputT> implements DoFnRunner
   }
 
   @Override
-  public void onTimer(
+  public <KeyT> void onTimer(
       String timerId,
       String timerFamilyId,
+      KeyT key,
       BoundedWindow window,
       Instant timestamp,
       Instant outputTimestamp,
       TimeDomain timeDomain) {
     checkState(keyedInternals.getKey() != null, "Key is not set for timer");
 
-    underlying.onTimer(timerId, timerFamilyId, window, timestamp, outputTimestamp, timeDomain);
+    underlying.onTimer(timerId, timerFamilyId, key, window, timestamp, outputTimestamp, timeDomain);
   }
 
   @Override
diff --git a/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/SamzaDoFnRunners.java b/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/SamzaDoFnRunners.java
index 0d8aa08..f9e1325 100644
--- a/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/SamzaDoFnRunners.java
+++ b/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/SamzaDoFnRunners.java
@@ -264,9 +264,10 @@ public class SamzaDoFnRunners {
     }
 
     @Override
-    public void onTimer(
+    public <KeyT> void onTimer(
         String timerId,
         String timerFamilyId,
+        KeyT key,
         BoundedWindow window,
         Instant timestamp,
         Instant outputTimestamp,
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DoFnRunnerWithMetrics.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DoFnRunnerWithMetrics.java
index 55d97ba..0294228 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DoFnRunnerWithMetrics.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DoFnRunnerWithMetrics.java
@@ -69,15 +69,16 @@ class DoFnRunnerWithMetrics<InputT, OutputT> implements DoFnRunner<InputT, Outpu
   }
 
   @Override
-  public void onTimer(
+  public <KeyT> void onTimer(
       final String timerId,
       final String timerFamilyId,
+      KeyT key,
       final BoundedWindow window,
       final Instant timestamp,
       final Instant outputTimestamp,
       final TimeDomain timeDomain) {
     try (Closeable ignored = MetricsEnvironment.scopedMetricsContainer(metricsContainer())) {
-      delegate.onTimer(timerId, timerFamilyId, window, timestamp, outputTimestamp, timeDomain);
+      delegate.onTimer(timerId, timerFamilyId, key, window, timestamp, outputTimestamp, timeDomain);
     } catch (IOException e) {
       throw new RuntimeException(e);
     }
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/DoFnRunnerWithMetrics.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/DoFnRunnerWithMetrics.java
index 013f860..0784292 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/DoFnRunnerWithMetrics.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/DoFnRunnerWithMetrics.java
@@ -69,15 +69,16 @@ class DoFnRunnerWithMetrics<InputT, OutputT> implements DoFnRunner<InputT, Outpu
   }
 
   @Override
-  public void onTimer(
+  public <KeyT> void onTimer(
       final String timerId,
       final String timerFamilyId,
+      final KeyT key,
       final BoundedWindow window,
       final Instant timestamp,
       final Instant outputTimestamp,
       final TimeDomain timeDomain) {
     try (Closeable ignored = MetricsEnvironment.scopedMetricsContainer(metricsContainer())) {
-      delegate.onTimer(timerId, timerFamilyId, window, timestamp, outputTimestamp, timeDomain);
+      delegate.onTimer(timerId, timerFamilyId, key, window, timestamp, outputTimestamp, timeDomain);
     } catch (IOException e) {
       throw new RuntimeException(e);
     }
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkProcessContext.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkProcessContext.java
index 9cbbeda..657da55 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkProcessContext.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkProcessContext.java
@@ -164,6 +164,7 @@ class SparkProcessContext<FnInputT, FnOutputT, OutputT> {
       doFnRunner.onTimer(
           timer.getTimerId(),
           timer.getTimerFamilyId(),
+          null,
           window,
           timer.getTimestamp(),
           timer.getOutputTimestamp(),
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/UsesKey.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/UsesKey.java
new file mode 100644
index 0000000..c25d25a
--- /dev/null
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/UsesKey.java
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.testing;
+
+import org.apache.beam.sdk.annotations.Internal;
+import org.apache.beam.sdk.transforms.DoFn.OnTimer;
+
+/**
+ * Category tag for validation tests which use key. Tests tagged with {@link UsesKey} should be run
+ * for runners which support key parameter in {@link OnTimer}.
+ */
+@Internal
+public interface UsesKey {}
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java
index fda2317..c4df77b 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java
@@ -480,6 +480,15 @@ public abstract class DoFn<InputT, OutputT> implements Serializable, HasDisplayD
     String value();
   }
 
+  /**
+   * Parameter annotation for dereferencing input element key in {@link
+   * org.apache.beam.sdk.values.KV} pair.
+   */
+  @Documented
+  @Retention(RetentionPolicy.RUNTIME)
+  @Target(ElementType.PARAMETER)
+  public @interface Key {}
+
   /** Annotation for specifying specific fields that are accessed in a Schema PCollection. */
   @Retention(RetentionPolicy.RUNTIME)
   @Target({ElementType.FIELD, ElementType.PARAMETER})
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnTester.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnTester.java
index 4d9f81f..b02a894 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnTester.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnTester.java
@@ -259,6 +259,12 @@ public class DoFnTester<InputT, OutputT> implements AutoCloseable {
             }
 
             @Override
+            public Object key() {
+              throw new UnsupportedOperationException(
+                  "Cannot access key as parameter outside of @OnTimer method.");
+            }
+
+            @Override
             public Instant timestamp(DoFn<InputT, OutputT> doFn) {
               return processContext.timestamp();
             }
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/ByteBuddyDoFnInvokerFactory.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/ByteBuddyDoFnInvokerFactory.java
index ebf3d3b..d4d2cfc 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/ByteBuddyDoFnInvokerFactory.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/ByteBuddyDoFnInvokerFactory.java
@@ -132,6 +132,7 @@ class ByteBuddyDoFnInvokerFactory implements DoFnInvokerFactory {
   public static final String SIDE_INPUT_PARAMETER_METHOD = "sideInput";
   public static final String TIMER_FAMILY_PARAMETER_METHOD = "timerFamily";
   public static final String TIMER_ID_PARAMETER_METHOD = "timerId";
+  public static final String KEY_PARAMETER_METHOD = "key";
 
   /**
    * Returns a {@link ByteBuddyDoFnInvokerFactory} shared with all other invocations, so that its
@@ -1054,6 +1055,13 @@ class ByteBuddyDoFnInvokerFactory implements DoFnInvokerFactory {
                     getExtraContextFactoryMethodDescription(
                         TIMER_ID_PARAMETER_METHOD, DoFn.class)));
           }
+
+          @Override
+          public StackManipulation dispatch(DoFnSignature.Parameter.KeyParameter p) {
+            return new StackManipulation.Compound(
+                simpleExtraContextParameter(KEY_PARAMETER_METHOD),
+                TypeCasting.to(new TypeDescription.ForLoadedType(p.keyT().getRawType())));
+          }
         });
   }
 
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvoker.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvoker.java
index a03639f..f8d4a36 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvoker.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvoker.java
@@ -171,6 +171,11 @@ public interface DoFnInvoker<InputT, OutputT> {
     /** Provide a reference to the input element. */
     InputT element(DoFn<InputT, OutputT> doFn);
 
+    /**
+     * Provide a reference to the input element key in {@link org.apache.beam.sdk.values.KV} pair.
+     */
+    Object key();
+
     /** Provide a reference to the input sideInput with the specified tag. */
     Object sideInput(String tagId);
 
@@ -258,6 +263,12 @@ public interface DoFnInvoker<InputT, OutputT> {
     }
 
     @Override
+    public Object key() {
+      throw new UnsupportedOperationException(
+          "Cannot access key as parameter outside of @OnTimer method.");
+    }
+
+    @Override
     public Object sideInput(String tagId) {
       throw new UnsupportedOperationException(
           String.format("SideInput unsupported in %s", getErrorContext()));
@@ -451,6 +462,11 @@ public interface DoFnInvoker<InputT, OutputT> {
     }
 
     @Override
+    public Object key() {
+      return delegate.key();
+    }
+
+    @Override
     public Object sideInput(String tagId) {
       return delegate.sideInput(tagId);
     }
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java
index 8a423a2..56fe771 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java
@@ -298,6 +298,8 @@ public abstract class DoFnSignature {
         return cases.dispatch((TimerIdParameter) this);
       } else if (this instanceof BundleFinalizerParameter) {
         return cases.dispatch((BundleFinalizerParameter) this);
+      } else if (this instanceof KeyParameter) {
+        return cases.dispatch((KeyParameter) this);
       } else {
         throw new IllegalStateException(
             String.format(
@@ -354,6 +356,8 @@ public abstract class DoFnSignature {
 
       ResultT dispatch(BundleFinalizerParameter p);
 
+      ResultT dispatch(KeyParameter p);
+
       /** A base class for a visitor with a default method for cases it is not interested in. */
       abstract class WithDefault<ResultT> implements Cases<ResultT> {
 
@@ -473,6 +477,11 @@ public abstract class DoFnSignature {
         public ResultT dispatch(TimerFamilyParameter p) {
           return dispatchDefault(p);
         }
+
+        @Override
+        public ResultT dispatch(KeyParameter p) {
+          return dispatchDefault(p);
+        }
       }
     }
 
@@ -575,6 +584,11 @@ public abstract class DoFnSignature {
       return new AutoValue_DoFnSignature_Parameter_WindowParameter(windowT);
     }
 
+    /** Returns a {@link KeyParameter}. */
+    public static KeyParameter keyT(TypeDescriptor<?> keyT) {
+      return new AutoValue_DoFnSignature_Parameter_KeyParameter(keyT);
+    }
+
     /** Returns a {@link PipelineOptionsParameter}. */
     public static PipelineOptionsParameter pipelineOptions() {
       return PIPELINE_OPTIONS_PARAMETER;
@@ -719,6 +733,13 @@ public abstract class DoFnSignature {
       TimerIdParameter() {}
     }
 
+    @AutoValue
+    public abstract static class KeyParameter extends Parameter {
+      KeyParameter() {}
+
+      public abstract TypeDescriptor<?> keyT();
+    }
+
     /**
      * Descriptor for a {@link Parameter} representing the time domain of a timer.
      *
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java
index 31ad817..36fda43 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java
@@ -86,6 +86,7 @@ import org.apache.beam.sdk.transforms.splittabledofn.WatermarkEstimator;
 import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
 import org.apache.beam.sdk.transforms.windowing.PaneInfo;
 import org.apache.beam.sdk.util.common.ReflectHelpers;
+import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.Row;
 import org.apache.beam.sdk.values.TypeDescriptor;
@@ -164,7 +165,8 @@ public class DoFnSignatures {
           Parameter.TimerParameter.class,
           Parameter.StateParameter.class,
           Parameter.TimerFamilyParameter.class,
-          Parameter.TimerIdParameter.class);
+          Parameter.TimerIdParameter.class,
+          Parameter.KeyParameter.class);
 
   private static final ImmutableList<Class<? extends Parameter>>
       ALLOWED_ON_TIMER_FAMILY_PARAMETERS =
@@ -1284,6 +1286,18 @@ public class DoFnSignatures {
           rawType.equals(Instant.class),
           "@Timestamp argument must have type org.joda.time.Instant.");
       return Parameter.timestampParameter();
+    } else if (hasAnnotation(DoFn.Key.class, param.getAnnotations())) {
+      methodErrors.checkArgument(
+          KV.class.equals(inputT.getRawType()),
+          "@Key argument is expected to be use with input element of type KV.");
+
+      Type keyType = ((ParameterizedType) inputT.getType()).getActualTypeArguments()[0];
+      methodErrors.checkArgument(
+          TypeDescriptor.of(keyType).equals(paramT),
+          "@Key argument is expected to be type of %s, but found %s.",
+          keyType,
+          rawType);
+      return Parameter.keyT(paramT);
     } else if (rawType.equals(TimeDomain.class)) {
       return Parameter.timeDomainParameter();
     } else if (hasAnnotation(DoFn.SideInput.class, param.getAnnotations())) {
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java
index a9e33c4..d71e0fd 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java
@@ -91,6 +91,7 @@ import org.apache.beam.sdk.testing.NeedsRunner;
 import org.apache.beam.sdk.testing.PAssert;
 import org.apache.beam.sdk.testing.TestPipeline;
 import org.apache.beam.sdk.testing.TestStream;
+import org.apache.beam.sdk.testing.UsesKey;
 import org.apache.beam.sdk.testing.UsesMapState;
 import org.apache.beam.sdk.testing.UsesRequiresTimeSortedInput;
 import org.apache.beam.sdk.testing.UsesSetState;
@@ -4970,4 +4971,138 @@ public class ParDoTest implements Serializable {
       pipeline.run();
     }
   }
+
+  /** Tests to validate Key in OnTimer. */
+  @RunWith(JUnit4.class)
+  public static class KeyTests extends SharedTestBase implements Serializable {
+
+    @Test
+    @Category({
+      ValidatesRunner.class,
+      UsesTimersInParDo.class,
+      UsesKey.class,
+    })
+    public void testKeyInOnTimer() throws Exception {
+      final String timerId = "foo";
+
+      DoFn<KV<String, Integer>, Integer> fn =
+          new DoFn<KV<String, Integer>, Integer>() {
+
+            @TimerId(timerId)
+            private final TimerSpec spec = TimerSpecs.timer(TimeDomain.EVENT_TIME);
+
+            @ProcessElement
+            public void processElement(@TimerId(timerId) Timer timer, OutputReceiver<Integer> r) {
+              timer.set(new Instant(1));
+            }
+
+            @OnTimer(timerId)
+            public void onTimer(@Key String key, OutputReceiver<Integer> r) {
+              r.output(Integer.parseInt(key));
+            }
+          };
+
+      PCollection<Integer> output =
+          pipeline.apply(Create.of(KV.of("1", 37), KV.of("2", 3))).apply(ParDo.of(fn));
+      PAssert.that(output).containsInAnyOrder(1, 2);
+      pipeline.run();
+    }
+
+    @Test
+    @Category({
+      ValidatesRunner.class,
+      UsesTimersInParDo.class,
+      UsesKey.class,
+    })
+    public void testKeyInOnTimerWithGenericKey() throws Exception {
+      final String timerId = "foo";
+
+      DoFn<KV<KV<String, String>, Integer>, Integer> fn =
+          new DoFn<KV<KV<String, String>, Integer>, Integer>() {
+
+            @TimerId(timerId)
+            private final TimerSpec spec = TimerSpecs.timer(TimeDomain.EVENT_TIME);
+
+            @ProcessElement
+            public void processElement(@TimerId(timerId) Timer timer, OutputReceiver<Integer> r) {
+              timer.set(new Instant(1));
+            }
+
+            @OnTimer(timerId)
+            public void onTimer(@Key KV<String, String> key, OutputReceiver<Integer> r) {
+              r.output(Integer.parseInt(key.getKey()));
+            }
+          };
+
+      PCollection<Integer> output =
+          pipeline
+              .apply(Create.of(KV.of(KV.of("1", "1"), 37), KV.of(KV.of("1", "1"), 3)))
+              .apply(ParDo.of(fn));
+      PAssert.that(output).containsInAnyOrder(1);
+      pipeline.run();
+    }
+
+    @Test
+    @Category({
+      ValidatesRunner.class,
+      UsesTimersInParDo.class,
+      UsesKey.class,
+    })
+    public void testKeyInOnTimerWithWrongKeyType() throws Exception {
+
+      thrown.expect(IllegalArgumentException.class);
+      thrown.expectMessage("@Key argument is expected to be type of");
+      thrown.expectMessage(", but found ");
+
+      final String timerId = "foo";
+
+      DoFn<KV<String, Integer>, Integer> fn =
+          new DoFn<KV<String, Integer>, Integer>() {
+
+            @TimerId(timerId)
+            private final TimerSpec spec = TimerSpecs.timer(TimeDomain.EVENT_TIME);
+
+            @ProcessElement
+            public void processElement(@TimerId(timerId) Timer timer, OutputReceiver<Integer> r) {
+              timer.set(new Instant(1));
+            }
+
+            @OnTimer(timerId)
+            public void onTimer(@Key Integer key, OutputReceiver<Integer> r) {}
+          };
+
+      pipeline.apply(Create.of(KV.of("1", 37), KV.of("1", 4), KV.of("2", 3))).apply(ParDo.of(fn));
+    }
+
+    @Test
+    @Category({
+      ValidatesRunner.class,
+      UsesTimersInParDo.class,
+      UsesKey.class,
+    })
+    public void testKeyInOnTimerWithoutKV() throws Exception {
+
+      thrown.expect(IllegalArgumentException.class);
+      thrown.expectMessage("@Key argument is expected to be use with input element of type KV.");
+
+      final String timerId = "foo";
+
+      DoFn<String, Integer> fn =
+          new DoFn<String, Integer>() {
+
+            @TimerId(timerId)
+            private final TimerSpec spec = TimerSpecs.timer(TimeDomain.EVENT_TIME);
+
+            @ProcessElement
+            public void processElement(@TimerId(timerId) Timer timer, OutputReceiver<Integer> r) {
+              timer.set(new Instant(1));
+            }
+
+            @OnTimer(timerId)
+            public void onTimer(@Key Integer key, OutputReceiver<Integer> r) {}
+          };
+
+      pipeline.apply(Create.of("1")).apply(ParDo.of(fn));
+    }
+  }
 }
diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/datastore/DataStoreV1Table.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/datastore/DataStoreV1Table.java
index 3a103c8..7efdb13 100644
--- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/datastore/DataStoreV1Table.java
+++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/datastore/DataStoreV1Table.java
@@ -22,7 +22,6 @@ import static com.google.datastore.v1.client.DatastoreHelper.makeValue;
 
 import com.alibaba.fastjson.JSONObject;
 import com.google.datastore.v1.Entity;
-import com.google.datastore.v1.Key;
 import com.google.datastore.v1.Query;
 import com.google.datastore.v1.Value;
 import com.google.datastore.v1.Value.ValueTypeCase;
@@ -372,14 +371,14 @@ class DataStoreV1Table extends SchemaBaseBeamTable implements Serializable {
        * @param row {@code Row} to construct a key for.
        * @return resulting {@code Key}.
        */
-      private Key constructKeyFromRow(Row row) {
+      private com.google.datastore.v1.Key constructKeyFromRow(Row row) {
         if (!useNonRandomKey) {
           // When key field is not present - use key supplier to generate a random one.
           return makeKey(kind, keySupplier.get()).build();
         }
         byte[] keyBytes = row.getBytes(keyField);
         try {
-          return Key.parseFrom(keyBytes);
+          return com.google.datastore.v1.Key.parseFrom(keyBytes);
         } catch (InvalidProtocolBufferException e) {
           throw new IllegalStateException("Failed to parse DataStore key from bytes.");
         }
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java
index daf1f479..98b3e66 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java
@@ -271,7 +271,7 @@ public class FnApiDoFnRunner<InputT, RestrictionT, PositionT, WatermarkEstimator
   private final DoFnInvoker<InputT, OutputT> doFnInvoker;
   private final StartBundleArgumentProvider startBundleArgumentProvider;
   private final ProcessBundleContext processContext;
-  private final OnTimerContext onTimerContext;
+  private OnTimerContext onTimerContext;
   private final FinishBundleArgumentProvider finishBundleArgumentProvider;
 
   /**
@@ -573,7 +573,6 @@ public class FnApiDoFnRunner<InputT, RestrictionT, PositionT, WatermarkEstimator
         throw new IllegalStateException(
             String.format("Unknown URN %s", pTransform.getSpec().getUrn()));
     }
-    this.onTimerContext = new OnTimerContext();
     this.finishBundleArgumentProvider = new FinishBundleArgumentProvider();
 
     switch (pTransform.getSpec().getUrn()) {
@@ -966,6 +965,7 @@ public class FnApiDoFnRunner<InputT, RestrictionT, PositionT, WatermarkEstimator
   private <K> void processTimer(String timerId, TimeDomain timeDomain, Timer<K> timer) {
     currentTimer = timer;
     currentTimeDomain = timeDomain;
+    onTimerContext = new OnTimerContext<>(timer.getUserKey());
     try {
       Iterator<BoundedWindow> windowIterator =
           (Iterator<BoundedWindow>) timer.getWindows().iterator();
@@ -1350,6 +1350,12 @@ public class FnApiDoFnRunner<InputT, RestrictionT, PositionT, WatermarkEstimator
     }
 
     @Override
+    public Object key() {
+      throw new UnsupportedOperationException(
+          "Cannot access key as parameter outside of @OnTimer method.");
+    }
+
+    @Override
     public Object sideInput(String tagId) {
       return sideInput(sideInputMapping.get(tagId));
     }
@@ -1527,7 +1533,13 @@ public class FnApiDoFnRunner<InputT, RestrictionT, PositionT, WatermarkEstimator
   }
 
   /** Provides arguments for a {@link DoFnInvoker} for {@link DoFn.OnTimer @OnTimer}. */
-  private class OnTimerContext extends BaseArgumentProvider<InputT, OutputT> {
+  private class OnTimerContext<K> extends BaseArgumentProvider<InputT, OutputT> {
+    private final K key;
+
+    public OnTimerContext(K key) {
+      this.key = key;
+    }
+
     private class Context extends DoFn<InputT, OutputT>.OnTimerContext {
       private Context() {
         doFn.super();
@@ -1626,6 +1638,11 @@ public class FnApiDoFnRunner<InputT, RestrictionT, PositionT, WatermarkEstimator
     }
 
     @Override
+    public K key() {
+      return key;
+    }
+
+    @Override
     public OutputReceiver<OutputT> outputReceiver(DoFn<InputT, OutputT> doFn) {
       return DoFnOutputReceivers.windowedReceiver(context, null);
     }
diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/datastore/V1TestUtil.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/datastore/V1TestUtil.java
index 82cfec4..5ae3d93 100644
--- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/datastore/V1TestUtil.java
+++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/datastore/V1TestUtil.java
@@ -119,7 +119,7 @@ class V1TestUtil {
     private final String kind;
     @Nullable private final String namespace;
     private final int largePropertySize;
-    private Key ancestorKey;
+    private com.google.datastore.v1.Key ancestorKey;
 
     CreateEntityFn(
         String kind, @Nullable String namespace, String ancestor, int largePropertySize) {