You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by lc...@apache.org on 2022/07/08 19:53:34 UTC

[beam] branch master updated: [BEAM-13015, #22050] Make SDK harness msec counters faster using ordered puts (#22103)

This is an automated email from the ASF dual-hosted git repository.

lcwik pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new fbe61507cab [BEAM-13015, #22050] Make SDK harness msec counters faster using ordered puts (#22103)
fbe61507cab is described below

commit fbe61507cab08838b4769225a76a56c1c10684ed
Author: Luke Cwik <lc...@google.com>
AuthorDate: Fri Jul 8 12:53:28 2022 -0700

    [BEAM-13015, #22050] Make SDK harness msec counters faster using ordered puts (#22103)
    
    * [BEAM-13015, #21250] Make SDK harness msec counters to be faster
    
    Msec counters benchmarks, 2.5x-3.5x faster, larger bundles have bigger wins
    
    Benchmark                                                               Mode  Cnt         Score        Error  Units
    ExecutionStateSamplerBenchmark.testLargeBundleHarnessStateSampler      thrpt   25     74093.492 ±    364.322  ops/s
    ExecutionStateSamplerBenchmark.testLargeBundleRunnersCoreStateSampler  thrpt   25     19838.512 ±     81.986  ops/s
    ExecutionStateSamplerBenchmark.testTinyBundleHarnessStateSampler       thrpt   25  10187284.157 ± 128030.350  ops/s
    ExecutionStateSamplerBenchmark.testTinyBundleRunnersCoreStateSampler   thrpt   25   4099369.479 ±  25219.934  ops/s
    Before
    
    Benchmark                               Mode   Cnt      Score     Error  Units
    ProcessBundleBenchmark.testTinyBundle   thrpt   25  28624.904 ± 231.755  ops/s
    ProcessBundleBenchmark.testLargeBundle  thrpt   25   1128.407 ±  22.196  ops/s
    After
    
    Benchmark                               Mode   Cnt      Score     Error  Units
    ProcessBundleBenchmark.testTinyBundle   thrpt   25  29091.576 ± 103.945  ops/s
    ProcessBundleBenchmark.testLargeBundle  thrpt   25   1158.686 ±   6.835  ops/s
---
 .../core/metrics/MonitoringInfoConstants.java      |  31 +-
 .../fnexecution/control/RemoteExecutionTest.java   |  14 +-
 .../control/ExecutionStateSamplerBenchmark.java    | 174 ++++++++
 .../ExecutionStateSamplerBenchmarkTest.java        |  58 +++
 .../java/org/apache/beam/fn/harness/FnHarness.java |  16 +-
 .../fn/harness/control/ExecutionStateSampler.java  | 418 ++++++++++++++++++
 .../fn/harness/control/ProcessBundleHandler.java   |  99 ++---
 .../harness/data/PCollectionConsumerRegistry.java  |  76 ++--
 .../harness/data/PTransformFunctionRegistry.java   |  84 ++--
 .../beam/fn/harness/status/BeamFnStatusClient.java |  28 +-
 .../harness/control/ExecutionStateSamplerTest.java | 489 +++++++++++++++++++++
 .../harness/control/ProcessBundleHandlerTest.java  |  30 +-
 .../data/PCollectionConsumerRegistryTest.java      |  59 ++-
 .../data/PTransformFunctionRegistryTest.java       |  86 +++-
 .../fn/harness/status/BeamFnStatusClientTest.java  |  11 +-
 15 files changed, 1464 insertions(+), 209 deletions(-)

diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/MonitoringInfoConstants.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/MonitoringInfoConstants.java
index 03496fe38db..2808ae58cf1 100644
--- a/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/MonitoringInfoConstants.java
+++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/MonitoringInfoConstants.java
@@ -35,11 +35,13 @@ public final class MonitoringInfoConstants {
   public static final class Urns {
     public static final String ELEMENT_COUNT = extractUrn(MonitoringInfoSpecs.Enum.ELEMENT_COUNT);
     public static final String START_BUNDLE_MSECS =
-        extractUrn(MonitoringInfoSpecs.Enum.START_BUNDLE_MSECS);
+        "beam:metric:pardo_execution_time:start_bundle_msecs:v1";
+
     public static final String PROCESS_BUNDLE_MSECS =
-        extractUrn(MonitoringInfoSpecs.Enum.PROCESS_BUNDLE_MSECS);
+        "beam:metric:pardo_execution_time:process_bundle_msecs:v1";
+
     public static final String FINISH_BUNDLE_MSECS =
-        extractUrn(MonitoringInfoSpecs.Enum.FINISH_BUNDLE_MSECS);
+        "beam:metric:pardo_execution_time:finish_bundle_msecs:v1";
     public static final String TOTAL_MSECS = extractUrn(MonitoringInfoSpecs.Enum.TOTAL_MSECS);
     public static final String USER_SUM_INT64 = extractUrn(MonitoringInfoSpecs.Enum.USER_SUM_INT64);
     public static final String USER_SUM_DOUBLE =
@@ -58,6 +60,18 @@ public final class MonitoringInfoConstants {
         extractUrn(MonitoringInfoSpecs.Enum.API_REQUEST_COUNT);
     public static final String API_REQUEST_LATENCIES =
         extractUrn(MonitoringInfoSpecs.Enum.API_REQUEST_LATENCIES);
+
+    static {
+      // Validate that compile time constants match the values stored in the protos.
+      // Defining these as constants allows for usage in switch case statements and also
+      // ensures that protos don't get accidentally changed.
+      checkArgument(
+          START_BUNDLE_MSECS.equals(extractUrn(MonitoringInfoSpecs.Enum.START_BUNDLE_MSECS)));
+      checkArgument(
+          PROCESS_BUNDLE_MSECS.equals(extractUrn(MonitoringInfoSpecs.Enum.PROCESS_BUNDLE_MSECS)));
+      checkArgument(
+          FINISH_BUNDLE_MSECS.equals(extractUrn(MonitoringInfoSpecs.Enum.FINISH_BUNDLE_MSECS)));
+    }
   }
 
   /** Standardised MonitoringInfo labels that can be utilized by runners. */
@@ -91,11 +105,9 @@ public final class MonitoringInfoConstants {
     public static final String SPANNER_QUERY_NAME = "SPANNER_QUERY_NAME";
 
     static {
-      // Note: One benefit of defining these strings above, instead of pulling them in from
-      // the proto files, is to ensure that this code will crash if the strings in the proto
-      // file are changed, without modifying this file.
-      // Though, one should not change those strings either, as Runner Harnesss running old versions
-      // would not be able to understand the new label names./
+      // Validate that compile time constants match the values stored in the protos.
+      // Defining these as constants allows for usage in switch case statements and also
+      // ensures that protos don't get accidentally changed.
       checkArgument(PTRANSFORM.equals(extractLabel(MonitoringInfoLabels.TRANSFORM)));
       checkArgument(PCOLLECTION.equals(extractLabel(MonitoringInfoLabels.PCOLLECTION)));
       checkArgument(
@@ -150,6 +162,9 @@ public final class MonitoringInfoConstants {
     public static final String PROGRESS_TYPE = "beam:metrics:progress:v1";
 
     static {
+      // Validate that compile time constants match the values stored in the protos.
+      // Defining these as constants allows for usage in switch case statements and also
+      // ensures that protos don't get accidentally changed.
       checkArgument(SUM_INT64_TYPE.equals(getUrn(MonitoringInfoTypeUrns.Enum.SUM_INT64_TYPE)));
       checkArgument(SUM_DOUBLE_TYPE.equals(getUrn(MonitoringInfoTypeUrns.Enum.SUM_DOUBLE_TYPE)));
       checkArgument(
diff --git a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java
index 788a7ae8701..7e28a799c46 100644
--- a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java
+++ b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java
@@ -74,7 +74,6 @@ import org.apache.beam.runners.core.construction.graph.PipelineNode.PTransformNo
 import org.apache.beam.runners.core.construction.graph.ProtoOverrides;
 import org.apache.beam.runners.core.construction.graph.SplittableParDoExpander;
 import org.apache.beam.runners.core.metrics.DistributionData;
-import org.apache.beam.runners.core.metrics.ExecutionStateSampler;
 import org.apache.beam.runners.core.metrics.MonitoringInfoConstants;
 import org.apache.beam.runners.core.metrics.MonitoringInfoConstants.TypeUrns;
 import org.apache.beam.runners.core.metrics.MonitoringInfoConstants.Urns;
@@ -858,7 +857,7 @@ public class RemoteExecutionTest implements Serializable {
     public void startBundle() throws InterruptedException {
       Metrics.counter(RemoteExecutionTest.class, START_USER_COUNTER_NAME).inc(10);
       Metrics.distribution(RemoteExecutionTest.class, START_USER_DISTRIBUTION_NAME).update(10);
-      ExecutionStateSampler.instance().doSampling(1);
+      Thread.sleep(500);
     }
 
     @ProcessElement
@@ -868,7 +867,7 @@ public class RemoteExecutionTest implements Serializable {
       ctxt.output("two");
       Metrics.counter(RemoteExecutionTest.class, PROCESS_USER_COUNTER_NAME).inc();
       Metrics.distribution(RemoteExecutionTest.class, PROCESS_USER_DISTRIBUTION_NAME).update(1);
-      ExecutionStateSampler.instance().doSampling(2);
+      Thread.sleep(500);
       AFTER_PROCESS.get(uuid).countDown();
       checkState(
           ALLOW_COMPLETION.get(uuid).await(60, TimeUnit.SECONDS),
@@ -879,14 +878,15 @@ public class RemoteExecutionTest implements Serializable {
     public void finishBundle() throws InterruptedException {
       Metrics.counter(RemoteExecutionTest.class, FINISH_USER_COUNTER_NAME).inc(100);
       Metrics.distribution(RemoteExecutionTest.class, FINISH_USER_DISTRIBUTION_NAME).update(100);
-      ExecutionStateSampler.instance().doSampling(3);
+      Thread.sleep(500);
     }
   }
 
   @Test
   @SuppressWarnings("FutureReturnValueIgnored")
   public void testMetrics() throws Exception {
-    launchSdkHarness(PipelineOptionsFactory.create());
+    launchSdkHarness(
+        PipelineOptionsFactory.fromArgs("--experiments=state_sampling_period_millis=10").create());
     MetricsDoFn metricsDoFn = new MetricsDoFn();
     Pipeline p = Pipeline.create();
 
@@ -1142,7 +1142,7 @@ public class RemoteExecutionTest implements Serializable {
             matchers.add(
                 allOf(
                     MonitoringInfoMatchers.matchSetFields(builder.build()),
-                    MonitoringInfoMatchers.counterValueGreaterThanOrEqualTo(2)));
+                    MonitoringInfoMatchers.counterValueGreaterThanOrEqualTo(1)));
 
             builder = new SimpleMonitoringInfoBuilder();
             builder.setUrn(Urns.FINISH_BUNDLE_MSECS);
@@ -1151,7 +1151,7 @@ public class RemoteExecutionTest implements Serializable {
             matchers.add(
                 allOf(
                     MonitoringInfoMatchers.matchSetFields(builder.build()),
-                    MonitoringInfoMatchers.counterValueGreaterThanOrEqualTo(3)));
+                    MonitoringInfoMatchers.counterValueGreaterThanOrEqualTo(1)));
 
             List<MonitoringInfo> oldMonitoringInfos = progressMonitoringInfos.get();
             if (oldMonitoringInfos == null) {
diff --git a/sdks/java/harness/jmh/src/main/java/org/apache/beam/fn/harness/jmh/control/ExecutionStateSamplerBenchmark.java b/sdks/java/harness/jmh/src/main/java/org/apache/beam/fn/harness/jmh/control/ExecutionStateSamplerBenchmark.java
new file mode 100644
index 00000000000..0a15302f009
--- /dev/null
+++ b/sdks/java/harness/jmh/src/main/java/org/apache/beam/fn/harness/jmh/control/ExecutionStateSamplerBenchmark.java
@@ -0,0 +1,174 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.fn.harness.jmh.control;
+
+import java.io.Closeable;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+import org.apache.beam.runners.core.metrics.ExecutionStateSampler;
+import org.apache.beam.runners.core.metrics.ExecutionStateTracker;
+import org.apache.beam.runners.core.metrics.MonitoringInfoConstants.Labels;
+import org.apache.beam.runners.core.metrics.MonitoringInfoConstants.Urns;
+import org.apache.beam.runners.core.metrics.SimpleExecutionState;
+import org.apache.beam.sdk.options.PipelineOptionsFactory;
+import org.apache.beam.vendor.grpc.v1p43p2.com.google.protobuf.ByteString;
+import org.openjdk.jmh.annotations.Benchmark;
+import org.openjdk.jmh.annotations.Level;
+import org.openjdk.jmh.annotations.Scope;
+import org.openjdk.jmh.annotations.Setup;
+import org.openjdk.jmh.annotations.State;
+import org.openjdk.jmh.annotations.TearDown;
+import org.openjdk.jmh.annotations.Threads;
+
+/** Benchmarks for sampling execution state. */
+public class ExecutionStateSamplerBenchmark {
+  private static final String PTRANSFORM = "benchmarkPTransform";
+
+  @State(Scope.Benchmark)
+  public static class RunnersCoreStateSampler {
+    public final ExecutionStateSampler sampler = ExecutionStateSampler.newForTest();
+    public final ExecutionStateTracker tracker = new ExecutionStateTracker(sampler);
+    public final SimpleExecutionState state1 =
+        new SimpleExecutionState(
+            "process",
+            Urns.PROCESS_BUNDLE_MSECS,
+            new HashMap<>(Collections.singletonMap(Labels.PTRANSFORM, PTRANSFORM)));
+    public final SimpleExecutionState state2 =
+        new SimpleExecutionState(
+            "process",
+            Urns.PROCESS_BUNDLE_MSECS,
+            new HashMap<>(Collections.singletonMap(Labels.PTRANSFORM, PTRANSFORM)));
+    public final SimpleExecutionState state3 =
+        new SimpleExecutionState(
+            "process",
+            Urns.PROCESS_BUNDLE_MSECS,
+            new HashMap<>(Collections.singletonMap(Labels.PTRANSFORM, PTRANSFORM)));
+
+    @Setup(Level.Trial)
+    public void setup() {
+      sampler.start();
+    }
+
+    @TearDown(Level.Trial)
+    public void tearDown() {
+      sampler.stop();
+      // Print out the total millis so that JVM doesn't optimize code away.
+      System.out.println(
+          state1.getTotalMillis()
+              + ", "
+              + state2.getTotalMillis()
+              + ", "
+              + state3.getTotalMillis());
+    }
+  }
+
+  @State(Scope.Benchmark)
+  public static class HarnessStateSampler {
+    public final org.apache.beam.fn.harness.control.ExecutionStateSampler sampler =
+        new org.apache.beam.fn.harness.control.ExecutionStateSampler(
+            PipelineOptionsFactory.create(), System::currentTimeMillis);
+    public final org.apache.beam.fn.harness.control.ExecutionStateSampler.ExecutionStateTracker
+        tracker = sampler.create();
+    public final org.apache.beam.fn.harness.control.ExecutionStateSampler.ExecutionState state1 =
+        tracker.create("1", PTRANSFORM, PTRANSFORM + "Name", "1");
+    public final org.apache.beam.fn.harness.control.ExecutionStateSampler.ExecutionState state2 =
+        tracker.create("2", PTRANSFORM, PTRANSFORM + "Name", "2");
+    public final org.apache.beam.fn.harness.control.ExecutionStateSampler.ExecutionState state3 =
+        tracker.create("3", PTRANSFORM, PTRANSFORM + "Name", "3");
+
+    @TearDown(Level.Trial)
+    public void tearDown() {
+      sampler.stop();
+      Map<String, ByteString> monitoringData = new HashMap<>();
+      tracker.updateFinalMonitoringData(monitoringData);
+      // Print out the total millis so that JVM doesn't optimize code away.
+      System.out.println(monitoringData);
+    }
+  }
+
+  @Benchmark
+  @Threads(1)
+  public void testTinyBundleRunnersCoreStateSampler(RunnersCoreStateSampler state)
+      throws Exception {
+    state.tracker.activate();
+    for (int i = 0; i < 3; ) {
+      Closeable close1 = state.tracker.enterState(state.state1);
+      Closeable close2 = state.tracker.enterState(state.state2);
+      Closeable close3 = state.tracker.enterState(state.state3);
+      // trival code that is being sampled for this state
+      i += 1;
+      close3.close();
+      close2.close();
+      close1.close();
+    }
+    state.tracker.reset();
+  }
+
+  @Benchmark
+  @Threads(1)
+  public void testTinyBundleHarnessStateSampler(HarnessStateSampler state) throws Exception {
+    state.tracker.start("processBundleId");
+    for (int i = 0; i < 3; ) {
+      state.state1.activate();
+      state.state2.activate();
+      state.state3.activate();
+      // trival code that is being sampled for this state
+      i += 1;
+      state.state3.deactivate();
+      state.state2.deactivate();
+      state.state1.deactivate();
+    }
+    state.tracker.reset();
+  }
+
+  @Benchmark
+  @Threads(1)
+  public void testLargeBundleRunnersCoreStateSampler(RunnersCoreStateSampler state)
+      throws Exception {
+    state.tracker.activate();
+    for (int i = 0; i < 1000; ) {
+      Closeable close1 = state.tracker.enterState(state.state1);
+      Closeable close2 = state.tracker.enterState(state.state2);
+      Closeable close3 = state.tracker.enterState(state.state3);
+      // trival code that is being sampled for this state
+      i += 1;
+      close3.close();
+      close2.close();
+      close1.close();
+    }
+    state.tracker.reset();
+  }
+
+  @Benchmark
+  @Threads(1)
+  public void testLargeBundleHarnessStateSampler(HarnessStateSampler state) throws Exception {
+    state.tracker.start("processBundleId");
+    for (int i = 0; i < 1000; ) {
+      state.state1.activate();
+      state.state2.activate();
+      state.state3.activate();
+      // trival code that is being sampled for this state
+      i += 1;
+      state.state3.deactivate();
+      state.state2.deactivate();
+      state.state1.deactivate();
+    }
+    state.tracker.reset();
+  }
+}
diff --git a/sdks/java/harness/jmh/src/test/java/org/apache/beam/fn/harness/jmh/control/ExecutionStateSamplerBenchmarkTest.java b/sdks/java/harness/jmh/src/test/java/org/apache/beam/fn/harness/jmh/control/ExecutionStateSamplerBenchmarkTest.java
new file mode 100644
index 00000000000..3e735d9b59c
--- /dev/null
+++ b/sdks/java/harness/jmh/src/test/java/org/apache/beam/fn/harness/jmh/control/ExecutionStateSamplerBenchmarkTest.java
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.fn.harness.jmh.control;
+
+import org.apache.beam.fn.harness.jmh.control.ExecutionStateSamplerBenchmark.HarnessStateSampler;
+import org.apache.beam.fn.harness.jmh.control.ExecutionStateSamplerBenchmark.RunnersCoreStateSampler;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/** Tests for {@link ExecutionStateSamplerBenchmark}. */
+@RunWith(JUnit4.class)
+public class ExecutionStateSamplerBenchmarkTest {
+  @Test
+  public void testTinyBundleRunnersCoreStateSampler() throws Exception {
+    RunnersCoreStateSampler state = new RunnersCoreStateSampler();
+    state.setup();
+    new ExecutionStateSamplerBenchmark().testTinyBundleRunnersCoreStateSampler(state);
+    state.tearDown();
+  }
+
+  @Test
+  public void testLargeBundleRunnersCoreStateSampler() throws Exception {
+    RunnersCoreStateSampler state = new RunnersCoreStateSampler();
+    state.setup();
+    new ExecutionStateSamplerBenchmark().testLargeBundleRunnersCoreStateSampler(state);
+    state.tearDown();
+  }
+
+  @Test
+  public void testTinyBundleHarnessStateSampler() throws Exception {
+    HarnessStateSampler state = new HarnessStateSampler();
+    new ExecutionStateSamplerBenchmark().testTinyBundleHarnessStateSampler(state);
+    state.tearDown();
+  }
+
+  @Test
+  public void testLargeBundleHarnessStateSampler() throws Exception {
+    HarnessStateSampler state = new HarnessStateSampler();
+    new ExecutionStateSamplerBenchmark().testLargeBundleHarnessStateSampler(state);
+    state.tearDown();
+  }
+}
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnHarness.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnHarness.java
index 2ff0e298f80..d6206b0c697 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnHarness.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnHarness.java
@@ -24,6 +24,7 @@ import java.util.concurrent.ExecutorService;
 import java.util.function.Function;
 import javax.annotation.Nullable;
 import org.apache.beam.fn.harness.control.BeamFnControlClient;
+import org.apache.beam.fn.harness.control.ExecutionStateSampler;
 import org.apache.beam.fn.harness.control.FinalizeBundleHandler;
 import org.apache.beam.fn.harness.control.HarnessMonitoringInfosInstructionHandler;
 import org.apache.beam.fn.harness.control.ProcessBundleHandler;
@@ -38,7 +39,6 @@ import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleDescriptor;
 import org.apache.beam.model.fnexecution.v1.BeamFnControlGrpc;
 import org.apache.beam.model.pipeline.v1.Endpoints;
 import org.apache.beam.runners.core.construction.PipelineOptionsTranslation;
-import org.apache.beam.runners.core.metrics.ExecutionStateSampler;
 import org.apache.beam.runners.core.metrics.MetricsContainerImpl;
 import org.apache.beam.runners.core.metrics.ShortIdMap;
 import org.apache.beam.sdk.extensions.gcp.options.GcsOptions;
@@ -218,6 +218,9 @@ public class FnHarness {
     IdGenerator idGenerator = IdGenerators.decrementingLongs();
     ShortIdMap metricsShortIds = new ShortIdMap();
     ExecutorService executorService = options.as(GcsOptions.class).getExecutorService();
+    ExecutionStateSampler executionStateSampler =
+        new ExecutionStateSampler(options, System::currentTimeMillis);
+
     // The logging client variable is not used per se, but during its lifetime (until close()) it
     // intercepts logging and sends it to the logging service.
     try (BeamFnLoggingClient logging =
@@ -275,6 +278,7 @@ public class FnHarness {
               beamFnStateGrpcClientCache,
               finalizeBundleHandler,
               metricsShortIds,
+              executionStateSampler,
               processWideCache);
 
       BeamFnStatusClient beamFnStatusClient = null;
@@ -325,14 +329,6 @@ public class FnHarness {
 
       JvmInitializers.runBeforeProcessing(options);
 
-      String samplingPeriodMills =
-          ExperimentalOptions.getExperimentValue(
-              options, ExperimentalOptions.STATE_SAMPLING_PERIOD_MILLIS);
-      if (samplingPeriodMills != null) {
-        ExecutionStateSampler.setSamplingPeriod(Integer.parseInt(samplingPeriodMills));
-      }
-      ExecutionStateSampler.instance().start();
-
       LOG.info("Entering instruction processing loop");
 
       // The control client immediately dispatches requests to an executor so we execute on the
@@ -351,7 +347,7 @@ public class FnHarness {
       processBundleHandler.shutdown();
     } finally {
       System.out.println("Shutting SDK harness down.");
-      ExecutionStateSampler.instance().stop();
+      executionStateSampler.stop();
       executorService.shutdown();
     }
   }
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ExecutionStateSampler.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ExecutionStateSampler.java
new file mode 100644
index 00000000000..fa1bc186cc3
--- /dev/null
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ExecutionStateSampler.java
@@ -0,0 +1,418 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.fn.harness.control;
+
+import com.google.auto.value.AutoValue;
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.CancellationException;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.Future;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
+import java.util.concurrent.atomic.AtomicLong;
+import java.util.concurrent.atomic.AtomicReference;
+import javax.annotation.concurrent.GuardedBy;
+import org.apache.beam.fn.harness.control.ProcessBundleHandler.BundleProcessor;
+import org.apache.beam.runners.core.metrics.MonitoringInfoEncodings;
+import org.apache.beam.sdk.extensions.gcp.options.GcsOptions;
+import org.apache.beam.sdk.options.ExperimentalOptions;
+import org.apache.beam.sdk.options.PipelineOptions;
+import org.apache.beam.vendor.grpc.v1p43p2.com.google.protobuf.ByteString;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Joiner;
+import org.checkerframework.checker.nullness.qual.Nullable;
+import org.joda.time.DateTimeUtils.MillisProvider;
+import org.joda.time.Duration;
+import org.joda.time.format.PeriodFormatter;
+import org.joda.time.format.PeriodFormatterBuilder;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/** Monitors the execution of one or more execution threads. */
+public class ExecutionStateSampler {
+
+  private static final Logger LOG = LoggerFactory.getLogger(ExecutionStateSampler.class);
+  private static final int DEFAULT_SAMPLING_PERIOD_MS = 200;
+  private static final long MAX_LULL_TIME_MS = TimeUnit.MINUTES.toMillis(5);
+  private static final PeriodFormatter DURATION_FORMATTER =
+      new PeriodFormatterBuilder()
+          .appendDays()
+          .appendSuffix("d")
+          .minimumPrintedDigits(2)
+          .appendHours()
+          .appendSuffix("h")
+          .printZeroAlways()
+          .appendMinutes()
+          .appendSuffix("m")
+          .appendSeconds()
+          .appendSuffix("s")
+          .toFormatter();
+  private final int periodMs;
+  private final MillisProvider clock;
+
+  @GuardedBy("activeStateTrackers")
+  private final Set<ExecutionStateTracker> activeStateTrackers;
+
+  private final Future<Void> stateSamplingThread;
+
+  @SuppressWarnings(
+      "methodref.receiver.bound.invalid" /* Synchronization ensures proper initialization */)
+  public ExecutionStateSampler(PipelineOptions options, MillisProvider clock) {
+    String samplingPeriodMills =
+        ExperimentalOptions.getExperimentValue(
+            options, ExperimentalOptions.STATE_SAMPLING_PERIOD_MILLIS);
+    this.periodMs =
+        samplingPeriodMills == null
+            ? DEFAULT_SAMPLING_PERIOD_MS
+            : Integer.parseInt(samplingPeriodMills);
+    this.clock = clock;
+    this.activeStateTrackers = new HashSet<>();
+    // We specifically synchronize to ensure that this object can complete
+    // being published before the state sampler thread starts.
+    synchronized (this) {
+      this.stateSamplingThread =
+          options.as(GcsOptions.class).getExecutorService().submit(this::stateSampler);
+    }
+  }
+
+  /** An {@link ExecutionState} represents the current state of an execution thread. */
+  public interface ExecutionState {
+
+    /**
+     * Activates this execution state within the {@link ExecutionStateTracker}.
+     *
+     * <p>Must only be invoked by the bundle processing thread.
+     */
+    void activate();
+
+    /**
+     * Restores the previously active execution state within the {@link ExecutionStateTracker}.
+     *
+     * <p>Must only be invoked by the bundle processing thread.
+     */
+    void deactivate();
+  }
+
+  /** Stops the execution of the state sampler. */
+  public void stop() {
+    stateSamplingThread.cancel(true);
+    try {
+      stateSamplingThread.get(5L * periodMs, TimeUnit.MILLISECONDS);
+    } catch (CancellationException e) {
+      // This was expected -- we were cancelling the thread.
+    } catch (InterruptedException | TimeoutException e) {
+      throw new RuntimeException(
+          "Failed to stop state sampling after waiting 5 sampling periods.", e);
+    } catch (ExecutionException e) {
+      throw new RuntimeException("Exception in state sampler", e);
+    }
+  }
+
+  /** Entry point for the state sampling thread. */
+  private Void stateSampler() throws Exception {
+    // Ensure the object finishes being published safely.
+    synchronized (this) {
+      if (stateSamplingThread == null) {
+        throw new IllegalStateException("Underinitialized ExecutionStateSampler instance");
+      }
+    }
+
+    long lastSampleTimeMillis = clock.getMillis();
+    long targetTimeMillis = lastSampleTimeMillis + periodMs;
+    while (!Thread.interrupted()) {
+      long currentTimeMillis = clock.getMillis();
+      long difference = targetTimeMillis - currentTimeMillis;
+      if (difference > 0) {
+        Thread.sleep(difference);
+      } else {
+        long millisSinceLastSample = currentTimeMillis - lastSampleTimeMillis;
+        synchronized (activeStateTrackers) {
+          for (ExecutionStateTracker activeTracker : activeStateTrackers) {
+            activeTracker.takeSample(currentTimeMillis, millisSinceLastSample);
+          }
+        }
+        lastSampleTimeMillis = currentTimeMillis;
+        targetTimeMillis = lastSampleTimeMillis + periodMs;
+      }
+    }
+    return null;
+  }
+
+  /** Returns a new {@link ExecutionStateTracker} associated with this state sampler. */
+  public ExecutionStateTracker create() {
+    return new ExecutionStateTracker();
+  }
+
+  /** Tracks the current state of a single execution thread. */
+  public class ExecutionStateTracker implements BundleProgressReporter {
+
+    // The set of execution states that this tracker is responsible for. Effectively
+    // final since create() should not be invoked once any bundle starts processing.
+    private final List<ExecutionStateImpl> executionStates;
+    // Read by multiple threads, written by the bundle processing thread lazily.
+    private final AtomicReference<@Nullable String> processBundleId;
+    // Read by multiple threads, written by the bundle processing thread lazily.
+    private final AtomicReference<@Nullable Thread> trackedThread;
+    // Read by multiple threads, read and written by the ExecutionStateSampler thread lazily.
+    private final AtomicLong lastTransitionTime;
+    // Read and written by the bundle processing thread frequently.
+    private long numTransitions;
+    // Read by the ExecutionStateSampler, written by the bundle processing thread lazily and
+    // frequently.
+    private final AtomicLong numTransitionsLazy;
+
+    // Read by multiple threads, read and written by the bundle processing thread lazily.
+    private final AtomicReference<@Nullable ExecutionStateImpl> currentState;
+    // Read and written by the ExecutionStateSampler thread
+    private long transitionsAtLastSample;
+
+    private ExecutionStateTracker() {
+      this.executionStates = new ArrayList<>();
+      this.trackedThread = new AtomicReference<>();
+      this.lastTransitionTime = new AtomicLong();
+      this.numTransitionsLazy = new AtomicLong();
+      this.currentState = new AtomicReference<>();
+      this.processBundleId = new AtomicReference<>();
+    }
+
+    /**
+     * Returns an {@link ExecutionState} bound to this tracker for the specified transform and
+     * processing state.
+     */
+    public ExecutionState create(
+        String shortId, String ptransformId, String ptransformUniqueName, String stateName) {
+      ExecutionStateImpl newState =
+          new ExecutionStateImpl(shortId, ptransformId, ptransformUniqueName, stateName);
+      executionStates.add(newState);
+      return newState;
+    }
+
+    /**
+     * Called periodically by the {@link ExecutionStateSampler} to report time spent in this state.
+     *
+     * @param currentTimeMillis the current time.
+     * @param millisSinceLastSample the time since the last sample was reported. As an
+     *     approximation, all of that time should be associated with this state.
+     */
+    private void takeSample(long currentTimeMillis, long millisSinceLastSample) {
+      ExecutionStateImpl currentExecutionState = currentState.get();
+      if (currentExecutionState != null) {
+        currentExecutionState.takeSample(millisSinceLastSample);
+      }
+
+      long transitionsAtThisSample = numTransitionsLazy.get();
+
+      if (transitionsAtThisSample != transitionsAtLastSample) {
+        lastTransitionTime.lazySet(currentTimeMillis);
+        transitionsAtLastSample = transitionsAtThisSample;
+      } else {
+        long lullTimeMs = currentTimeMillis - lastTransitionTime.get();
+        Thread thread = trackedThread.get();
+        if (lullTimeMs > MAX_LULL_TIME_MS) {
+          if (thread == null) {
+            LOG.warn(
+                String.format(
+                    "Operation ongoing in bundle %s for at least %s without outputting or completing (stack trace unable to be generated).",
+                    processBundleId.get(),
+                    DURATION_FORMATTER.print(Duration.millis(lullTimeMs).toPeriod())));
+          } else if (currentExecutionState == null) {
+            LOG.warn(
+                String.format(
+                    "Operation ongoing in bundle %s for at least %s without outputting or completing:%n  at %s",
+                    processBundleId.get(),
+                    DURATION_FORMATTER.print(Duration.millis(lullTimeMs).toPeriod()),
+                    Joiner.on("\n  at ").join(thread.getStackTrace())));
+          } else {
+            LOG.warn(
+                String.format(
+                    "Operation ongoing in bundle %s for PTransform{id=%s, name=%s, state=%s} for at least %s without outputting or completing:%n  at %s",
+                    processBundleId.get(),
+                    currentExecutionState.ptransformId,
+                    currentExecutionState.ptransformUniqueName,
+                    currentExecutionState.stateName,
+                    DURATION_FORMATTER.print(Duration.millis(lullTimeMs).toPeriod()),
+                    Joiner.on("\n  at ").join(thread.getStackTrace())));
+          }
+        }
+      }
+    }
+
+    /** Returns status information related to this tracker or null if not tracking a bundle. */
+    public @Nullable ExecutionStateTrackerStatus getStatus() {
+      Thread thread = trackedThread.get();
+      if (thread == null) {
+        return null;
+      }
+      long lastTransitionTimeMs = lastTransitionTime.get();
+      // We are actively processing a bundle but may have not yet entered into a state.
+      ExecutionStateImpl current = currentState.get();
+      if (current != null) {
+        return ExecutionStateTrackerStatus.create(
+            current.ptransformId, current.ptransformUniqueName, thread, lastTransitionTimeMs);
+      } else {
+        return ExecutionStateTrackerStatus.create(null, null, thread, lastTransitionTimeMs);
+      }
+    }
+
+    /** {@link ExecutionState} represents the current state of an execution thread. */
+    private class ExecutionStateImpl implements ExecutionState {
+      private final String shortId;
+      private final String ptransformId;
+      private final String ptransformUniqueName;
+      private final String stateName;
+      // Read and written by the bundle processing thread frequently.
+      private long msecs;
+      // Read by the ExecutionStateSampler, written by the bundle processing thread frequently.
+      private final AtomicLong lazyMsecs;
+      /** Guarded by {@link BundleProcessor#getProgressRequestLock}. */
+      private boolean hasReportedValue;
+      /** Guarded by {@link BundleProcessor#getProgressRequestLock}. */
+      private long lastReportedValue;
+      // Read and written by the bundle processing thread frequently.
+      private @Nullable ExecutionStateImpl previousState;
+
+      private ExecutionStateImpl(
+          String shortId, String ptransformId, String ptransformName, String stateName) {
+        this.shortId = shortId;
+        this.ptransformId = ptransformId;
+        this.ptransformUniqueName = ptransformName;
+        this.stateName = stateName;
+        this.lazyMsecs = new AtomicLong();
+      }
+
+      /**
+       * Called periodically by the {@link ExecutionStateTracker} to report time spent in this
+       * state.
+       *
+       * @param millisSinceLastSample the time since the last sample was reported. As an
+       *     approximation, all of that time should be associated with this state.
+       */
+      public void takeSample(long millisSinceLastSample) {
+        msecs += millisSinceLastSample;
+        lazyMsecs.set(msecs);
+      }
+
+      /** Updates the monitoring data for this {@link ExecutionState}. */
+      public void updateMonitoringData(Map<String, ByteString> monitoringData) {
+        long msecsReads = lazyMsecs.get();
+        if (hasReportedValue && lastReportedValue == msecsReads) {
+          return;
+        }
+        monitoringData.put(shortId, MonitoringInfoEncodings.encodeInt64Counter(msecsReads));
+        lastReportedValue = msecsReads;
+        hasReportedValue = true;
+      }
+
+      public void reset() {
+        if (hasReportedValue) {
+          msecs = 0;
+          lazyMsecs.set(0);
+          lastReportedValue = 0;
+        }
+      }
+
+      @Override
+      public void activate() {
+        previousState = currentState.get();
+        currentState.lazySet(this);
+        numTransitions += 1;
+        numTransitionsLazy.lazySet(numTransitions);
+      }
+
+      @Override
+      public void deactivate() {
+        currentState.lazySet(previousState);
+        previousState = null;
+
+        numTransitions += 1;
+        numTransitionsLazy.lazySet(numTransitions);
+      }
+    }
+
+    /**
+     * Starts tracking execution states for specified {@code processBundleId}.
+     *
+     * <p>Only invoked by the bundle processing thread.
+     */
+    public void start(String processBundleId) {
+      this.processBundleId.lazySet(processBundleId);
+      this.lastTransitionTime.lazySet(clock.getMillis());
+      this.trackedThread.lazySet(Thread.currentThread());
+      synchronized (activeStateTrackers) {
+        activeStateTrackers.add(this);
+      }
+    }
+
+    @Override
+    public void updateIntermediateMonitoringData(Map<String, ByteString> monitoringData) {
+      for (ExecutionStateImpl executionState : executionStates) {
+        executionState.updateMonitoringData(monitoringData);
+      }
+    }
+
+    @Override
+    public void updateFinalMonitoringData(Map<String, ByteString> monitoringData) {
+      for (ExecutionStateImpl executionState : executionStates) {
+        executionState.updateMonitoringData(monitoringData);
+      }
+    }
+
+    /**
+     * Stops tracking execution states allowing for the {@link ExecutionStateTracker} to be re-used
+     * for another bundle.
+     */
+    @Override
+    public void reset() {
+      synchronized (activeStateTrackers) {
+        activeStateTrackers.remove(this);
+        for (ExecutionStateImpl executionState : executionStates) {
+          executionState.reset();
+        }
+        this.transitionsAtLastSample = 0;
+      }
+      this.processBundleId.lazySet(null);
+      this.trackedThread.lazySet(null);
+      this.numTransitions = 0;
+      this.numTransitionsLazy.lazySet(0);
+      this.lastTransitionTime.lazySet(0);
+    }
+  }
+
+  @AutoValue
+  public abstract static class ExecutionStateTrackerStatus {
+    public static ExecutionStateTrackerStatus create(
+        @Nullable String ptransformId,
+        @Nullable String ptransformUniqueName,
+        Thread trackedThread,
+        long lastTransitionTimeMs) {
+      return new AutoValue_ExecutionStateSampler_ExecutionStateTrackerStatus(
+          ptransformId, ptransformUniqueName, trackedThread, lastTransitionTimeMs);
+    }
+
+    public abstract @Nullable String getPTransformId();
+
+    public abstract @Nullable String getPTransformUniqueName();
+
+    public abstract Thread getTrackedThread();
+
+    public abstract long getLastTransitionTimeMillis();
+  }
+}
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java
index 8467a57c179..9db3ab2785d 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java
@@ -18,7 +18,6 @@
 package org.apache.beam.fn.harness.control;
 
 import com.google.auto.value.AutoValue;
-import java.io.Closeable;
 import java.io.IOException;
 import java.time.Duration;
 import java.util.ArrayList;
@@ -48,6 +47,7 @@ import org.apache.beam.fn.harness.Caches.ClearableCache;
 import org.apache.beam.fn.harness.PTransformRunnerFactory;
 import org.apache.beam.fn.harness.PTransformRunnerFactory.Context;
 import org.apache.beam.fn.harness.PTransformRunnerFactory.Registrar;
+import org.apache.beam.fn.harness.control.ExecutionStateSampler.ExecutionStateTracker;
 import org.apache.beam.fn.harness.control.FinalizeBundleHandler.CallbackRegistration;
 import org.apache.beam.fn.harness.data.BeamFnDataClient;
 import org.apache.beam.fn.harness.data.PCollectionConsumerRegistry;
@@ -74,9 +74,8 @@ import org.apache.beam.model.pipeline.v1.RunnerApi.WindowingStrategy;
 import org.apache.beam.runners.core.construction.BeamUrns;
 import org.apache.beam.runners.core.construction.PTransformTranslation;
 import org.apache.beam.runners.core.construction.Timer;
-import org.apache.beam.runners.core.metrics.ExecutionStateSampler;
-import org.apache.beam.runners.core.metrics.ExecutionStateTracker;
 import org.apache.beam.runners.core.metrics.MetricsContainerStepMap;
+import org.apache.beam.runners.core.metrics.MonitoringInfoConstants.Urns;
 import org.apache.beam.runners.core.metrics.ShortIdMap;
 import org.apache.beam.sdk.fn.data.BeamFnDataInboundObserver2;
 import org.apache.beam.sdk.fn.data.BeamFnDataOutboundAggregator;
@@ -156,6 +155,7 @@ public class ProcessBundleHandler {
   private final FinalizeBundleHandler finalizeBundleHandler;
   private final ShortIdMap shortIds;
   private final boolean runnerAcceptsShortIds;
+  private final ExecutionStateSampler executionStateSampler;
   private final Map<String, PTransformRunnerFactory> urnToPTransformRunnerFactoryMap;
   private final PTransformRunnerFactory defaultPTransformRunnerFactory;
   private final Cache<Object, Object> processWideCache;
@@ -170,6 +170,7 @@ public class ProcessBundleHandler {
       BeamFnStateGrpcClientCache beamFnStateGrpcClientCache,
       FinalizeBundleHandler finalizeBundleHandler,
       ShortIdMap shortIds,
+      ExecutionStateSampler executionStateSampler,
       Cache<Object, Object> processWideCache) {
     this(
         options,
@@ -179,6 +180,7 @@ public class ProcessBundleHandler {
         beamFnStateGrpcClientCache,
         finalizeBundleHandler,
         shortIds,
+        executionStateSampler,
         REGISTERED_RUNNER_FACTORIES,
         processWideCache,
         new BundleProcessorCache());
@@ -193,6 +195,7 @@ public class ProcessBundleHandler {
       BeamFnStateGrpcClientCache beamFnStateGrpcClientCache,
       FinalizeBundleHandler finalizeBundleHandler,
       ShortIdMap shortIds,
+      ExecutionStateSampler executionStateSampler,
       Map<String, PTransformRunnerFactory> urnToPTransformRunnerFactoryMap,
       Cache<Object, Object> processWideCache,
       BundleProcessorCache bundleProcessorCache) {
@@ -206,6 +209,7 @@ public class ProcessBundleHandler {
     this.runnerAcceptsShortIds =
         runnerCapabilities.contains(
             BeamUrns.getUrn(RunnerApi.StandardRunnerProtocols.Enum.MONITORING_INFO_SHORT_IDS));
+    this.executionStateSampler = executionStateSampler;
     this.urnToPTransformRunnerFactoryMap = urnToPTransformRunnerFactoryMap;
     this.defaultPTransformRunnerFactory =
         new UnknownPTransformRunnerFactory(urnToPTransformRunnerFactoryMap.keySet());
@@ -366,7 +370,8 @@ public class ProcessBundleHandler {
                     @Override
                     public <T> void addPCollectionConsumer(
                         String pCollectionId, FnDataReceiver<WindowedValue<T>> consumer) {
-                      pCollectionConsumerRegistry.register(pCollectionId, pTransformId, consumer);
+                      pCollectionConsumerRegistry.register(
+                          pCollectionId, pTransformId, pTransform.getUniqueName(), consumer);
                     }
 
                     @Override
@@ -420,12 +425,14 @@ public class ProcessBundleHandler {
 
                     @Override
                     public void addStartBundleFunction(ThrowingRunnable startFunction) {
-                      startFunctionRegistry.register(pTransformId, startFunction);
+                      startFunctionRegistry.register(
+                          pTransformId, pTransform.getUniqueName(), startFunction);
                     }
 
                     @Override
                     public void addFinishBundleFunction(ThrowingRunnable finishFunction) {
-                      finishFunctionRegistry.register(pTransformId, finishFunction);
+                      finishFunctionRegistry.register(
+                          pTransformId, pTransform.getUniqueName(), finishFunction);
                     }
 
                     @Override
@@ -506,7 +513,8 @@ public class ProcessBundleHandler {
       ExecutionStateTracker stateTracker = bundleProcessor.getStateTracker();
 
       try (HandleStateCallsForBundle beamFnStateClient = bundleProcessor.getBeamFnStateClient()) {
-        try (Closeable closeTracker = stateTracker.activate()) {
+        stateTracker.start(request.getInstructionId());
+        try {
           // Already in reverse topological order so we don't need to do anything.
           for (ThrowingRunnable startFunction : startFunctionRegistry.getFunctions()) {
             LOG.debug("Starting function {}", startFunction);
@@ -542,30 +550,35 @@ public class ProcessBundleHandler {
             LOG.debug("Finishing function {}", finishFunction);
             finishFunction.run();
           }
-        }
-
-        // If bundleProcessor has not flushed any elements, embed them in response.
-        embedOutboundElementsIfApplicable(response, bundleProcessor);
-
-        // Add all checkpointed residuals to the response.
-        response.addAllResidualRoots(bundleProcessor.getSplitListener().getResidualRoots());
 
-        // Add all metrics to the response.
-        Map<String, ByteString> monitoringData = finalMonitoringData(bundleProcessor);
-        if (runnerAcceptsShortIds) {
-          response.putAllMonitoringData(monitoringData);
-        } else {
-          for (Map.Entry<String, ByteString> metric : monitoringData.entrySet()) {
-            response.addMonitoringInfos(
-                shortIds.get(metric.getKey()).toBuilder().setPayload(metric.getValue()));
+          // If bundleProcessor has not flushed any elements, embed them in response.
+          embedOutboundElementsIfApplicable(response, bundleProcessor);
+
+          // Add all checkpointed residuals to the response.
+          response.addAllResidualRoots(bundleProcessor.getSplitListener().getResidualRoots());
+
+          // Add all metrics to the response.
+          bundleProcessor.getProgressRequestLock().lock();
+          Map<String, ByteString> monitoringData = finalMonitoringData(bundleProcessor);
+          if (runnerAcceptsShortIds) {
+            response.putAllMonitoringData(monitoringData);
+          } else {
+            for (Map.Entry<String, ByteString> metric : monitoringData.entrySet()) {
+              response.addMonitoringInfos(
+                  shortIds.get(metric.getKey()).toBuilder().setPayload(metric.getValue()));
+            }
           }
-        }
 
-        if (!bundleProcessor.getBundleFinalizationCallbackRegistrations().isEmpty()) {
-          finalizeBundleHandler.registerCallbacks(
-              bundleProcessor.getInstructionId(),
-              ImmutableList.copyOf(bundleProcessor.getBundleFinalizationCallbackRegistrations()));
-          response.setRequiresFinalization(true);
+          if (!bundleProcessor.getBundleFinalizationCallbackRegistrations().isEmpty()) {
+            finalizeBundleHandler.registerCallbacks(
+                bundleProcessor.getInstructionId(),
+                ImmutableList.copyOf(bundleProcessor.getBundleFinalizationCallbackRegistrations()));
+            response.setRequiresFinalization(true);
+          }
+        } finally {
+          // We specifically deactivate state tracking while we are holding the progress request and
+          // sampling locks.
+          stateTracker.reset();
         }
       }
 
@@ -668,15 +681,6 @@ public class ProcessBundleHandler {
   private Map<String, ByteString> intermediateMonitoringData(BundleProcessor bundleProcessor)
       throws Exception {
     Map<String, ByteString> monitoringData = new HashMap<>();
-    // Get start bundle Execution Time Metrics.
-    monitoringData.putAll(
-        bundleProcessor.getStartFunctionRegistry().getExecutionTimeMonitoringData(shortIds));
-    // Get process bundle Execution Time Metrics.
-    monitoringData.putAll(
-        bundleProcessor.getpCollectionConsumerRegistry().getExecutionTimeMonitoringData(shortIds));
-    // Get finish bundle Execution Time Metrics.
-    monitoringData.putAll(
-        bundleProcessor.getFinishFunctionRegistry().getExecutionTimeMonitoringData(shortIds));
     // Extract MonitoringInfos that come from the metrics container registry.
     monitoringData.putAll(
         bundleProcessor.getMetricsContainerRegistry().getMonitoringData(shortIds));
@@ -689,17 +693,7 @@ public class ProcessBundleHandler {
 
   private Map<String, ByteString> finalMonitoringData(BundleProcessor bundleProcessor)
       throws Exception {
-    bundleProcessor.getProgressRequestLock().lock();
     HashMap<String, ByteString> monitoringData = new HashMap<>();
-    // Get start bundle Execution Time Metrics.
-    monitoringData.putAll(
-        bundleProcessor.getStartFunctionRegistry().getExecutionTimeMonitoringData(shortIds));
-    // Get process bundle Execution Time Metrics.
-    monitoringData.putAll(
-        bundleProcessor.getpCollectionConsumerRegistry().getExecutionTimeMonitoringData(shortIds));
-    // Get finish bundle Execution Time Metrics.
-    monitoringData.putAll(
-        bundleProcessor.getFinishFunctionRegistry().getExecutionTimeMonitoringData(shortIds));
     // Extract MonitoringInfos that come from the metrics container registry.
     monitoringData.putAll(
         bundleProcessor.getMetricsContainerRegistry().getMonitoringData(shortIds));
@@ -744,8 +738,8 @@ public class ProcessBundleHandler {
     BundleProgressReporter.InMemory bundleProgressReporterAndRegistrar =
         new BundleProgressReporter.InMemory();
     MetricsContainerStepMap metricsContainerRegistry = new MetricsContainerStepMap();
-    ExecutionStateTracker stateTracker =
-        new ExecutionStateTracker(ExecutionStateSampler.instance());
+    ExecutionStateTracker stateTracker = executionStateSampler.create();
+    bundleProgressReporterAndRegistrar.register(stateTracker);
     PCollectionConsumerRegistry pCollectionConsumerRegistry =
         new PCollectionConsumerRegistry(
             metricsContainerRegistry,
@@ -757,10 +751,10 @@ public class ProcessBundleHandler {
 
     PTransformFunctionRegistry startFunctionRegistry =
         new PTransformFunctionRegistry(
-            metricsContainerRegistry, stateTracker, ExecutionStateTracker.START_STATE_NAME);
+            metricsContainerRegistry, shortIds, stateTracker, Urns.START_BUNDLE_MSECS);
     PTransformFunctionRegistry finishFunctionRegistry =
         new PTransformFunctionRegistry(
-            metricsContainerRegistry, stateTracker, ExecutionStateTracker.FINISH_STATE_NAME);
+            metricsContainerRegistry, shortIds, stateTracker, Urns.FINISH_BUNDLE_MSECS);
     List<ThrowingRunnable> resetFunctions = new ArrayList<>();
     List<ThrowingRunnable> tearDownFunctions = new ArrayList<>();
 
@@ -1118,10 +1112,7 @@ public class ProcessBundleHandler {
           this.bundleCache = null;
         }
       }
-      getStartFunctionRegistry().reset();
-      getFinishFunctionRegistry().reset();
       getSplitListener().clear();
-      getpCollectionConsumerRegistry().reset();
       getMetricsContainerRegistry().reset();
       getStateTracker().reset();
       getBundleFinalizationCallbackRegistrations().clear();
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistry.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistry.java
index b222e829185..34d0967b85b 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistry.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistry.java
@@ -27,13 +27,15 @@ import java.util.Map;
 import java.util.Random;
 import org.apache.beam.fn.harness.HandlesSplits;
 import org.apache.beam.fn.harness.control.BundleProgressReporter;
+import org.apache.beam.fn.harness.control.ExecutionStateSampler.ExecutionState;
+import org.apache.beam.fn.harness.control.ExecutionStateSampler.ExecutionStateTracker;
 import org.apache.beam.fn.harness.control.Metrics;
 import org.apache.beam.fn.harness.control.Metrics.BundleCounter;
 import org.apache.beam.fn.harness.control.Metrics.BundleDistribution;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleDescriptor;
+import org.apache.beam.model.pipeline.v1.MetricsApi.MonitoringInfo;
 import org.apache.beam.model.pipeline.v1.RunnerApi;
 import org.apache.beam.runners.core.construction.RehydratedComponents;
-import org.apache.beam.runners.core.metrics.ExecutionStateTracker;
 import org.apache.beam.runners.core.metrics.MetricsContainerStepMap;
 import org.apache.beam.runners.core.metrics.MonitoringInfoConstants;
 import org.apache.beam.runners.core.metrics.MonitoringInfoConstants.Labels;
@@ -41,9 +43,7 @@ import org.apache.beam.runners.core.metrics.MonitoringInfoConstants.TypeUrns;
 import org.apache.beam.runners.core.metrics.MonitoringInfoConstants.Urns;
 import org.apache.beam.runners.core.metrics.MonitoringInfoMetricName;
 import org.apache.beam.runners.core.metrics.ShortIdMap;
-import org.apache.beam.runners.core.metrics.SimpleExecutionState;
 import org.apache.beam.runners.core.metrics.SimpleMonitoringInfoBuilder;
-import org.apache.beam.runners.core.metrics.SimpleStateRegistry;
 import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.fn.data.FnDataReceiver;
 import org.apache.beam.sdk.metrics.Distribution;
@@ -52,7 +52,6 @@ import org.apache.beam.sdk.metrics.MetricsEnvironment;
 import org.apache.beam.sdk.util.WindowedValue;
 import org.apache.beam.sdk.util.WindowedValue.WindowedValueCoder;
 import org.apache.beam.sdk.util.common.ElementByteSizeObserver;
-import org.apache.beam.vendor.grpc.v1p43p2.com.google.protobuf.ByteString;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
 
 /**
@@ -75,7 +74,7 @@ public class PCollectionConsumerRegistry {
     public static ConsumerAndMetadata forConsumer(
         FnDataReceiver consumer,
         String pTransformId,
-        SimpleExecutionState state,
+        ExecutionState state,
         MetricsContainer metricsContainer) {
       return new AutoValue_PCollectionConsumerRegistry_ConsumerAndMetadata(
           consumer, pTransformId, state, metricsContainer);
@@ -85,7 +84,7 @@ public class PCollectionConsumerRegistry {
 
     public abstract String getPTransformId();
 
-    public abstract SimpleExecutionState getExecutionState();
+    public abstract ExecutionState getExecutionState();
 
     public abstract MetricsContainer getMetricsContainer();
   }
@@ -95,7 +94,6 @@ public class PCollectionConsumerRegistry {
   private final ShortIdMap shortIdMap;
   private final Map<String, List<ConsumerAndMetadata>> pCollectionIdsToConsumers;
   private final Map<String, FnDataReceiver> pCollectionIdsToWrappedConsumer;
-  private final SimpleStateRegistry executionStates;
   private final BundleProgressReporter.Registrar bundleProgressReporterRegistrar;
   private final ProcessBundleDescriptor processBundleDescriptor;
   private final RehydratedComponents rehydratedComponents;
@@ -111,7 +109,6 @@ public class PCollectionConsumerRegistry {
     this.shortIdMap = shortIdMap;
     this.pCollectionIdsToConsumers = new HashMap<>();
     this.pCollectionIdsToWrappedConsumer = new HashMap<>();
-    this.executionStates = new SimpleStateRegistry();
     this.bundleProgressReporterRegistrar = bundleProgressReporterRegistrar;
     this.processBundleDescriptor = processBundleDescriptor;
     this.rehydratedComponents =
@@ -133,13 +130,17 @@ public class PCollectionConsumerRegistry {
    *
    * @param pCollectionId
    * @param pTransformId
+   * @param pTransformUniqueName
    * @param consumer
    * @param <T> the element type of the PCollection
    * @throws RuntimeException if {@code register()} is called after {@code
    *     getMultiplexingConsumer()} is called.
    */
   public <T> void register(
-      String pCollectionId, String pTransformId, FnDataReceiver<WindowedValue<T>> consumer) {
+      String pCollectionId,
+      String pTransformId,
+      String pTransformUniqueName,
+      FnDataReceiver<WindowedValue<T>> consumer) {
     // Just save these consumers for now, but package them up later with an
     // ElementCountFnDataReceiver and possibly a MultiplexingFnDataReceiver
     // if there are multiple consumers.
@@ -149,20 +150,35 @@ public class PCollectionConsumerRegistry {
               + "calling getMultiplexingConsumer.");
     }
 
-    HashMap<String, String> labelsMetadata = new HashMap<>();
-    labelsMetadata.put(MonitoringInfoConstants.Labels.PTRANSFORM, pTransformId);
-    SimpleExecutionState state =
-        new SimpleExecutionState(
-            ExecutionStateTracker.PROCESS_STATE_NAME,
-            MonitoringInfoConstants.Urns.PROCESS_BUNDLE_MSECS,
-            labelsMetadata);
-    executionStates.register(state);
+    SimpleMonitoringInfoBuilder builder = new SimpleMonitoringInfoBuilder();
+    builder.setUrn(MonitoringInfoConstants.Urns.PROCESS_BUNDLE_MSECS);
+    builder.setType(MonitoringInfoConstants.TypeUrns.SUM_INT64_TYPE);
+    builder.setLabel(MonitoringInfoConstants.Labels.PTRANSFORM, pTransformId);
+    MonitoringInfo mi = builder.build();
+    if (mi == null) {
+      throw new IllegalStateException(
+          String.format(
+              "Unable to construct %s counter for PTransform {id=%s, name=%s}",
+              MonitoringInfoConstants.Urns.PROCESS_BUNDLE_MSECS,
+              pTransformId,
+              pTransformUniqueName));
+    }
+    String shortId = shortIdMap.getOrCreateShortId(mi);
+    ExecutionState executionState =
+        stateTracker.create(
+            shortId,
+            pTransformId,
+            pTransformUniqueName,
+            org.apache.beam.runners.core.metrics.ExecutionStateTracker.PROCESS_STATE_NAME);
 
     List<ConsumerAndMetadata> consumerAndMetadatas =
         pCollectionIdsToConsumers.computeIfAbsent(pCollectionId, (unused) -> new ArrayList<>());
     consumerAndMetadatas.add(
         ConsumerAndMetadata.forConsumer(
-            consumer, pTransformId, state, metricsContainerRegistry.getContainer(pTransformId)));
+            consumer,
+            pTransformId,
+            executionState,
+            metricsContainerRegistry.getContainer(pTransformId)));
   }
 
   /**
@@ -215,16 +231,6 @@ public class PCollectionConsumerRegistry {
         });
   }
 
-  /** @return Execution Time Monitoring data based on the tracked start or finish function. */
-  public Map<String, ByteString> getExecutionTimeMonitoringData(ShortIdMap shortIds) {
-    return executionStates.getExecutionTimeMonitoringData(shortIds);
-  }
-
-  /** Reset the execution states of the registered functions. */
-  public void reset() {
-    executionStates.reset();
-  }
-
   /**
    * A wrapping {@code FnDataReceiver<WindowedValue<T>>} which counts the number of elements
    * consumed by the original {@code FnDataReceiver<WindowedValue<T>> consumer} and sets up metrics
@@ -234,7 +240,7 @@ public class PCollectionConsumerRegistry {
    */
   private class MetricTrackingFnDataReceiver<T> implements FnDataReceiver<WindowedValue<T>> {
     private final FnDataReceiver<WindowedValue<T>> delegate;
-    private final SimpleExecutionState state;
+    private final ExecutionState state;
     private final BundleCounter elementCountCounter;
     private final SampleByteSizeDistribution<T> sampledByteSizeDistribution;
     private final Coder<T> coder;
@@ -293,8 +299,11 @@ public class PCollectionConsumerRegistry {
       // created. Also use the ExecutionStateTracker and enter an appropriate state to track the
       // Process Bundle Execution time metric.
       try (Closeable closeable = MetricsEnvironment.scopedMetricsContainer(metricsContainer)) {
-        try (Closeable trackerCloseable = stateTracker.enterState(state)) {
+        state.activate();
+        try {
           this.delegate.accept(input);
+        } finally {
+          state.deactivate();
         }
       }
       this.sampledByteSizeDistribution.finishLazyUpdate();
@@ -369,9 +378,12 @@ public class PCollectionConsumerRegistry {
 
         try (Closeable closeable =
             MetricsEnvironment.scopedMetricsContainer(consumerAndMetadata.getMetricsContainer())) {
-          try (Closeable trackerCloseable =
-              stateTracker.enterState(consumerAndMetadata.getExecutionState())) {
+          ExecutionState state = consumerAndMetadata.getExecutionState();
+          state.activate();
+          try {
             consumerAndMetadata.getConsumer().accept(input);
+          } finally {
+            state.deactivate();
           }
         }
         this.sampledByteSizeDistribution.finishLazyUpdate();
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/PTransformFunctionRegistry.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/PTransformFunctionRegistry.java
index 47b2254c2e7..2b13e02c610 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/PTransformFunctionRegistry.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/PTransformFunctionRegistry.java
@@ -19,20 +19,18 @@ package org.apache.beam.fn.harness.data;
 
 import java.io.Closeable;
 import java.util.ArrayList;
-import java.util.HashMap;
 import java.util.List;
-import java.util.Map;
+import org.apache.beam.fn.harness.control.ExecutionStateSampler.ExecutionState;
+import org.apache.beam.fn.harness.control.ExecutionStateSampler.ExecutionStateTracker;
 import org.apache.beam.model.pipeline.v1.MetricsApi.MonitoringInfo;
-import org.apache.beam.runners.core.metrics.ExecutionStateTracker;
 import org.apache.beam.runners.core.metrics.MetricsContainerImpl;
 import org.apache.beam.runners.core.metrics.MetricsContainerStepMap;
 import org.apache.beam.runners.core.metrics.MonitoringInfoConstants;
+import org.apache.beam.runners.core.metrics.MonitoringInfoConstants.Urns;
 import org.apache.beam.runners.core.metrics.ShortIdMap;
-import org.apache.beam.runners.core.metrics.SimpleExecutionState;
-import org.apache.beam.runners.core.metrics.SimpleStateRegistry;
+import org.apache.beam.runners.core.metrics.SimpleMonitoringInfoBuilder;
 import org.apache.beam.sdk.function.ThrowingRunnable;
 import org.apache.beam.sdk.metrics.MetricsEnvironment;
-import org.apache.beam.vendor.grpc.v1p43p2.com.google.protobuf.ByteString;
 
 /**
  * A class to to register and retrieve functions for bundle processing (i.e. the start, or finish
@@ -61,10 +59,11 @@ import org.apache.beam.vendor.grpc.v1p43p2.com.google.protobuf.ByteString;
 public class PTransformFunctionRegistry {
 
   private MetricsContainerStepMap metricsContainerRegistry;
-  private ExecutionStateTracker stateTracker;
-  private String executionStateName;
-  private List<ThrowingRunnable> runnables = new ArrayList<>();
-  private SimpleStateRegistry executionStates = new SimpleStateRegistry();
+  private final ExecutionStateTracker stateTracker;
+  private final String executionStateUrn;
+  private final ShortIdMap shortIds;
+  private final List<ThrowingRunnable> runnables = new ArrayList<>();
+  private final String stateName;
 
   /**
    * Construct the registry to run for either start or finish bundle functions.
@@ -72,14 +71,26 @@ public class PTransformFunctionRegistry {
    * @param metricsContainerRegistry - Used to enable a metric container to properly account for the
    *     pTransform in user metrics.
    * @param stateTracker - The tracker to enter states in order to calculate execution time metrics.
-   * @param executionStateName - The state name for the state .
+   * @param executionStateUrn - The URN for the execution state .
    */
   public PTransformFunctionRegistry(
       MetricsContainerStepMap metricsContainerRegistry,
+      ShortIdMap shortIds,
       ExecutionStateTracker stateTracker,
-      String executionStateName) {
+      String executionStateUrn) {
+    switch (executionStateUrn) {
+      case Urns.START_BUNDLE_MSECS:
+        stateName = org.apache.beam.runners.core.metrics.ExecutionStateTracker.START_STATE_NAME;
+        break;
+      case Urns.FINISH_BUNDLE_MSECS:
+        stateName = org.apache.beam.runners.core.metrics.ExecutionStateTracker.FINISH_STATE_NAME;
+        break;
+      default:
+        throw new IllegalArgumentException(String.format("Unknown URN %s", executionStateUrn));
+    }
     this.metricsContainerRegistry = metricsContainerRegistry;
-    this.executionStateName = executionStateName;
+    this.shortIds = shortIds;
+    this.executionStateUrn = executionStateUrn;
     this.stateTracker = stateTracker;
   }
 
@@ -87,49 +98,42 @@ public class PTransformFunctionRegistry {
    * Register the runnable to process the specific pTransformId and track its execution time.
    *
    * @param pTransformId
+   * @param pTransformUniqueName
    * @param runnable
    */
-  public void register(String pTransformId, ThrowingRunnable runnable) {
-    HashMap<String, String> labelsMetadata = new HashMap<String, String>();
-    labelsMetadata.put(MonitoringInfoConstants.Labels.PTRANSFORM, pTransformId);
-    String executionTimeUrn = "";
-    if (executionStateName.equals(ExecutionStateTracker.START_STATE_NAME)) {
-      executionTimeUrn = MonitoringInfoConstants.Urns.START_BUNDLE_MSECS;
-    } else if (executionStateName.equals(ExecutionStateTracker.FINISH_STATE_NAME)) {
-      executionTimeUrn = MonitoringInfoConstants.Urns.FINISH_BUNDLE_MSECS;
+  public void register(
+      String pTransformId, String pTransformUniqueName, ThrowingRunnable runnable) {
+    SimpleMonitoringInfoBuilder miBuilder = new SimpleMonitoringInfoBuilder();
+    miBuilder.setUrn(executionStateUrn);
+    miBuilder.setType(MonitoringInfoConstants.TypeUrns.SUM_INT64_TYPE);
+    miBuilder.setLabel(MonitoringInfoConstants.Labels.PTRANSFORM, pTransformId);
+    MonitoringInfo mi = miBuilder.build();
+    if (mi == null) {
+      throw new IllegalStateException(
+          String.format(
+              "Unable to construct %s counter for PTransform {id=%s, name=%s}",
+              executionStateUrn, pTransformId, pTransformUniqueName));
     }
+    String shortId = shortIds.getOrCreateShortId(mi);
+    ExecutionState executionState =
+        stateTracker.create(shortId, pTransformId, pTransformUniqueName, stateName);
 
-    SimpleExecutionState state =
-        new SimpleExecutionState(this.executionStateName, executionTimeUrn, labelsMetadata);
-    executionStates.register(state);
     MetricsContainerImpl container = metricsContainerRegistry.getContainer(pTransformId);
 
     ThrowingRunnable wrapped =
         () -> {
           try (Closeable metricCloseable = MetricsEnvironment.scopedMetricsContainer(container)) {
-            try (Closeable trackerCloseable = this.stateTracker.enterState(state)) {
+            executionState.activate();
+            try {
               runnable.run();
+            } finally {
+              executionState.deactivate();
             }
           }
         };
     runnables.add(wrapped);
   }
 
-  /** Reset the execution states of the registered functions. */
-  public void reset() {
-    executionStates.reset();
-  }
-
-  /** @return Execution Time MonitoringInfos based on the tracked start or finish function. */
-  public List<MonitoringInfo> getExecutionTimeMonitoringInfos() {
-    return executionStates.getExecutionTimeMonitoringInfos();
-  }
-
-  /** @return Execution Time MonitoringInfos based on the tracked start or finish function. */
-  public Map<String, ByteString> getExecutionTimeMonitoringData(ShortIdMap shortIds) {
-    return executionStates.getExecutionTimeMonitoringData(shortIds);
-  }
-
   /**
    * @return A list of wrapper functions which will invoke the registered functions indirectly. The
    *     order of registry is maintained.
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/status/BeamFnStatusClient.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/status/BeamFnStatusClient.java
index b0f7b182f29..216430956f5 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/status/BeamFnStatusClient.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/status/BeamFnStatusClient.java
@@ -29,6 +29,7 @@ import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.TimeUnit;
 import java.util.function.Function;
 import org.apache.beam.fn.harness.Cache;
+import org.apache.beam.fn.harness.control.ExecutionStateSampler.ExecutionStateTrackerStatus;
 import org.apache.beam.fn.harness.control.ProcessBundleHandler.BundleProcessor;
 import org.apache.beam.fn.harness.control.ProcessBundleHandler.BundleProcessorCache;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi;
@@ -36,12 +37,12 @@ import org.apache.beam.model.fnexecution.v1.BeamFnApi.WorkerStatusRequest;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.WorkerStatusResponse;
 import org.apache.beam.model.fnexecution.v1.BeamFnWorkerStatusGrpc;
 import org.apache.beam.model.pipeline.v1.Endpoints.ApiServiceDescriptor;
-import org.apache.beam.runners.core.metrics.ExecutionStateTracker;
 import org.apache.beam.sdk.options.PipelineOptions;
 import org.apache.beam.vendor.grpc.v1p43p2.io.grpc.ManagedChannel;
 import org.apache.beam.vendor.grpc.v1p43p2.io.grpc.stub.StreamObserver;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting;
 import org.checkerframework.checker.nullness.qual.Nullable;
+import org.joda.time.DateTimeUtils;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -203,20 +204,19 @@ public class BeamFnStatusClient implements AutoCloseable {
       activeBundlesState.add("No active processing bundles.");
     } else {
       List<BundleState> bundleStates = new ArrayList<>();
-      processBundleCache.getActiveBundleProcessors().keySet().stream()
+      processBundleCache.getActiveBundleProcessors().entrySet().stream()
           .forEach(
-              instruction -> {
-                BundleProcessor bundleProcessor = processBundleCache.find(instruction);
-                if (bundleProcessor != null) {
-                  ExecutionStateTracker executionStateTracker = bundleProcessor.getStateTracker();
-                  Thread trackedTread = executionStateTracker.getTrackedThread();
-                  if (trackedTread != null) {
-                    bundleStates.add(
-                        new BundleState(
-                            instruction,
-                            trackedTread.getName(),
-                            executionStateTracker.getMillisSinceLastTransition()));
-                  }
+              instructionAndBundleProcessor -> {
+                BundleProcessor bundleProcessor = instructionAndBundleProcessor.getValue();
+                ExecutionStateTrackerStatus executionStateTrackerStatus =
+                    bundleProcessor.getStateTracker().getStatus();
+                if (executionStateTrackerStatus != null) {
+                  bundleStates.add(
+                      new BundleState(
+                          instructionAndBundleProcessor.getKey(),
+                          executionStateTrackerStatus.getTrackedThread().getName(),
+                          DateTimeUtils.currentTimeMillis()
+                              - executionStateTrackerStatus.getLastTransitionTimeMillis()));
                 }
               });
       bundleStates.stream()
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ExecutionStateSamplerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ExecutionStateSamplerTest.java
new file mode 100644
index 00000000000..558eb004459
--- /dev/null
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ExecutionStateSamplerTest.java
@@ -0,0 +1,489 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.fn.harness.control;
+
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.Matchers.anyOf;
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.greaterThan;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertTrue;
+import static org.mockito.Mockito.mock;
+
+import java.util.HashMap;
+import java.util.Map;
+import java.util.concurrent.CountDownLatch;
+import org.apache.beam.fn.harness.control.ExecutionStateSampler.ExecutionState;
+import org.apache.beam.fn.harness.control.ExecutionStateSampler.ExecutionStateTracker;
+import org.apache.beam.fn.harness.control.ExecutionStateSampler.ExecutionStateTrackerStatus;
+import org.apache.beam.runners.core.metrics.MonitoringInfoEncodings;
+import org.apache.beam.sdk.options.PipelineOptionsFactory;
+import org.apache.beam.sdk.testing.ExpectedLogs;
+import org.apache.beam.vendor.grpc.v1p43p2.com.google.protobuf.ByteString;
+import org.joda.time.DateTimeUtils.MillisProvider;
+import org.joda.time.Duration;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+import org.mockito.Mockito;
+import org.mockito.invocation.InvocationOnMock;
+import org.mockito.stubbing.Answer;
+
+/** Tests for {@link ExecutionStateSampler}. */
+@RunWith(JUnit4.class)
+public class ExecutionStateSamplerTest {
+
+  @Rule public ExpectedLogs expectedLogs = ExpectedLogs.none(ExecutionStateSampler.class);
+
+  @Test
+  public void testSamplingProducesCorrectFinalResults() throws Exception {
+    MillisProvider clock = mock(MillisProvider.class);
+    ExecutionStateSampler sampler =
+        new ExecutionStateSampler(
+            PipelineOptionsFactory.fromArgs("--experiments=state_sampling_period_millis=10")
+                .create(),
+            clock);
+    ExecutionStateTracker tracker1 = sampler.create();
+    ExecutionState state1 =
+        tracker1.create("shortId1", "ptransformId1", "ptransformIdName1", "process");
+
+    ExecutionStateTracker tracker2 = sampler.create();
+    ExecutionState state2 =
+        tracker2.create("shortId2", "ptransformId2", "ptransformIdName2", "process");
+
+    CountDownLatch waitTillActive = new CountDownLatch(1);
+    CountDownLatch waitTillIntermediateReport = new CountDownLatch(1);
+    CountDownLatch waitTillStatesDeactivated = new CountDownLatch(1);
+    CountDownLatch waitForSamples = new CountDownLatch(1);
+    CountDownLatch waitForMoreSamples = new CountDownLatch(1);
+    CountDownLatch waitForEvenMoreSamples = new CountDownLatch(1);
+    Thread testThread = Thread.currentThread();
+    Mockito.when(clock.getMillis())
+        .thenAnswer(
+            new Answer<Long>() {
+              private long currentTime;
+
+              @Override
+              public Long answer(InvocationOnMock invocation) throws Throwable {
+                if (Thread.currentThread().equals(testThread)) {
+                  return 1L;
+                } else {
+                  // Block the state sampling thread till the state is active
+                  // and unblock the state transition once a certain number of samples
+                  // have been taken.
+                  // Block the state sampling thread till the state is active
+                  // and unblock the state transition once a certain number of samples
+                  // have been taken.
+                  if (currentTime < 1000L) {
+                    waitTillActive.await();
+                    currentTime += 100L;
+                  } else if (currentTime < 1500L) {
+                    waitForSamples.countDown();
+                    waitTillIntermediateReport.await();
+                    currentTime += 100L;
+                  } else if (currentTime == 1500L) {
+                    waitForMoreSamples.countDown();
+                    waitTillStatesDeactivated.await();
+                    currentTime = 1600L;
+                  } else if (currentTime == 1600L) {
+                    waitForEvenMoreSamples.countDown();
+                  }
+                  return currentTime;
+                }
+              }
+            });
+
+    // No tracked thread
+    assertNull(tracker1.getStatus());
+    assertNull(tracker2.getStatus());
+
+    tracker1.start("bundleId1");
+    tracker2.start("bundleId2");
+
+    state1.activate();
+    state2.activate();
+
+    // Check that the status returns a value as soon as it is activated.
+    ExecutionStateTrackerStatus activeBundleStatus1 = tracker1.getStatus();
+    ExecutionStateTrackerStatus activeBundleStatus2 = tracker2.getStatus();
+    assertEquals("ptransformId1", activeBundleStatus1.getPTransformId());
+    assertEquals("ptransformId2", activeBundleStatus2.getPTransformId());
+    assertEquals("ptransformIdName1", activeBundleStatus1.getPTransformUniqueName());
+    assertEquals("ptransformIdName2", activeBundleStatus2.getPTransformUniqueName());
+    assertEquals(Thread.currentThread(), activeBundleStatus1.getTrackedThread());
+    assertEquals(Thread.currentThread(), activeBundleStatus2.getTrackedThread());
+    assertThat(
+        activeBundleStatus1.getLastTransitionTimeMillis(),
+        // Because we are using lazySet, we aren't guaranteed to see the latest value
+        // but we should definitely be seeing a value that isn't zero
+        equalTo(1L));
+    assertThat(
+        activeBundleStatus2.getLastTransitionTimeMillis(),
+        // Internal implementation has this be equal to the second value we return (2 * 100L)
+        equalTo(1L));
+
+    waitTillActive.countDown();
+    waitForSamples.await();
+
+    // Check that we get additional data about the active PTransform.
+    ExecutionStateTrackerStatus activeStateStatus1 = tracker1.getStatus();
+    ExecutionStateTrackerStatus activeStateStatus2 = tracker2.getStatus();
+    assertEquals("ptransformId1", activeStateStatus1.getPTransformId());
+    assertEquals("ptransformId2", activeStateStatus2.getPTransformId());
+    assertEquals("ptransformIdName1", activeStateStatus1.getPTransformUniqueName());
+    assertEquals("ptransformIdName2", activeStateStatus2.getPTransformUniqueName());
+    assertEquals(Thread.currentThread(), activeStateStatus1.getTrackedThread());
+    assertEquals(Thread.currentThread(), activeStateStatus2.getTrackedThread());
+    assertThat(
+        activeStateStatus1.getLastTransitionTimeMillis(),
+        greaterThan(activeBundleStatus1.getLastTransitionTimeMillis()));
+    assertThat(
+        activeStateStatus2.getLastTransitionTimeMillis(),
+        greaterThan(activeBundleStatus2.getLastTransitionTimeMillis()));
+
+    // Validate intermediate monitoring data
+    Map<String, ByteString> intermediateResults1 = new HashMap<>();
+    Map<String, ByteString> intermediateResults2 = new HashMap<>();
+    tracker1.updateIntermediateMonitoringData(intermediateResults1);
+    tracker2.updateIntermediateMonitoringData(intermediateResults2);
+    assertThat(
+        MonitoringInfoEncodings.decodeInt64Counter(intermediateResults1.get("shortId1")),
+        // Because we are using lazySet, we aren't guaranteed to see the latest value.
+        // The CountDownLatch ensures that we will see either the prior value or
+        // the latest value.
+        anyOf(equalTo(900L), equalTo(1000L)));
+    assertThat(
+        MonitoringInfoEncodings.decodeInt64Counter(intermediateResults2.get("shortId2")),
+        // Because we are using lazySet, we aren't guaranteed to see the latest value.
+        // The CountDownLatch ensures that we will see either the prior value or
+        // the latest value.
+        anyOf(equalTo(900L), equalTo(1000L)));
+
+    waitTillIntermediateReport.countDown();
+    waitForMoreSamples.await();
+    state1.deactivate();
+    state2.deactivate();
+
+    waitTillStatesDeactivated.countDown();
+    waitForEvenMoreSamples.await();
+
+    // Check the status once the states are deactivated but the bundle is still active
+    ExecutionStateTrackerStatus inactiveStateStatus1 = tracker1.getStatus();
+    ExecutionStateTrackerStatus inactiveStateStatus2 = tracker2.getStatus();
+    assertNull(inactiveStateStatus1.getPTransformId());
+    assertNull(inactiveStateStatus2.getPTransformId());
+    assertNull(inactiveStateStatus1.getPTransformUniqueName());
+    assertNull(inactiveStateStatus2.getPTransformUniqueName());
+    assertEquals(Thread.currentThread(), inactiveStateStatus1.getTrackedThread());
+    assertEquals(Thread.currentThread(), inactiveStateStatus2.getTrackedThread());
+    assertThat(
+        inactiveStateStatus1.getLastTransitionTimeMillis(),
+        greaterThan(activeStateStatus1.getLastTransitionTimeMillis()));
+    assertThat(
+        inactiveStateStatus2.getLastTransitionTimeMillis(),
+        greaterThan(activeStateStatus1.getLastTransitionTimeMillis()));
+
+    // Validate the final monitoring data
+    Map<String, ByteString> finalResults1 = new HashMap<>();
+    Map<String, ByteString> finalResults2 = new HashMap<>();
+    tracker1.updateFinalMonitoringData(finalResults1);
+    tracker2.updateFinalMonitoringData(finalResults2);
+    assertThat(
+        MonitoringInfoEncodings.decodeInt64Counter(finalResults1.get("shortId1")),
+        // Because we are using lazySet, we aren't guaranteed to see the latest value.
+        // The CountDownLatch ensures that we will see either the prior value or
+        // the latest value.
+        anyOf(equalTo(1400L), equalTo(1500L)));
+    assertThat(
+        MonitoringInfoEncodings.decodeInt64Counter(finalResults2.get("shortId2")),
+        // Because we are using lazySet, we aren't guaranteed to see the latest value.
+        // The CountDownLatch ensures that we will see either the prior value or
+        // the latest value.
+        anyOf(equalTo(1400L), equalTo(1500L)));
+
+    tracker1.reset();
+    tracker2.reset();
+
+    // Shouldn't have a status returned since there is no active bundle.
+    assertNull(tracker1.getStatus());
+    assertNull(tracker2.getStatus());
+
+    sampler.stop();
+    expectedLogs.verifyNotLogged("Operation ongoing");
+  }
+
+  @Test
+  public void testSamplingDoesntReportDuplicateFinalResults() throws Exception {
+    MillisProvider clock = mock(MillisProvider.class);
+    ExecutionStateSampler sampler =
+        new ExecutionStateSampler(
+            PipelineOptionsFactory.fromArgs("--experiments=state_sampling_period_millis=10")
+                .create(),
+            clock);
+    ExecutionStateTracker tracker1 = sampler.create();
+    ExecutionState state1 =
+        tracker1.create("shortId1", "ptransformId1", "ptransformIdName1", "process");
+
+    ExecutionStateTracker tracker2 = sampler.create();
+    ExecutionState state2 =
+        tracker2.create("shortId2", "ptransformId2", "ptransformIdName2", "process");
+
+    CountDownLatch waitTillActive = new CountDownLatch(1);
+    CountDownLatch waitForSamples = new CountDownLatch(1);
+    Thread testThread = Thread.currentThread();
+    Mockito.when(clock.getMillis())
+        .thenAnswer(
+            new Answer<Long>() {
+              private long currentTime;
+
+              @Override
+              public Long answer(InvocationOnMock invocation) throws Throwable {
+                if (Thread.currentThread().equals(testThread)) {
+                  return 0L;
+                } else {
+                  // Block the state sampling thread till the state is active
+                  // and unblock the state transition once a certain number of samples
+                  // have been taken.
+                  waitTillActive.await();
+                  if (currentTime < 1000L) {
+                    currentTime += 100L;
+                  } else {
+                    waitForSamples.countDown();
+                  }
+                  return currentTime;
+                }
+              }
+            });
+
+    tracker1.start("bundleId1");
+    tracker2.start("bundleId2");
+
+    state1.activate();
+    state2.activate();
+    waitTillActive.countDown();
+    waitForSamples.await();
+    state1.deactivate();
+    state2.deactivate();
+
+    Map<String, ByteString> intermediateResults1 = new HashMap<>();
+    Map<String, ByteString> intermediateResults2 = new HashMap<>();
+    tracker1.updateIntermediateMonitoringData(intermediateResults1);
+    tracker2.updateIntermediateMonitoringData(intermediateResults2);
+    assertThat(
+        MonitoringInfoEncodings.decodeInt64Counter(intermediateResults1.get("shortId1")),
+        // Because we are using lazySet, we aren't guaranteed to see the latest value.
+        // The CountDownLatch ensures that we will see either the prior value or
+        // the latest value.
+        anyOf(equalTo(900L), equalTo(1000L)));
+    assertThat(
+        MonitoringInfoEncodings.decodeInt64Counter(intermediateResults2.get("shortId2")),
+        // Because we are using lazySet, we aren't guaranteed to see the latest value.
+        // The CountDownLatch ensures that we will see either the prior value or
+        // the latest value.
+        anyOf(equalTo(900L), equalTo(1000L)));
+
+    state1.deactivate();
+    state2.deactivate();
+
+    Map<String, ByteString> finalResults1 = new HashMap<>();
+    Map<String, ByteString> finalResults2 = new HashMap<>();
+    tracker1.updateFinalMonitoringData(finalResults1);
+    tracker2.updateFinalMonitoringData(finalResults2);
+
+    assertTrue(finalResults1.isEmpty());
+    assertTrue(finalResults2.isEmpty());
+
+    tracker1.reset();
+    tracker2.reset();
+
+    sampler.stop();
+    expectedLogs.verifyNotLogged("Operation ongoing");
+  }
+
+  @Test
+  public void testTrackerReuse() throws Exception {
+    MillisProvider clock = mock(MillisProvider.class);
+    ExecutionStateSampler sampler =
+        new ExecutionStateSampler(
+            PipelineOptionsFactory.fromArgs("--experiments=state_sampling_period_millis=10")
+                .create(),
+            clock);
+    ExecutionStateTracker tracker = sampler.create();
+    ExecutionState state = tracker.create("shortId", "ptransformId", "ptransformIdName", "process");
+
+    CountDownLatch waitTillActive = new CountDownLatch(1);
+    CountDownLatch waitTillSecondStateActive = new CountDownLatch(1);
+    CountDownLatch waitForSamples = new CountDownLatch(1);
+    CountDownLatch waitForMoreSamples = new CountDownLatch(1);
+    Thread testThread = Thread.currentThread();
+    Mockito.when(clock.getMillis())
+        .thenAnswer(
+            new Answer<Long>() {
+              private long currentTime;
+
+              @Override
+              public Long answer(InvocationOnMock invocation) throws Throwable {
+                if (Thread.currentThread().equals(testThread)) {
+                  return 0L;
+                } else {
+                  // Block the state sampling thread till the state is active
+                  // and unblock the state transition once a certain number of samples
+                  // have been taken.
+                  if (currentTime < 1000L) {
+                    waitTillActive.await();
+                    currentTime += 100L;
+                  } else if (currentTime < 1500L) {
+                    waitForSamples.countDown();
+                    waitTillSecondStateActive.await();
+                    currentTime += 100L;
+                  } else {
+                    waitForMoreSamples.countDown();
+                  }
+                  return currentTime;
+                }
+              }
+            });
+
+    {
+      tracker.start("bundleId1");
+      state.activate();
+      waitTillActive.countDown();
+      waitForSamples.await();
+      state.deactivate();
+      Map<String, ByteString> finalResults = new HashMap<>();
+      tracker.updateFinalMonitoringData(finalResults);
+      assertThat(
+          MonitoringInfoEncodings.decodeInt64Counter(finalResults.get("shortId")),
+          // Because we are using lazySet, we aren't guaranteed to see the latest value.
+          // The CountDownLatch ensures that we will see either the prior value or
+          // the latest value.
+          anyOf(equalTo(900L), equalTo(1000L)));
+      tracker.reset();
+    }
+
+    {
+      tracker.start("bundleId2");
+      state.activate();
+      waitTillSecondStateActive.countDown();
+      waitForMoreSamples.await();
+      state.deactivate();
+      Map<String, ByteString> finalResults = new HashMap<>();
+      tracker.updateFinalMonitoringData(finalResults);
+      assertThat(
+          MonitoringInfoEncodings.decodeInt64Counter(finalResults.get("shortId")),
+          // Because we are using lazySet, we aren't guaranteed to see the latest value.
+          // The CountDownLatch ensures that we will see either the prior value or
+          // the latest value.
+          anyOf(equalTo(400L), equalTo(500L)));
+      tracker.reset();
+    }
+
+    expectedLogs.verifyNotLogged("Operation ongoing");
+  }
+
+  @Test
+  public void testLullDetectionOccursInActiveBundle() throws Exception {
+    MillisProvider clock = mock(MillisProvider.class);
+    ExecutionStateSampler sampler =
+        new ExecutionStateSampler(
+            PipelineOptionsFactory.fromArgs("--experiments=state_sampling_period_millis=10")
+                .create(),
+            clock);
+    ExecutionStateTracker tracker = sampler.create();
+
+    CountDownLatch waitTillActive = new CountDownLatch(1);
+    CountDownLatch waitForSamples = new CountDownLatch(10);
+    Thread testThread = Thread.currentThread();
+    Mockito.when(clock.getMillis())
+        .thenAnswer(
+            new Answer<Long>() {
+              private long currentTime;
+
+              @Override
+              public Long answer(InvocationOnMock invocation) throws Throwable {
+                if (Thread.currentThread().equals(testThread)) {
+                  return 0L;
+                } else {
+                  // Block the state sampling thread till the bundle is active
+                  // and unblock the state transition once a certain number of samples
+                  // have been taken.
+                  waitTillActive.await();
+                  waitForSamples.countDown();
+                  currentTime += Duration.standardMinutes(1).getMillis();
+                  return currentTime;
+                }
+              }
+            });
+
+    tracker.start("bundleId");
+    waitTillActive.countDown();
+    waitForSamples.await();
+    tracker.reset();
+
+    sampler.stop();
+    expectedLogs.verifyWarn("Operation ongoing in bundle bundleId for at least");
+  }
+
+  @Test
+  public void testLullDetectionOccursInActiveState() throws Exception {
+    MillisProvider clock = mock(MillisProvider.class);
+    ExecutionStateSampler sampler =
+        new ExecutionStateSampler(
+            PipelineOptionsFactory.fromArgs("--experiments=state_sampling_period_millis=10")
+                .create(),
+            clock);
+    ExecutionStateTracker tracker = sampler.create();
+    ExecutionState state = tracker.create("shortId", "ptransformId", "ptransformIdName", "process");
+
+    CountDownLatch waitTillActive = new CountDownLatch(1);
+    CountDownLatch waitForSamples = new CountDownLatch(10);
+    Thread testThread = Thread.currentThread();
+    Mockito.when(clock.getMillis())
+        .thenAnswer(
+            new Answer<Long>() {
+              private long currentTime;
+
+              @Override
+              public Long answer(InvocationOnMock invocation) throws Throwable {
+                if (Thread.currentThread().equals(testThread)) {
+                  return 0L;
+                } else {
+                  // Block the state sampling thread till the state is active
+                  // and unblock the state transition once a certain number of samples
+                  // have been taken.
+                  waitTillActive.await();
+                  waitForSamples.countDown();
+                  currentTime += Duration.standardMinutes(1).getMillis();
+                  return currentTime;
+                }
+              }
+            });
+
+    tracker.start("bundleId");
+    state.activate();
+    waitTillActive.countDown();
+    waitForSamples.await();
+    state.deactivate();
+    tracker.reset();
+
+    sampler.stop();
+    expectedLogs.verifyWarn("Operation ongoing in bundle bundleId for PTransform");
+  }
+}
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java
index 2cd25707305..450628a31a9 100644
--- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java
@@ -70,6 +70,7 @@ import org.apache.beam.fn.harness.BeamFnDataReadRunner;
 import org.apache.beam.fn.harness.Cache;
 import org.apache.beam.fn.harness.Caches;
 import org.apache.beam.fn.harness.PTransformRunnerFactory;
+import org.apache.beam.fn.harness.control.ExecutionStateSampler.ExecutionStateTracker;
 import org.apache.beam.fn.harness.control.FinalizeBundleHandler.CallbackRegistration;
 import org.apache.beam.fn.harness.control.ProcessBundleHandler.BundleProcessor;
 import org.apache.beam.fn.harness.control.ProcessBundleHandler.BundleProcessorCache;
@@ -113,7 +114,6 @@ import org.apache.beam.runners.core.construction.ModelCoders;
 import org.apache.beam.runners.core.construction.PTransformTranslation;
 import org.apache.beam.runners.core.construction.ParDoTranslation;
 import org.apache.beam.runners.core.construction.Timer;
-import org.apache.beam.runners.core.metrics.ExecutionStateTracker;
 import org.apache.beam.runners.core.metrics.MetricsContainerStepMap;
 import org.apache.beam.runners.core.metrics.ShortIdMap;
 import org.apache.beam.sdk.coders.KvCoder;
@@ -148,6 +148,7 @@ import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterable
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.util.concurrent.Uninterruptibles;
 import org.joda.time.Instant;
+import org.junit.After;
 import org.junit.Before;
 import org.junit.Rule;
 import org.junit.Test;
@@ -170,11 +171,19 @@ public class ProcessBundleHandlerTest {
 
   @Rule public TestExecutorService executor = TestExecutors.from(Executors::newCachedThreadPool);
   @Mock private BeamFnDataClient beamFnDataClient;
+  private ExecutionStateSampler executionStateSampler;
 
   @Before
   public void setUp() {
     MockitoAnnotations.initMocks(this);
     TestBundleProcessor.resetCnt = 0;
+    executionStateSampler =
+        new ExecutionStateSampler(PipelineOptionsFactory.create(), System::currentTimeMillis);
+  }
+
+  @After
+  public void tearDown() {
+    executionStateSampler.stop();
   }
 
   private static class TestDoFn extends DoFn<String, String> {
@@ -362,6 +371,7 @@ public class ProcessBundleHandlerTest {
             null /* beamFnStateClient */,
             null /* finalizeBundleHandler */,
             new ShortIdMap(),
+            executionStateSampler,
             ImmutableMap.of(),
             Caches.noop(),
             new BundleProcessorCache());
@@ -391,6 +401,7 @@ public class ProcessBundleHandlerTest {
             null /* beamFnStateClient */,
             null /* finalizeBundleHandler */,
             new ShortIdMap(),
+            executionStateSampler,
             ImmutableMap.of(),
             Caches.noop(),
             new BundleProcessorCache());
@@ -467,6 +478,7 @@ public class ProcessBundleHandlerTest {
             null /* beamFnStateClient */,
             null /* finalizeBundleHandler */,
             new ShortIdMap(),
+            executionStateSampler,
             ImmutableMap.of(
                 DATA_INPUT_URN, startFinishRecorder,
                 DATA_OUTPUT_URN, startFinishRecorder),
@@ -572,6 +584,7 @@ public class ProcessBundleHandlerTest {
             null /* beamFnStateClient */,
             null /* finalizeBundleHandler */,
             new ShortIdMap(),
+            executionStateSampler,
             urnToPTransformRunnerFactoryMap,
             Caches.noop(),
             new BundleProcessorCache());
@@ -622,6 +635,7 @@ public class ProcessBundleHandlerTest {
             null /* beamFnStateGrpcClientCache */,
             null /* finalizeBundleHandler */,
             new ShortIdMap(),
+            executionStateSampler,
             ImmutableMap.of(DATA_INPUT_URN, (context) -> null),
             Caches.noop(),
             new TestBundleProcessorCache());
@@ -748,10 +762,7 @@ public class ProcessBundleHandlerTest {
     assertNull(bundleProcessor.getInstructionId());
     assertNull(bundleProcessor.getCacheTokens());
     assertNull(bundleCache.peek("A"));
-    verify(startFunctionRegistry, times(1)).reset();
-    verify(finishFunctionRegistry, times(1)).reset();
     verify(splitListener, times(1)).clear();
-    verify(pCollectionConsumerRegistry, times(1)).reset();
     verify(metricsContainerRegistry, times(1)).reset();
     verify(stateTracker, times(1)).reset();
     verify(bundleFinalizationCallbacks, times(1)).clear();
@@ -787,6 +798,7 @@ public class ProcessBundleHandlerTest {
             null /* beamFnStateGrpcClientCache */,
             null /* finalizeBundleHandler */,
             new ShortIdMap(),
+            executionStateSampler,
             ImmutableMap.of(
                 DATA_INPUT_URN,
                 (context) -> {
@@ -830,6 +842,7 @@ public class ProcessBundleHandlerTest {
             null /* beamFnStateGrpcClientCache */,
             mockFinalizeBundleHandler,
             new ShortIdMap(),
+            executionStateSampler,
             ImmutableMap.of(
                 DATA_INPUT_URN,
                 (PTransformRunnerFactory<Object>)
@@ -886,6 +899,7 @@ public class ProcessBundleHandlerTest {
             null /* beamFnStateGrpcClientCache */,
             null /* finalizeBundleHandler */,
             new ShortIdMap(),
+            executionStateSampler,
             ImmutableMap.of(
                 DATA_INPUT_URN,
                 (PTransformRunnerFactory<Object>)
@@ -1068,6 +1082,7 @@ public class ProcessBundleHandlerTest {
         null /* beamFnStateClient */,
         null /* finalizeBundleHandler */,
         new ShortIdMap(),
+        executionStateSampler,
         urnToPTransformRunnerFactoryMap,
         Caches.noop(),
         new BundleProcessorCache());
@@ -1390,6 +1405,7 @@ public class ProcessBundleHandlerTest {
             null /* beamFnStateGrpcClientCache */,
             null /* finalizeBundleHandler */,
             new ShortIdMap(),
+            executionStateSampler,
             ImmutableMap.of(
                 DATA_INPUT_URN,
                 (PTransformRunnerFactory<Object>)
@@ -1459,6 +1475,7 @@ public class ProcessBundleHandlerTest {
             null /* beamFnStateGrpcClientCache */,
             null /* finalizeBundleHandler */,
             new ShortIdMap(),
+            executionStateSampler,
             ImmutableMap.of(
                 DATA_INPUT_URN,
                 (PTransformRunnerFactory<Object>)
@@ -1512,6 +1529,7 @@ public class ProcessBundleHandlerTest {
             null /* beamFnStateGrpcClientCache */,
             null /* finalizeBundleHandler */,
             new ShortIdMap(),
+            executionStateSampler,
             ImmutableMap.of(
                 DATA_INPUT_URN,
                 (PTransformRunnerFactory<Object>)
@@ -1594,6 +1612,7 @@ public class ProcessBundleHandlerTest {
             mockBeamFnStateGrpcClient,
             null /* finalizeBundleHandler */,
             new ShortIdMap(),
+            executionStateSampler,
             ImmutableMap.of(
                 DATA_INPUT_URN,
                 new PTransformRunnerFactory<Object>() {
@@ -1647,6 +1666,7 @@ public class ProcessBundleHandlerTest {
             null /* beamFnStateGrpcClientCache */,
             null /* finalizeBundleHandler */,
             new ShortIdMap(),
+            executionStateSampler,
             ImmutableMap.of(
                 DATA_INPUT_URN,
                 new PTransformRunnerFactory<Object>() {
@@ -1762,6 +1782,7 @@ public class ProcessBundleHandlerTest {
             null /* beamFnStateClient */,
             null /* finalizeBundleHandler */,
             new ShortIdMap(),
+            executionStateSampler,
             ImmutableMap.of(DATA_INPUT_URN, startFinishGuard),
             Caches.noop(),
             bundleProcessorCache);
@@ -1880,6 +1901,7 @@ public class ProcessBundleHandlerTest {
             null /* beamFnStateGrpcClientCache */,
             null /* finalizeBundleHandler */,
             new ShortIdMap(),
+            executionStateSampler,
             ImmutableMap.of(
                 DATA_INPUT_URN,
                 new PTransformRunnerFactory<Object>() {
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistryTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistryTest.java
index 0aa0e944ae5..60c874c6a37 100644
--- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistryTest.java
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistryTest.java
@@ -37,12 +37,12 @@ import java.util.List;
 import java.util.Map;
 import org.apache.beam.fn.harness.HandlesSplits;
 import org.apache.beam.fn.harness.control.BundleProgressReporter;
+import org.apache.beam.fn.harness.control.ExecutionStateSampler;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleDescriptor;
 import org.apache.beam.model.pipeline.v1.MetricsApi.MonitoringInfo;
 import org.apache.beam.model.pipeline.v1.RunnerApi.PCollection;
 import org.apache.beam.runners.core.construction.SdkComponents;
 import org.apache.beam.runners.core.metrics.DistributionData;
-import org.apache.beam.runners.core.metrics.ExecutionStateTracker;
 import org.apache.beam.runners.core.metrics.MetricsContainerStepMap;
 import org.apache.beam.runners.core.metrics.MonitoringInfoConstants;
 import org.apache.beam.runners.core.metrics.MonitoringInfoConstants.Labels;
@@ -53,11 +53,14 @@ import org.apache.beam.sdk.coders.IterableCoder;
 import org.apache.beam.sdk.coders.StringUtf8Coder;
 import org.apache.beam.sdk.fn.data.FnDataReceiver;
 import org.apache.beam.sdk.metrics.MetricsEnvironment;
+import org.apache.beam.sdk.options.PipelineOptionsFactory;
 import org.apache.beam.sdk.util.WindowedValue;
 import org.apache.beam.sdk.util.common.ElementByteSizeObservableIterable;
 import org.apache.beam.sdk.util.common.ElementByteSizeObservableIterator;
 import org.apache.beam.vendor.grpc.v1p43p2.com.google.protobuf.ByteString;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
+import org.junit.After;
+import org.junit.Before;
 import org.junit.Rule;
 import org.junit.Test;
 import org.junit.rules.ExpectedException;
@@ -101,6 +104,18 @@ public class PCollectionConsumerRegistryTest {
     }
   }
 
+  private ExecutionStateSampler sampler;
+
+  @Before
+  public void setUp() throws Exception {
+    sampler = new ExecutionStateSampler(PipelineOptionsFactory.create(), System::currentTimeMillis);
+  }
+
+  @After
+  public void tearDown() throws Exception {
+    sampler.stop();
+  }
+
   @Test
   public void singleConsumer() throws Exception {
     final String pTransformIdA = "pTransformIdA";
@@ -111,13 +126,13 @@ public class PCollectionConsumerRegistryTest {
     PCollectionConsumerRegistry consumers =
         new PCollectionConsumerRegistry(
             metricsContainerRegistry,
-            mock(ExecutionStateTracker.class),
+            sampler.create(),
             shortIds,
             reporterAndRegistrar,
             TEST_DESCRIPTOR);
     FnDataReceiver<WindowedValue<String>> consumerA1 = mock(FnDataReceiver.class);
 
-    consumers.register(P_COLLECTION_A, pTransformIdA, consumerA1);
+    consumers.register(P_COLLECTION_A, pTransformIdA, pTransformIdA + "Name", consumerA1);
 
     FnDataReceiver<WindowedValue<String>> wrapperConsumer =
         (FnDataReceiver<WindowedValue<String>>)
@@ -172,13 +187,13 @@ public class PCollectionConsumerRegistryTest {
     PCollectionConsumerRegistry consumers =
         new PCollectionConsumerRegistry(
             metricsContainerRegistry,
-            mock(ExecutionStateTracker.class),
+            sampler.create(),
             shortIds,
             reporterAndRegistrar,
             TEST_DESCRIPTOR);
     FnDataReceiver<WindowedValue<String>> consumer = mock(FnDataReceiver.class);
 
-    consumers.register(P_COLLECTION_A, pTransformId, consumer);
+    consumers.register(P_COLLECTION_A, pTransformId, pTransformId + "Name", consumer);
 
     FnDataReceiver<WindowedValue<String>> wrapperConsumer =
         (FnDataReceiver<WindowedValue<String>>)
@@ -199,7 +214,7 @@ public class PCollectionConsumerRegistryTest {
     PCollectionConsumerRegistry consumers =
         new PCollectionConsumerRegistry(
             metricsContainerRegistry,
-            mock(ExecutionStateTracker.class),
+            sampler.create(),
             shortIds,
             reporterAndRegistrar,
             TEST_DESCRIPTOR);
@@ -258,15 +273,15 @@ public class PCollectionConsumerRegistryTest {
     PCollectionConsumerRegistry consumers =
         new PCollectionConsumerRegistry(
             metricsContainerRegistry,
-            mock(ExecutionStateTracker.class),
+            sampler.create(),
             shortIds,
             reporterAndRegistrar,
             TEST_DESCRIPTOR);
     FnDataReceiver<WindowedValue<String>> consumerA1 = mock(FnDataReceiver.class);
     FnDataReceiver<WindowedValue<String>> consumerA2 = mock(FnDataReceiver.class);
 
-    consumers.register(P_COLLECTION_A, pTransformIdA, consumerA1);
-    consumers.register(P_COLLECTION_A, pTransformIdB, consumerA2);
+    consumers.register(P_COLLECTION_A, pTransformIdA, pTransformIdA + "Name", consumerA1);
+    consumers.register(P_COLLECTION_A, pTransformIdB, pTransformIdB + "Name", consumerA2);
 
     FnDataReceiver<WindowedValue<String>> wrapperConsumer =
         (FnDataReceiver<WindowedValue<String>>)
@@ -322,15 +337,15 @@ public class PCollectionConsumerRegistryTest {
     PCollectionConsumerRegistry consumers =
         new PCollectionConsumerRegistry(
             metricsContainerRegistry,
-            mock(ExecutionStateTracker.class),
+            sampler.create(),
             shortIds,
             reporterAndRegistrar,
             TEST_DESCRIPTOR);
     FnDataReceiver<WindowedValue<String>> consumerA1 = mock(FnDataReceiver.class);
     FnDataReceiver<WindowedValue<String>> consumerA2 = mock(FnDataReceiver.class);
 
-    consumers.register(P_COLLECTION_A, pTransformId, consumerA1);
-    consumers.register(P_COLLECTION_A, pTransformId, consumerA2);
+    consumers.register(P_COLLECTION_A, pTransformId, pTransformId + "Name", consumerA1);
+    consumers.register(P_COLLECTION_A, pTransformId, pTransformId + "Name", consumerA2);
 
     FnDataReceiver<WindowedValue<String>> wrapperConsumer =
         (FnDataReceiver<WindowedValue<String>>)
@@ -352,19 +367,19 @@ public class PCollectionConsumerRegistryTest {
     PCollectionConsumerRegistry consumers =
         new PCollectionConsumerRegistry(
             metricsContainerRegistry,
-            mock(ExecutionStateTracker.class),
+            sampler.create(),
             shortIds,
             reporterAndRegistrar,
             TEST_DESCRIPTOR);
     FnDataReceiver<WindowedValue<String>> consumerA1 = mock(FnDataReceiver.class);
     FnDataReceiver<WindowedValue<String>> consumerA2 = mock(FnDataReceiver.class);
 
-    consumers.register(P_COLLECTION_A, pTransformId, consumerA1);
+    consumers.register(P_COLLECTION_A, pTransformId, pTransformId + "Name", consumerA1);
     consumers.getMultiplexingConsumer(P_COLLECTION_A);
 
     expectedException.expect(RuntimeException.class);
     expectedException.expectMessage("cannot be register()-d after");
-    consumers.register(P_COLLECTION_A, pTransformId, consumerA2);
+    consumers.register(P_COLLECTION_A, pTransformId, pTransformId + "Name", consumerA2);
   }
 
   @Test
@@ -377,15 +392,15 @@ public class PCollectionConsumerRegistryTest {
     PCollectionConsumerRegistry consumers =
         new PCollectionConsumerRegistry(
             metricsContainerRegistry,
-            mock(ExecutionStateTracker.class),
+            sampler.create(),
             shortIds,
             reporterAndRegistrar,
             TEST_DESCRIPTOR);
     FnDataReceiver<WindowedValue<String>> consumerA1 = mock(FnDataReceiver.class);
     FnDataReceiver<WindowedValue<String>> consumerA2 = mock(FnDataReceiver.class);
 
-    consumers.register(P_COLLECTION_A, "pTransformA", consumerA1);
-    consumers.register(P_COLLECTION_A, "pTransformB", consumerA2);
+    consumers.register(P_COLLECTION_A, "pTransformA", "pTransformAName", consumerA1);
+    consumers.register(P_COLLECTION_A, "pTransformB", "pTransformBName", consumerA2);
 
     FnDataReceiver<WindowedValue<String>> wrapperConsumer =
         (FnDataReceiver<WindowedValue<String>>)
@@ -413,13 +428,13 @@ public class PCollectionConsumerRegistryTest {
     PCollectionConsumerRegistry consumers =
         new PCollectionConsumerRegistry(
             metricsContainerRegistry,
-            mock(ExecutionStateTracker.class),
+            sampler.create(),
             shortIds,
             reporterAndRegistrar,
             TEST_DESCRIPTOR);
     SplittingReceiver consumerA1 = mock(SplittingReceiver.class);
 
-    consumers.register(P_COLLECTION_A, pTransformIdA, consumerA1);
+    consumers.register(P_COLLECTION_A, pTransformIdA, pTransformIdA + "Name", consumerA1);
 
     FnDataReceiver<WindowedValue<String>> wrapperConsumer =
         (FnDataReceiver<WindowedValue<String>>)
@@ -444,13 +459,13 @@ public class PCollectionConsumerRegistryTest {
     PCollectionConsumerRegistry consumers =
         new PCollectionConsumerRegistry(
             metricsContainerRegistry,
-            mock(ExecutionStateTracker.class),
+            sampler.create(),
             shortIds,
             reporterAndRegistrar,
             TEST_DESCRIPTOR);
     FnDataReceiver<WindowedValue<Iterable<String>>> consumerA1 = mock(FnDataReceiver.class);
 
-    consumers.register(P_COLLECTION_B, pTransformIdA, consumerA1);
+    consumers.register(P_COLLECTION_B, pTransformIdA, pTransformIdA + "Name", consumerA1);
 
     FnDataReceiver<WindowedValue<Iterable<String>>> wrapperConsumer =
         (FnDataReceiver<WindowedValue<Iterable<String>>>)
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/PTransformFunctionRegistryTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/PTransformFunctionRegistryTest.java
index b60ebef5344..35025a61a71 100644
--- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/PTransformFunctionRegistryTest.java
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/PTransformFunctionRegistryTest.java
@@ -17,15 +17,25 @@
  */
 package org.apache.beam.fn.harness.data;
 
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertTrue;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.times;
-import static org.mockito.Mockito.verify;
 import static org.powermock.api.mockito.PowerMockito.mockStatic;
 
-import org.apache.beam.runners.core.metrics.ExecutionStateTracker;
+import java.util.concurrent.atomic.AtomicBoolean;
+import org.apache.beam.fn.harness.control.ExecutionStateSampler;
+import org.apache.beam.fn.harness.control.ExecutionStateSampler.ExecutionStateTracker;
+import org.apache.beam.fn.harness.control.ExecutionStateSampler.ExecutionStateTrackerStatus;
 import org.apache.beam.runners.core.metrics.MetricsContainerStepMap;
+import org.apache.beam.runners.core.metrics.MonitoringInfoConstants.Urns;
+import org.apache.beam.runners.core.metrics.ShortIdMap;
 import org.apache.beam.sdk.function.ThrowingRunnable;
 import org.apache.beam.sdk.metrics.MetricsEnvironment;
+import org.apache.beam.sdk.options.PipelineOptionsFactory;
+import org.junit.After;
+import org.junit.Before;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.powermock.api.mockito.PowerMockito;
@@ -37,41 +47,89 @@ import org.powermock.modules.junit4.PowerMockRunner;
 @PrepareForTest(MetricsEnvironment.class)
 public class PTransformFunctionRegistryTest {
 
+  private ExecutionStateSampler sampler;
+
+  @Before
+  public void setUp() {
+    sampler = new ExecutionStateSampler(PipelineOptionsFactory.create(), System::currentTimeMillis);
+  }
+
+  @After
+  public void tearDown() {
+    sampler.stop();
+  }
+
   @Test
-  public void functionsAreInvokedIndirectlyAfterRegisteringAndInvoking() throws Exception {
+  public void testStateTrackerRecordsStateTransitions() throws Exception {
+    ExecutionStateTracker executionStateTracker = sampler.create();
     PTransformFunctionRegistry testObject =
         new PTransformFunctionRegistry(
-            mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "start");
+            mock(MetricsContainerStepMap.class),
+            new ShortIdMap(),
+            executionStateTracker,
+            Urns.START_BUNDLE_MSECS);
 
-    ThrowingRunnable runnableA = mock(ThrowingRunnable.class);
-    ThrowingRunnable runnableB = mock(ThrowingRunnable.class);
-    testObject.register("pTransformA", runnableA);
-    testObject.register("pTransformB", runnableB);
+    final AtomicBoolean runnableAWasCalled = new AtomicBoolean();
+    final AtomicBoolean runnableBWasCalled = new AtomicBoolean();
+    ThrowingRunnable runnableA =
+        new ThrowingRunnable() {
+          @Override
+          public void run() throws Exception {
+            runnableAWasCalled.set(true);
+            ExecutionStateTrackerStatus executionStateTrackerStatus =
+                executionStateTracker.getStatus();
+            assertNotNull(executionStateTrackerStatus);
+            assertEquals(Thread.currentThread(), executionStateTrackerStatus.getTrackedThread());
+            assertEquals("pTransformA", executionStateTrackerStatus.getPTransformId());
+          }
+        };
+    ThrowingRunnable runnableB =
+        new ThrowingRunnable() {
+          @Override
+          public void run() throws Exception {
+            runnableBWasCalled.set(true);
+            ExecutionStateTrackerStatus executionStateTrackerStatus =
+                executionStateTracker.getStatus();
+            assertNotNull(executionStateTrackerStatus);
+            assertEquals(Thread.currentThread(), executionStateTrackerStatus.getTrackedThread());
+            assertEquals("pTransformB", executionStateTrackerStatus.getPTransformId());
+          }
+        };
+    testObject.register("pTransformA", "pTranformAName", runnableA);
+    testObject.register("pTransformB", "pTranformBName", runnableB);
 
+    executionStateTracker.start("testBundleId");
     for (ThrowingRunnable func : testObject.getFunctions()) {
       func.run();
     }
+    executionStateTracker.reset();
 
-    verify(runnableA, times(1)).run();
-    verify(runnableB, times(1)).run();
+    assertTrue(runnableAWasCalled.get());
+    assertTrue(runnableBWasCalled.get());
   }
 
   @Test
-  public void testScopedMetricContainerInvokedUponRunningFunctions() throws Exception {
+  public void testMetricsUponRunningFunctions() throws Exception {
+    ExecutionStateTracker executionStateTracker = sampler.create();
     mockStatic(MetricsEnvironment.class);
     MetricsContainerStepMap metricsContainerRegistry = new MetricsContainerStepMap();
     PTransformFunctionRegistry testObject =
         new PTransformFunctionRegistry(
-            metricsContainerRegistry, mock(ExecutionStateTracker.class), "start");
+            metricsContainerRegistry,
+            new ShortIdMap(),
+            executionStateTracker,
+            Urns.START_BUNDLE_MSECS);
 
     ThrowingRunnable runnableA = mock(ThrowingRunnable.class);
     ThrowingRunnable runnableB = mock(ThrowingRunnable.class);
-    testObject.register("pTransformA", runnableA);
-    testObject.register("pTransformB", runnableB);
+    testObject.register("pTransformA", "pTranformAName", runnableA);
+    testObject.register("pTransformB", "pTranformBName", runnableB);
 
+    executionStateTracker.start("testBundleId");
     for (ThrowingRunnable func : testObject.getFunctions()) {
       func.run();
     }
+    executionStateTracker.reset();
 
     // Verify that static scopedMetricsContainer is called with pTransformA's container.
     PowerMockito.verifyStatic(MetricsEnvironment.class, times(1));
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/status/BeamFnStatusClientTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/status/BeamFnStatusClientTest.java
index c0229f23fbe..67f040abbc9 100644
--- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/status/BeamFnStatusClientTest.java
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/status/BeamFnStatusClientTest.java
@@ -35,6 +35,8 @@ import java.util.UUID;
 import java.util.concurrent.BlockingQueue;
 import java.util.concurrent.LinkedBlockingQueue;
 import org.apache.beam.fn.harness.Caches;
+import org.apache.beam.fn.harness.control.ExecutionStateSampler.ExecutionStateTracker;
+import org.apache.beam.fn.harness.control.ExecutionStateSampler.ExecutionStateTrackerStatus;
 import org.apache.beam.fn.harness.control.ProcessBundleHandler;
 import org.apache.beam.fn.harness.control.ProcessBundleHandler.BundleProcessor;
 import org.apache.beam.fn.harness.control.ProcessBundleHandler.BundleProcessorCache;
@@ -42,7 +44,6 @@ import org.apache.beam.model.fnexecution.v1.BeamFnApi.WorkerStatusRequest;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.WorkerStatusResponse;
 import org.apache.beam.model.fnexecution.v1.BeamFnWorkerStatusGrpc.BeamFnWorkerStatusImplBase;
 import org.apache.beam.model.pipeline.v1.Endpoints;
-import org.apache.beam.runners.core.metrics.ExecutionStateTracker;
 import org.apache.beam.sdk.fn.channel.ManagedChannelFactory;
 import org.apache.beam.sdk.fn.test.TestStreams;
 import org.apache.beam.sdk.options.PipelineOptionsFactory;
@@ -54,6 +55,7 @@ import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.JUnit4;
 
+/** Tests for {@link BeamFnStatusClient}. */
 @RunWith(JUnit4.class)
 public class BeamFnStatusClientTest {
   private final Endpoints.ApiServiceDescriptor apiServiceDescriptor =
@@ -70,9 +72,10 @@ public class BeamFnStatusClientTest {
       BundleProcessor processor = mock(BundleProcessor.class);
       ExecutionStateTracker executionStateTracker = mock(ExecutionStateTracker.class);
       when(processor.getStateTracker()).thenReturn(executionStateTracker);
-      when(executionStateTracker.getMillisSinceLastTransition())
-          .thenReturn(Integer.toUnsignedLong((10 - i) * 1000));
-      when(executionStateTracker.getTrackedThread()).thenReturn(Thread.currentThread());
+      when(executionStateTracker.getStatus())
+          .thenReturn(
+              ExecutionStateTrackerStatus.create(
+                  "ptransformId", "ptransformIdName", Thread.currentThread(), i * 1000));
       String instruction = Integer.toString(i);
       when(processorCache.find(instruction)).thenReturn(processor);
       bundleProcessorMap.put(instruction, processor);