You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by lc...@apache.org on 2022/09/01 15:37:20 UTC

[beam] branch master updated: [fixes #22980] Migrate BeamFnLoggingClient to the new execution state sampler. (#22981)

This is an automated email from the ASF dual-hosted git repository.

lcwik pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 2df47e7657c [fixes #22980] Migrate BeamFnLoggingClient to the new execution state sampler. (#22981)
2df47e7657c is described below

commit 2df47e7657ca2a9c3fd7b3c3fb578913d4ec4ec1
Author: Luke Cwik <lc...@google.com>
AuthorDate: Thu Sep 1 08:37:13 2022 -0700

    [fixes #22980] Migrate BeamFnLoggingClient to the new execution state sampler. (#22981)
    
    * [fixes #22980] Migrate BeamFnLoggingClient to the new execution state sampler.
    
    This fixes the logging issue to include the transform id.
    
    There is now a public unit test that will prevent a similar regression in the future.
    
    * Migrate test to use ExpectedLogs due to logging backend change.
---
 runners/java-fn-execution/build.gradle             |   4 +-
 ...eCountingExecutableStageContextFactoryTest.java |  47 +++-----
 .../fnexecution/control/RemoteExecutionTest.java   | 133 ++++++++++++++++++++-
 .../java/org/apache/beam/fn/harness/FnHarness.java |   1 +
 .../fn/harness/control/ExecutionStateSampler.java  |  29 +++--
 .../fn/harness/logging/BeamFnLoggingClient.java    |  27 +++--
 .../harness/control/ExecutionStateSamplerTest.java |  20 +++-
 7 files changed, 206 insertions(+), 55 deletions(-)

diff --git a/runners/java-fn-execution/build.gradle b/runners/java-fn-execution/build.gradle
index cc72152ffdd..8960f237a8e 100644
--- a/runners/java-fn-execution/build.gradle
+++ b/runners/java-fn-execution/build.gradle
@@ -42,7 +42,9 @@ dependencies {
   testImplementation project(path: ":runners:core-java", configuration: "testRuntimeMigration")
   testImplementation library.java.junit
   testImplementation library.java.mockito_core
-  testRuntimeOnly library.java.slf4j_simple
+  // We want to use jdk logging backend to appropriately simulate logging setup
+  // for RemoteExecutionTest
+  testRuntimeOnly library.java.slf4j_jdk14
 }
 
 test {
diff --git a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/ReferenceCountingExecutableStageContextFactoryTest.java b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/ReferenceCountingExecutableStageContextFactoryTest.java
index 3d407bc2aa7..be28ae06acc 100644
--- a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/ReferenceCountingExecutableStageContextFactoryTest.java
+++ b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/ReferenceCountingExecutableStageContextFactoryTest.java
@@ -17,18 +17,16 @@
  */
 package org.apache.beam.runners.fnexecution.control;
 
-import static org.junit.Assert.assertTrue;
 import static org.mockito.Matchers.any;
 import static org.mockito.Mockito.doThrow;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.when;
 
-import java.io.ByteArrayOutputStream;
-import java.io.PrintStream;
 import org.apache.beam.runners.fnexecution.control.ReferenceCountingExecutableStageContextFactory.Creator;
 import org.apache.beam.runners.fnexecution.provisioning.JobInfo;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Charsets;
+import org.apache.beam.sdk.testing.ExpectedLogs;
 import org.junit.Assert;
+import org.junit.Rule;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.JUnit4;
@@ -37,6 +35,10 @@ import org.junit.runners.JUnit4;
 @RunWith(JUnit4.class)
 public class ReferenceCountingExecutableStageContextFactoryTest {
 
+  @Rule
+  public ExpectedLogs expectedLogs =
+      ExpectedLogs.none(ReferenceCountingExecutableStageContextFactory.class);
+
   @Test
   public void testCreateReuseReleaseCreate() throws Exception {
 
@@ -85,30 +87,17 @@ public class ReferenceCountingExecutableStageContextFactoryTest {
 
   @Test
   public void testCatchThrowablesAndLogThem() throws Exception {
-    PrintStream oldErr = System.err;
-    oldErr.flush();
-    ByteArrayOutputStream baos = new ByteArrayOutputStream();
-    PrintStream newErr = new PrintStream(baos);
-    try {
-      System.setErr(newErr);
-      Creator creator = mock(Creator.class);
-      ExecutableStageContext c1 = mock(ExecutableStageContext.class);
-      when(creator.apply(any(JobInfo.class))).thenReturn(c1);
-      // throw an Throwable and ensure that it is caught and logged.
-      doThrow(new NoClassDefFoundError()).when(c1).close();
-      ReferenceCountingExecutableStageContextFactory factory =
-          ReferenceCountingExecutableStageContextFactory.create(creator, (x) -> true);
-      JobInfo jobA = mock(JobInfo.class);
-      when(jobA.jobId()).thenReturn("jobA");
-      ExecutableStageContext ac1A = factory.get(jobA);
-      factory.release(ac1A);
-      newErr.flush();
-      String output = new String(baos.toByteArray(), Charsets.UTF_8);
-      // Ensure that the error is logged
-      assertTrue(output.contains("Unable to close ExecutableStageContext"));
-    } finally {
-      newErr.flush();
-      System.setErr(oldErr);
-    }
+    Creator creator = mock(Creator.class);
+    ExecutableStageContext c1 = mock(ExecutableStageContext.class);
+    when(creator.apply(any(JobInfo.class))).thenReturn(c1);
+    // throw an Throwable and ensure that it is caught and logged.
+    doThrow(new NoClassDefFoundError()).when(c1).close();
+    ReferenceCountingExecutableStageContextFactory factory =
+        ReferenceCountingExecutableStageContextFactory.create(creator, (x) -> true);
+    JobInfo jobA = mock(JobInfo.class);
+    when(jobA.jobId()).thenReturn("jobA");
+    ExecutableStageContext ac1A = factory.get(jobA);
+    factory.release(ac1A);
+    expectedLogs.verifyError("Unable to close ExecutableStageContext");
   }
 }
diff --git a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java
index 6942f12ee35..2ccbc0174e8 100644
--- a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java
+++ b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java
@@ -23,10 +23,14 @@ import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Prec
 import static org.hamcrest.MatcherAssert.assertThat;
 import static org.hamcrest.Matchers.allOf;
 import static org.hamcrest.Matchers.containsInAnyOrder;
+import static org.hamcrest.Matchers.containsString;
 import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.greaterThanOrEqualTo;
+import static org.hamcrest.Matchers.lessThanOrEqualTo;
 import static org.hamcrest.Matchers.not;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNotNull;
 import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
 
@@ -59,6 +63,7 @@ import java.util.function.Function;
 import org.apache.beam.fn.harness.Caches;
 import org.apache.beam.fn.harness.FnHarness;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.LogEntry.Severity;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleProgressResponse;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleResponse;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleSplitResponse;
@@ -83,7 +88,7 @@ import org.apache.beam.runners.fnexecution.control.ProcessBundleDescriptors.Exec
 import org.apache.beam.runners.fnexecution.control.SdkHarnessClient.BundleProcessor;
 import org.apache.beam.runners.fnexecution.data.GrpcDataService;
 import org.apache.beam.runners.fnexecution.logging.GrpcLoggingService;
-import org.apache.beam.runners.fnexecution.logging.Slf4jLogWriter;
+import org.apache.beam.runners.fnexecution.logging.LogWriter;
 import org.apache.beam.runners.fnexecution.state.GrpcStateService;
 import org.apache.beam.runners.fnexecution.state.StateRequestHandler;
 import org.apache.beam.runners.fnexecution.state.StateRequestHandlers;
@@ -154,6 +159,7 @@ import org.junit.Rule;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.JUnit4;
+import org.slf4j.LoggerFactory;
 
 /**
  * Tests the execution of a pipeline from specification time to executing a single fused stage,
@@ -176,6 +182,7 @@ public class RemoteExecutionTest implements Serializable {
   private transient GrpcFnServer<FnApiControlClientPoolService> controlServer;
   private transient GrpcFnServer<GrpcDataService> dataServer;
   private transient GrpcFnServer<GrpcStateService> stateServer;
+  private transient LogCapturer logCapturer;
   private transient GrpcFnServer<GrpcLoggingService> loggingServer;
   private transient GrpcStateService stateDelegator;
   private transient SdkHarnessClient controlClient;
@@ -184,6 +191,15 @@ public class RemoteExecutionTest implements Serializable {
   private transient ExecutorService sdkHarnessExecutor;
   private transient Future<?> sdkHarnessExecutorFuture;
 
+  private static class LogCapturer implements LogWriter {
+    List<BeamFnApi.LogEntry> capturedLogs = Collections.synchronizedList(new ArrayList<>());
+
+    @Override
+    public void log(BeamFnApi.LogEntry entry) {
+      capturedLogs.add(entry);
+    }
+  }
+
   public void launchSdkHarness(PipelineOptions options) throws Exception {
     // Setup execution-time servers
     ThreadFactory threadFactory = new ThreadFactoryBuilder().setDaemon(true).build();
@@ -196,9 +212,10 @@ public class RemoteExecutionTest implements Serializable {
                 serverExecutor,
                 OutboundObserverFactory.serverDirect()),
             serverFactory);
+    logCapturer = new LogCapturer();
     loggingServer =
         GrpcFnServer.allocatePortAndCreateFor(
-            GrpcLoggingService.forWriter(Slf4jLogWriter.getDefault()), serverFactory);
+            GrpcLoggingService.forWriter(logCapturer), serverFactory);
     stateDelegator = GrpcStateService.create();
     stateServer = GrpcFnServer.allocatePortAndCreateFor(stateDelegator, serverFactory);
 
@@ -253,6 +270,7 @@ public class RemoteExecutionTest implements Serializable {
         throw e;
       }
     }
+    logCapturer = null;
   }
 
   @Test
@@ -313,8 +331,8 @@ public class RemoteExecutionTest implements Serializable {
               (Coder) remoteOutputCoder.getValue(),
               (FnDataReceiver<? super WindowedValue<?>>) outputContents::add));
     }
-    // The impulse example
 
+    // The impulse example
     try (RemoteBundle bundle =
         processor.newBundle(outputReceivers, BundleProgressHandler.ignored())) {
       Iterables.getOnlyElement(bundle.getInputReceivers().values())
@@ -331,6 +349,115 @@ public class RemoteExecutionTest implements Serializable {
     }
   }
 
+  @Test
+  public void testLogging() throws Exception {
+    long startTime = System.currentTimeMillis();
+    launchSdkHarness(PipelineOptionsFactory.create());
+    Pipeline p = Pipeline.create();
+    p.apply("impulse", Impulse.create())
+        .apply(
+            "create",
+            ParDo.of(
+                new DoFn<byte[], String>() {
+                  @ProcessElement
+                  public void process(ProcessContext ctxt) {
+                    ctxt.output("zero");
+                  }
+                }))
+        .apply(
+            "len",
+            ParDo.of(
+                new DoFn<String, Long>() {
+                  @ProcessElement
+                  public void process(ProcessContext ctxt) {
+                    org.slf4j.Logger logger = LoggerFactory.getLogger(RemoteExecutionTest.class);
+                    logger.warn("TEST" + ctxt.element());
+                    logger.error("TEST_EXCEPTION" + ctxt.element(), new Exception());
+                  }
+                }))
+        .apply("addKeys", WithKeys.of("foo"))
+        // Use some unknown coders
+        .setCoder(KvCoder.of(StringUtf8Coder.of(), BigEndianLongCoder.of()))
+        // Force the output to be materialized
+        .apply("gbk", GroupByKey.create());
+
+    RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p);
+    FusedPipeline fused = GreedyPipelineFuser.fuse(pipelineProto);
+    checkState(fused.getFusedStages().size() == 1, "Expected exactly one fused stage");
+    ExecutableStage stage = fused.getFusedStages().iterator().next();
+
+    ExecutableProcessBundleDescriptor descriptor =
+        ProcessBundleDescriptors.fromExecutableStage(
+            "my_stage", stage, dataServer.getApiServiceDescriptor());
+    String ptransformId = null;
+    for (Map.Entry<String, RunnerApi.PTransform> entry :
+        descriptor.getProcessBundleDescriptor().getTransformsMap().entrySet()) {
+      if (entry.getValue().getUniqueName().contains("len")) {
+        ptransformId = entry.getKey();
+      }
+    }
+    assertNotNull(ptransformId);
+    BundleProcessor processor =
+        controlClient.getProcessor(
+            descriptor.getProcessBundleDescriptor(), descriptor.getRemoteInputDestinations());
+    Map<String, ? super Coder<WindowedValue<?>>> remoteOutputCoders =
+        descriptor.getRemoteOutputCoders();
+    Map<String, RemoteOutputReceiver<?>> outputReceivers = new HashMap<>();
+    for (Entry<String, ? super Coder<WindowedValue<?>>> remoteOutputCoder :
+        remoteOutputCoders.entrySet()) {
+      List<? super WindowedValue<?>> outputContents =
+          Collections.synchronizedList(new ArrayList<>());
+      outputReceivers.put(
+          remoteOutputCoder.getKey(),
+          RemoteOutputReceiver.of(
+              (Coder) remoteOutputCoder.getValue(),
+              (FnDataReceiver<? super WindowedValue<?>>) outputContents::add));
+    }
+
+    String instructionId;
+    // Execute a bundle that logs.
+    try (RemoteBundle bundle =
+        processor.newBundle(outputReceivers, BundleProgressHandler.ignored())) {
+      instructionId = bundle.getId();
+      Iterables.getOnlyElement(bundle.getInputReceivers().values())
+          .accept(valueInGlobalWindow(new byte[0]));
+    }
+
+    while (System.currentTimeMillis() - startTime < 30_000L) {
+      BeamFnApi.LogEntry[] logs = logCapturer.capturedLogs.toArray(new BeamFnApi.LogEntry[0]);
+      boolean foundPTransformLog = false;
+      boolean foundExceptionLog = false;
+      for (BeamFnApi.LogEntry log : logs) {
+        assertThat(
+            log.getTimestamp().getSeconds() * 1000 + log.getTimestamp().getNanos() / 1_000_000,
+            allOf(greaterThanOrEqualTo(startTime), lessThanOrEqualTo(System.currentTimeMillis())));
+        assertThat(log.getThread(), not(""));
+        assertThat(log.getLogLocation(), not(""));
+
+        if ("TESTzero".equals(log.getMessage())) {
+          assertThat(log.getSeverity(), equalTo(Severity.Enum.WARN));
+          assertThat(log.getInstructionId(), equalTo(instructionId));
+          assertThat(log.getLogLocation(), equalTo(RemoteExecutionTest.class.getCanonicalName()));
+          assertThat(log.getTransformId(), equalTo(ptransformId));
+          assertThat(log.getTrace(), equalTo(""));
+          foundPTransformLog = true;
+        } else if ("TEST_EXCEPTIONzero".equals(log.getMessage())) {
+          assertThat(log.getSeverity(), equalTo(Severity.Enum.ERROR));
+          assertThat(log.getInstructionId(), equalTo(instructionId));
+          assertThat(log.getLogLocation(), equalTo(RemoteExecutionTest.class.getCanonicalName()));
+          assertThat(log.getTransformId(), equalTo(ptransformId));
+          assertThat(log.getTrace(), containsString("RemoteExecutionTest"));
+          foundExceptionLog = true;
+        }
+      }
+      if (foundPTransformLog && foundExceptionLog) {
+        break;
+      }
+      // Wait till we get more logs from the SDK.
+      Thread.sleep(500);
+    }
+  }
+
   @Test
   public void testBundleProcessorThrowsExecutionExceptionWhenUserCodeThrows() throws Exception {
     launchSdkHarness(PipelineOptionsFactory.create());
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnHarness.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnHarness.java
index 9183d16785a..103fa780328 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnHarness.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnHarness.java
@@ -280,6 +280,7 @@ public class FnHarness {
               metricsShortIds,
               executionStateSampler,
               processWideCache);
+      logging.setProcessBundleHandler(processBundleHandler);
 
       BeamFnStatusClient beamFnStatusClient = null;
       if (statusApiServiceDescriptor != null) {
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ExecutionStateSampler.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ExecutionStateSampler.java
index cd90a9c0a08..51f3ead2df5 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ExecutionStateSampler.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ExecutionStateSampler.java
@@ -178,9 +178,10 @@ public class ExecutionStateSampler {
     // Read by the ExecutionStateSampler, written by the bundle processing thread lazily and
     // frequently.
     private final AtomicLong numTransitionsLazy;
-
-    // Read by multiple threads, read and written by the bundle processing thread lazily.
-    private final AtomicReference<@Nullable ExecutionStateImpl> currentState;
+    // Read and written by the bundle processing thread frequently.
+    private @Nullable ExecutionStateImpl currentState;
+    // Read by multiple threads, written by the bundle processing thread lazily.
+    private final AtomicReference<@Nullable ExecutionStateImpl> currentStateLazy;
     // Read and written by the ExecutionStateSampler thread
     private long transitionsAtLastSample;
 
@@ -189,7 +190,7 @@ public class ExecutionStateSampler {
       this.trackedThread = new AtomicReference<>();
       this.lastTransitionTime = new AtomicLong();
       this.numTransitionsLazy = new AtomicLong();
-      this.currentState = new AtomicReference<>();
+      this.currentStateLazy = new AtomicReference<>();
       this.processBundleId = new AtomicReference<>();
     }
 
@@ -213,7 +214,7 @@ public class ExecutionStateSampler {
      *     approximation, all of that time should be associated with this state.
      */
     private void takeSample(long currentTimeMillis, long millisSinceLastSample) {
-      ExecutionStateImpl currentExecutionState = currentState.get();
+      ExecutionStateImpl currentExecutionState = currentStateLazy.get();
       if (currentExecutionState != null) {
         currentExecutionState.takeSample(millisSinceLastSample);
       }
@@ -263,7 +264,7 @@ public class ExecutionStateSampler {
       }
       long lastTransitionTimeMs = lastTransitionTime.get();
       // We are actively processing a bundle but may have not yet entered into a state.
-      ExecutionStateImpl current = currentState.get();
+      ExecutionStateImpl current = currentStateLazy.get();
       if (current != null) {
         return ExecutionStateTrackerStatus.create(
             current.ptransformId, current.ptransformUniqueName, thread, lastTransitionTimeMs);
@@ -272,6 +273,14 @@ public class ExecutionStateSampler {
       }
     }
 
+    /** Returns the ptransform id of the currently executing thread. */
+    public @Nullable String getCurrentThreadsPTransformId() {
+      if (currentState == null) {
+        return null;
+      }
+      return currentState.ptransformId;
+    }
+
     /** {@link ExecutionState} represents the current state of an execution thread. */
     private class ExecutionStateImpl implements ExecutionState {
       private final String shortId;
@@ -331,15 +340,17 @@ public class ExecutionStateSampler {
 
       @Override
       public void activate() {
-        previousState = currentState.get();
-        currentState.lazySet(this);
+        previousState = currentState;
+        currentState = this;
+        currentStateLazy.lazySet(this);
         numTransitions += 1;
         numTransitionsLazy.lazySet(numTransitions);
       }
 
       @Override
       public void deactivate() {
-        currentState.lazySet(previousState);
+        currentState = previousState;
+        currentStateLazy.lazySet(previousState);
         previousState = null;
 
         numTransitions += 1;
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/logging/BeamFnLoggingClient.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/logging/BeamFnLoggingClient.java
index a9deb195e25..cd99cc424d9 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/logging/BeamFnLoggingClient.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/logging/BeamFnLoggingClient.java
@@ -41,14 +41,12 @@ import java.util.logging.LogManager;
 import java.util.logging.LogRecord;
 import java.util.logging.Logger;
 import java.util.logging.SimpleFormatter;
+import org.apache.beam.fn.harness.control.ProcessBundleHandler;
+import org.apache.beam.fn.harness.control.ProcessBundleHandler.BundleProcessor;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.LogEntry;
 import org.apache.beam.model.fnexecution.v1.BeamFnLoggingGrpc;
 import org.apache.beam.model.pipeline.v1.Endpoints;
-import org.apache.beam.runners.core.metrics.ExecutionStateTracker;
-import org.apache.beam.runners.core.metrics.ExecutionStateTracker.ExecutionState;
-import org.apache.beam.runners.core.metrics.MonitoringInfoConstants;
-import org.apache.beam.runners.core.metrics.SimpleExecutionState;
 import org.apache.beam.sdk.extensions.gcp.options.GcsOptions;
 import org.apache.beam.sdk.options.PipelineOptions;
 import org.apache.beam.sdk.options.SdkHarnessOptions;
@@ -107,6 +105,7 @@ public class BeamFnLoggingClient implements AutoCloseable {
   private final LogRecordHandler logRecordHandler;
   private final CompletableFuture<Object> inboundObserverCompletion;
   private final Phaser phaser;
+  private @Nullable ProcessBundleHandler processBundleHandler;
 
   public BeamFnLoggingClient(
       PipelineOptions options,
@@ -151,6 +150,10 @@ public class BeamFnLoggingClient implements AutoCloseable {
     rootLogger.addHandler(logRecordHandler);
   }
 
+  public void setProcessBundleHandler(ProcessBundleHandler processBundleHandler) {
+    this.processBundleHandler = processBundleHandler;
+  }
+
   @Override
   public void close() throws Exception {
     try {
@@ -227,14 +230,14 @@ public class BeamFnLoggingClient implements AutoCloseable {
       if (loggerName != null) {
         builder.setLogLocation(loggerName);
       }
-      ExecutionState state = ExecutionStateTracker.getCurrentExecutionState(record.getThreadID());
-      if (state instanceof SimpleExecutionState) {
-        String transformId =
-            ((SimpleExecutionState) state)
-                .getLabels()
-                .get(MonitoringInfoConstants.Labels.PTRANSFORM);
-        if (transformId != null) {
-          builder.setTransformId(transformId);
+      if (instructionId != null && processBundleHandler != null) {
+        BundleProcessor bundleProcessor =
+            processBundleHandler.getBundleProcessorCache().find(instructionId);
+        if (bundleProcessor != null) {
+          String transformId = bundleProcessor.getStateTracker().getCurrentThreadsPTransformId();
+          if (transformId != null) {
+            builder.setTransformId(transformId);
+          }
         }
       }
 
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ExecutionStateSamplerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ExecutionStateSamplerTest.java
index d6ff927bd68..72a106a75a7 100644
--- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ExecutionStateSamplerTest.java
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ExecutionStateSamplerTest.java
@@ -110,6 +110,10 @@ public class ExecutionStateSamplerTest {
               }
             });
 
+    // No active PTransform
+    assertNull(tracker1.getCurrentThreadsPTransformId());
+    assertNull(tracker2.getCurrentThreadsPTransformId());
+
     // No tracked thread
     assertNull(tracker1.getStatus());
     assertNull(tracker2.getStatus());
@@ -120,6 +124,10 @@ public class ExecutionStateSamplerTest {
     state1.activate();
     state2.activate();
 
+    // Check that the current threads PTransform id is available
+    assertEquals("ptransformId1", tracker1.getCurrentThreadsPTransformId());
+    assertEquals("ptransformId2", tracker2.getCurrentThreadsPTransformId());
+
     // Check that the status returns a value as soon as it is activated.
     ExecutionStateTrackerStatus activeBundleStatus1 = tracker1.getStatus();
     ExecutionStateTrackerStatus activeBundleStatus2 = tracker2.getStatus();
@@ -142,6 +150,10 @@ public class ExecutionStateSamplerTest {
     waitTillActive.countDown();
     waitForSamples.await();
 
+    // Check that the current threads PTransform id is available
+    assertEquals("ptransformId1", tracker1.getCurrentThreadsPTransformId());
+    assertEquals("ptransformId2", tracker2.getCurrentThreadsPTransformId());
+
     // Check that we get additional data about the active PTransform.
     ExecutionStateTrackerStatus activeStateStatus1 = tracker1.getStatus();
     ExecutionStateTrackerStatus activeStateStatus2 = tracker2.getStatus();
@@ -184,6 +196,10 @@ public class ExecutionStateSamplerTest {
     waitTillStatesDeactivated.countDown();
     waitForEvenMoreSamples.await();
 
+    // Check that the current threads PTransform id is not available
+    assertNull(tracker1.getCurrentThreadsPTransformId());
+    assertNull(tracker2.getCurrentThreadsPTransformId());
+
     // Check the status once the states are deactivated but the bundle is still active
     ExecutionStateTrackerStatus inactiveStateStatus1 = tracker1.getStatus();
     ExecutionStateTrackerStatus inactiveStateStatus2 = tracker2.getStatus();
@@ -221,7 +237,9 @@ public class ExecutionStateSamplerTest {
     tracker1.reset();
     tracker2.reset();
 
-    // Shouldn't have a status returned since there is no active bundle.
+    // Shouldn't have a status or pt ransform id returned since there is no active bundle.
+    assertNull(tracker1.getCurrentThreadsPTransformId());
+    assertNull(tracker2.getCurrentThreadsPTransformId());
     assertNull(tracker1.getStatus());
     assertNull(tracker2.getStatus());