You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by lc...@apache.org on 2022/08/16 17:56:25 UTC

[beam] branch master updated: [BEAM-13015, #21250] Remove looking up thread local metrics container holder and object creation on hot path (#22627)

This is an automated email from the ASF dual-hosted git repository.

lcwik pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 9b934037432 [BEAM-13015, #21250] Remove looking up thread local metrics container holder and object creation on hot path (#22627)
9b934037432 is described below

commit 9b934037432b178a6af18c31d7c80784d8b44b95
Author: Luke Cwik <lc...@google.com>
AuthorDate: Tue Aug 16 10:56:18 2022 -0700

    [BEAM-13015, #21250] Remove looking up thread local metrics container holder and object creation on hot path (#22627)
    
    * [BEAM-13015, #21250] Remove looking up thread local metrics container holder and object creation on hot path
    
    * fix-up nullness warning
    
    * Fix IWYU powermock on dataflow worker
---
 .../google-cloud-dataflow-java/worker/build.gradle |  1 +
 .../worker/legacy-worker/build.gradle              |  1 +
 .../beam/sdk/metrics/MetricsEnvironment.java       | 33 +++++++---
 sdks/java/harness/build.gradle                     |  2 -
 .../fn/harness/control/ProcessBundleHandler.java   | 44 ++++++++++++-
 .../harness/data/PCollectionConsumerRegistry.java  | 72 +++++++++++++---------
 .../harness/data/PTransformFunctionRegistry.java   | 29 +++++----
 .../harness/control/ProcessBundleHandlerTest.java  |  7 +++
 .../data/PCollectionConsumerRegistryTest.java      | 52 +++++++++++-----
 .../data/PTransformFunctionRegistryTest.java       | 44 ++++++++-----
 10 files changed, 204 insertions(+), 81 deletions(-)

diff --git a/runners/google-cloud-dataflow-java/worker/build.gradle b/runners/google-cloud-dataflow-java/worker/build.gradle
index e1c12a0caf5..25e8e1eca20 100644
--- a/runners/google-cloud-dataflow-java/worker/build.gradle
+++ b/runners/google-cloud-dataflow-java/worker/build.gradle
@@ -122,6 +122,7 @@ dependencies {
   shadowTest library.java.jsonassert
   shadowTest library.java.junit
   shadowTest library.java.mockito_core
+  shadowTest library.java.powermock
 }
 
 //TODO(https://github.com/apache/beam/issues/19115): checktyle task should be enabled in the future.
diff --git a/runners/google-cloud-dataflow-java/worker/legacy-worker/build.gradle b/runners/google-cloud-dataflow-java/worker/legacy-worker/build.gradle
index b86d193dc9d..11852949e7a 100644
--- a/runners/google-cloud-dataflow-java/worker/legacy-worker/build.gradle
+++ b/runners/google-cloud-dataflow-java/worker/legacy-worker/build.gradle
@@ -244,6 +244,7 @@ dependencies {
     shadowTest library.java.jsonassert
     shadowTest library.java.junit
     shadowTest library.java.mockito_core
+    shadowTest library.java.powermock
 }
 
 project.task('validateShadedJarContainsSlf4jJdk14', dependsOn: 'shadowJar') {
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/metrics/MetricsEnvironment.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/metrics/MetricsEnvironment.java
index d01ce4ecc93..bf6da9bdb32 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/metrics/MetricsEnvironment.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/metrics/MetricsEnvironment.java
@@ -53,12 +53,18 @@ public class MetricsEnvironment {
   private static final AtomicBoolean METRICS_SUPPORTED = new AtomicBoolean(false);
   private static final AtomicBoolean REPORTED_MISSING_CONTAINER = new AtomicBoolean(false);
 
-  private static final ThreadLocal<@Nullable MetricsContainerHolder> CONTAINER_FOR_THREAD =
+  @SuppressWarnings("type.argument.type.incompatible") // object guaranteed to be non-null
+  private static final ThreadLocal<@NonNull MetricsContainerHolder> CONTAINER_FOR_THREAD =
       ThreadLocal.withInitial(MetricsContainerHolder::new);
 
   private static final AtomicReference<@Nullable MetricsContainer> PROCESS_WIDE_METRICS_CONTAINER =
       new AtomicReference<>();
 
+  /** Returns the container holder for the current thread. */
+  public static MetricsEnvironmentState getMetricsEnvironmentStateForCurrentThread() {
+    return CONTAINER_FOR_THREAD.get();
+  }
+
   /**
    * Set the {@link MetricsContainer} for the current thread.
    *
@@ -66,8 +72,6 @@ public class MetricsEnvironment {
    */
   public static @Nullable MetricsContainer setCurrentContainer(
       @Nullable MetricsContainer container) {
-    @SuppressWarnings("nullness") // Non-null due to withInitialValue
-    @NonNull
     MetricsContainerHolder holder = CONTAINER_FOR_THREAD.get();
     @Nullable MetricsContainer previous = holder.container;
     holder.container = container;
@@ -108,7 +112,6 @@ public class MetricsEnvironment {
     private final MetricsContainerHolder holder;
     private final @Nullable MetricsContainer oldContainer;
 
-    @SuppressWarnings("nullness") // Non-null due to withInitialValue
     private ScopedContainer(MetricsContainer newContainer) {
       // It is safe to cache the thread-local holder because it never changes for the thread.
       holder = CONTAINER_FOR_THREAD.get();
@@ -130,7 +133,6 @@ public class MetricsEnvironment {
    * diagnostic message.
    */
   public static @Nullable MetricsContainer getCurrentContainer() {
-    @SuppressWarnings("nullness") // Non-null due to withInitialValue
     MetricsContainer container = CONTAINER_FOR_THREAD.get().container;
     if (container == null && REPORTED_MISSING_CONTAINER.compareAndSet(false, true)) {
       if (isMetricsSupported()) {
@@ -149,7 +151,24 @@ public class MetricsEnvironment {
     return PROCESS_WIDE_METRICS_CONTAINER.get();
   }
 
-  private static class MetricsContainerHolder {
-    public @Nullable MetricsContainer container = null;
+  public static class MetricsContainerHolder implements MetricsEnvironmentState {
+    private @Nullable MetricsContainer container = null;
+
+    @Override
+    public @Nullable MetricsContainer activate(@Nullable MetricsContainer metricsContainer) {
+      MetricsContainer old = container;
+      container = metricsContainer;
+      return old;
+    }
+  }
+
+  /**
+   * Set the {@link MetricsContainer} for the associated {@link MetricsEnvironment}.
+   *
+   * @return The previous container for the associated {@link MetricsEnvironment}.
+   */
+  public interface MetricsEnvironmentState {
+    @Nullable
+    MetricsContainer activate(@Nullable MetricsContainer metricsContainer);
   }
 }
diff --git a/sdks/java/harness/build.gradle b/sdks/java/harness/build.gradle
index 37cb1d352fe..2b69847272d 100644
--- a/sdks/java/harness/build.gradle
+++ b/sdks/java/harness/build.gradle
@@ -63,8 +63,6 @@ dependencies {
     implementation project(it)
   }
   shadow library.java.vendored_guava_26_0_jre
-  shadowTest library.java.powermock
-  shadowTest library.java.powermock_mockito
   implementation library.java.joda_time
   implementation library.java.slf4j_api
   implementation library.java.vendored_grpc_1_48_1
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java
index e6fce8e24c3..bddfcc8c360 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java
@@ -83,6 +83,9 @@ import org.apache.beam.sdk.fn.data.DataEndpoint;
 import org.apache.beam.sdk.fn.data.FnDataReceiver;
 import org.apache.beam.sdk.fn.data.TimerEndpoint;
 import org.apache.beam.sdk.function.ThrowingRunnable;
+import org.apache.beam.sdk.metrics.MetricsContainer;
+import org.apache.beam.sdk.metrics.MetricsEnvironment;
+import org.apache.beam.sdk.metrics.MetricsEnvironment.MetricsEnvironmentState;
 import org.apache.beam.sdk.options.PipelineOptions;
 import org.apache.beam.sdk.transforms.DoFn.BundleFinalizer;
 import org.apache.beam.sdk.util.WindowedValue;
@@ -99,6 +102,7 @@ import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Immutabl
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.SetMultimap;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Sets;
+import org.checkerframework.checker.nullness.qual.Nullable;
 import org.joda.time.Instant;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -513,6 +517,7 @@ public class ProcessBundleHandler {
       ExecutionStateTracker stateTracker = bundleProcessor.getStateTracker();
 
       try (HandleStateCallsForBundle beamFnStateClient = bundleProcessor.getBeamFnStateClient()) {
+        bundleProcessor.getMetricsEnvironmentStateForBundle().start();
         stateTracker.start(request.getInstructionId());
         try {
           // Already in reverse topological order so we don't need to do anything.
@@ -730,6 +735,24 @@ public class ProcessBundleHandler {
     bundleProcessorCache.shutdown();
   }
 
+  @VisibleForTesting
+  static class MetricsEnvironmentStateForBundle implements MetricsEnvironmentState {
+    private @Nullable MetricsEnvironmentState currentThreadState;
+
+    @Override
+    public @Nullable MetricsContainer activate(@Nullable MetricsContainer metricsContainer) {
+      return currentThreadState.activate(metricsContainer);
+    }
+
+    public void start() {
+      currentThreadState = MetricsEnvironment.getMetricsEnvironmentStateForCurrentThread();
+    }
+
+    public void reset() {
+      currentThreadState = null;
+    }
+  }
+
   private BundleProcessor createBundleProcessor(
       String bundleId, BeamFnApi.ProcessBundleRequest processBundleRequest) throws IOException {
     BeamFnApi.ProcessBundleDescriptor bundleDescriptor = fnApiRegistry.apply(bundleId);
@@ -738,11 +761,14 @@ public class ProcessBundleHandler {
     BundleProgressReporter.InMemory bundleProgressReporterAndRegistrar =
         new BundleProgressReporter.InMemory();
     MetricsContainerStepMap metricsContainerRegistry = new MetricsContainerStepMap();
+    MetricsEnvironmentStateForBundle metricsEnvironmentStateForBundle =
+        new MetricsEnvironmentStateForBundle();
     ExecutionStateTracker stateTracker = executionStateSampler.create();
     bundleProgressReporterAndRegistrar.register(stateTracker);
     PCollectionConsumerRegistry pCollectionConsumerRegistry =
         new PCollectionConsumerRegistry(
             metricsContainerRegistry,
+            metricsEnvironmentStateForBundle,
             stateTracker,
             shortIds,
             bundleProgressReporterAndRegistrar,
@@ -751,10 +777,18 @@ public class ProcessBundleHandler {
 
     PTransformFunctionRegistry startFunctionRegistry =
         new PTransformFunctionRegistry(
-            metricsContainerRegistry, shortIds, stateTracker, Urns.START_BUNDLE_MSECS);
+            metricsContainerRegistry,
+            metricsEnvironmentStateForBundle,
+            shortIds,
+            stateTracker,
+            Urns.START_BUNDLE_MSECS);
     PTransformFunctionRegistry finishFunctionRegistry =
         new PTransformFunctionRegistry(
-            metricsContainerRegistry, shortIds, stateTracker, Urns.FINISH_BUNDLE_MSECS);
+            metricsContainerRegistry,
+            metricsEnvironmentStateForBundle,
+            shortIds,
+            stateTracker,
+            Urns.FINISH_BUNDLE_MSECS);
     List<ThrowingRunnable> resetFunctions = new ArrayList<>();
     List<ThrowingRunnable> tearDownFunctions = new ArrayList<>();
 
@@ -802,6 +836,7 @@ public class ProcessBundleHandler {
             splitListener,
             pCollectionConsumerRegistry,
             metricsContainerRegistry,
+            metricsEnvironmentStateForBundle,
             stateTracker,
             beamFnStateClient,
             bundleFinalizationCallbackRegistrations,
@@ -994,6 +1029,7 @@ public class ProcessBundleHandler {
         BundleSplitListener.InMemory splitListener,
         PCollectionConsumerRegistry pCollectionConsumerRegistry,
         MetricsContainerStepMap metricsContainerRegistry,
+        MetricsEnvironmentStateForBundle metricsEnvironmentStateForBundle,
         ExecutionStateTracker stateTracker,
         HandleStateCallsForBundle beamFnStateClient,
         Collection<CallbackRegistration> bundleFinalizationCallbackRegistrations,
@@ -1009,6 +1045,7 @@ public class ProcessBundleHandler {
           splitListener,
           pCollectionConsumerRegistry,
           metricsContainerRegistry,
+          metricsEnvironmentStateForBundle,
           stateTracker,
           beamFnStateClient,
           /*inboundEndpointApiServiceDescriptors=*/ new ArrayList<>(),
@@ -1046,6 +1083,8 @@ public class ProcessBundleHandler {
 
     abstract MetricsContainerStepMap getMetricsContainerRegistry();
 
+    abstract MetricsEnvironmentStateForBundle getMetricsEnvironmentStateForBundle();
+
     public abstract ExecutionStateTracker getStateTracker();
 
     abstract HandleStateCallsForBundle getBeamFnStateClient();
@@ -1114,6 +1153,7 @@ public class ProcessBundleHandler {
       }
       getSplitListener().clear();
       getMetricsContainerRegistry().reset();
+      getMetricsEnvironmentStateForBundle().reset();
       getStateTracker().reset();
       getBundleFinalizationCallbackRegistrations().clear();
       for (ThrowingRunnable resetFunction : getResetFunctions()) {
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistry.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistry.java
index 34d0967b85b..556076ed3b1 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistry.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistry.java
@@ -18,7 +18,6 @@
 package org.apache.beam.fn.harness.data;
 
 import com.google.auto.value.AutoValue;
-import java.io.Closeable;
 import java.io.IOException;
 import java.util.ArrayList;
 import java.util.HashMap;
@@ -48,7 +47,7 @@ import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.fn.data.FnDataReceiver;
 import org.apache.beam.sdk.metrics.Distribution;
 import org.apache.beam.sdk.metrics.MetricsContainer;
-import org.apache.beam.sdk.metrics.MetricsEnvironment;
+import org.apache.beam.sdk.metrics.MetricsEnvironment.MetricsEnvironmentState;
 import org.apache.beam.sdk.util.WindowedValue;
 import org.apache.beam.sdk.util.WindowedValue.WindowedValueCoder;
 import org.apache.beam.sdk.util.common.ElementByteSizeObserver;
@@ -90,6 +89,7 @@ public class PCollectionConsumerRegistry {
   }
 
   private final MetricsContainerStepMap metricsContainerRegistry;
+  private final MetricsEnvironmentState metricsEnvironmentState;
   private final ExecutionStateTracker stateTracker;
   private final ShortIdMap shortIdMap;
   private final Map<String, List<ConsumerAndMetadata>> pCollectionIdsToConsumers;
@@ -100,11 +100,13 @@ public class PCollectionConsumerRegistry {
 
   public PCollectionConsumerRegistry(
       MetricsContainerStepMap metricsContainerRegistry,
+      MetricsEnvironmentState metricsEnvironmentState,
       ExecutionStateTracker stateTracker,
       ShortIdMap shortIdMap,
       BundleProgressReporter.Registrar bundleProgressReporterRegistrar,
       ProcessBundleDescriptor processBundleDescriptor) {
     this.metricsContainerRegistry = metricsContainerRegistry;
+    this.metricsEnvironmentState = metricsEnvironmentState;
     this.stateTracker = stateTracker;
     this.shortIdMap = shortIdMap;
     this.pCollectionIdsToConsumers = new HashMap<>();
@@ -217,16 +219,18 @@ public class PCollectionConsumerRegistry {
           if (consumerAndMetadatas.size() == 1) {
             ConsumerAndMetadata consumerAndMetadata = consumerAndMetadatas.get(0);
             if (consumerAndMetadata.getConsumer() instanceof HandlesSplits) {
-              return new SplittingMetricTrackingFnDataReceiver(pcId, coder, consumerAndMetadata);
+              return new SplittingMetricTrackingFnDataReceiver(
+                  pcId, coder, consumerAndMetadata, metricsEnvironmentState);
             }
-            return new MetricTrackingFnDataReceiver(pcId, coder, consumerAndMetadata);
+            return new MetricTrackingFnDataReceiver(
+                pcId, coder, consumerAndMetadata, metricsEnvironmentState);
           } else {
             /* TODO(SDF), Consider supporting splitting each consumer individually. This would never
             come up in the existing SDF expansion, but might be useful to support fused SDF nodes.
             This would require dedicated delivery of the split results to each of the consumers
             separately. */
             return new MultiplexingMetricTrackingFnDataReceiver(
-                pcId, coder, ImmutableList.copyOf(consumerAndMetadatas));
+                pcId, coder, ImmutableList.copyOf(consumerAndMetadatas), metricsEnvironmentState);
           }
         });
   }
@@ -240,16 +244,20 @@ public class PCollectionConsumerRegistry {
    */
   private class MetricTrackingFnDataReceiver<T> implements FnDataReceiver<WindowedValue<T>> {
     private final FnDataReceiver<WindowedValue<T>> delegate;
-    private final ExecutionState state;
+    private final ExecutionState executionState;
     private final BundleCounter elementCountCounter;
     private final SampleByteSizeDistribution<T> sampledByteSizeDistribution;
     private final Coder<T> coder;
     private final MetricsContainer metricsContainer;
+    private final MetricsEnvironmentState metricsEnvironmentState;
 
     public MetricTrackingFnDataReceiver(
-        String pCollectionId, Coder<T> coder, ConsumerAndMetadata consumerAndMetadata) {
+        String pCollectionId,
+        Coder<T> coder,
+        ConsumerAndMetadata consumerAndMetadata,
+        MetricsEnvironmentState metricsEnvironmentState) {
       this.delegate = consumerAndMetadata.getConsumer();
-      this.state = consumerAndMetadata.getExecutionState();
+      this.executionState = consumerAndMetadata.getExecutionState();
 
       HashMap<String, String> labels = new HashMap<>();
       labels.put(Labels.PCOLLECTION, pCollectionId);
@@ -284,6 +292,7 @@ public class PCollectionConsumerRegistry {
 
       this.coder = coder;
       this.metricsContainer = consumerAndMetadata.getMetricsContainer();
+      this.metricsEnvironmentState = metricsEnvironmentState;
     }
 
     @Override
@@ -298,13 +307,13 @@ public class PCollectionConsumerRegistry {
       // PTransform context. This ensures that user metrics obtain the pTransform ID when they are
       // created. Also use the ExecutionStateTracker and enter an appropriate state to track the
       // Process Bundle Execution time metric.
-      try (Closeable closeable = MetricsEnvironment.scopedMetricsContainer(metricsContainer)) {
-        state.activate();
-        try {
-          this.delegate.accept(input);
-        } finally {
-          state.deactivate();
-        }
+      MetricsContainer oldContainer = metricsEnvironmentState.activate(metricsContainer);
+      executionState.activate();
+      try {
+        this.delegate.accept(input);
+      } finally {
+        executionState.deactivate();
+        metricsEnvironmentState.activate(oldContainer);
       }
       this.sampledByteSizeDistribution.finishLazyUpdate();
     }
@@ -320,13 +329,18 @@ public class PCollectionConsumerRegistry {
   private class MultiplexingMetricTrackingFnDataReceiver<T>
       implements FnDataReceiver<WindowedValue<T>> {
     private final List<ConsumerAndMetadata> consumerAndMetadatas;
+    private final MetricsEnvironmentState metricsEnvironmentState;
     private final BundleCounter elementCountCounter;
     private final SampleByteSizeDistribution<T> sampledByteSizeDistribution;
     private final Coder<T> coder;
 
     public MultiplexingMetricTrackingFnDataReceiver(
-        String pCollectionId, Coder<T> coder, List<ConsumerAndMetadata> consumerAndMetadatas) {
+        String pCollectionId,
+        Coder<T> coder,
+        List<ConsumerAndMetadata> consumerAndMetadatas,
+        MetricsEnvironmentState metricsEnvironmentState) {
       this.consumerAndMetadatas = consumerAndMetadatas;
+      this.metricsEnvironmentState = metricsEnvironmentState;
 
       HashMap<String, String> labels = new HashMap<>();
       labels.put(Labels.PCOLLECTION, pCollectionId);
@@ -375,16 +389,15 @@ public class PCollectionConsumerRegistry {
       // created. Also use the ExecutionStateTracker and enter an appropriate state to track the
       // Process Bundle Execution time metric.
       for (ConsumerAndMetadata consumerAndMetadata : consumerAndMetadatas) {
-
-        try (Closeable closeable =
-            MetricsEnvironment.scopedMetricsContainer(consumerAndMetadata.getMetricsContainer())) {
-          ExecutionState state = consumerAndMetadata.getExecutionState();
-          state.activate();
-          try {
-            consumerAndMetadata.getConsumer().accept(input);
-          } finally {
-            state.deactivate();
-          }
+        MetricsContainer oldContainer =
+            metricsEnvironmentState.activate(consumerAndMetadata.getMetricsContainer());
+        ExecutionState state = consumerAndMetadata.getExecutionState();
+        state.activate();
+        try {
+          consumerAndMetadata.getConsumer().accept(input);
+        } finally {
+          state.deactivate();
+          metricsEnvironmentState.activate(oldContainer);
         }
         this.sampledByteSizeDistribution.finishLazyUpdate();
       }
@@ -403,8 +416,11 @@ public class PCollectionConsumerRegistry {
     private final HandlesSplits delegate;
 
     public SplittingMetricTrackingFnDataReceiver(
-        String pCollection, Coder<T> coder, ConsumerAndMetadata consumerAndMetadata) {
-      super(pCollection, coder, consumerAndMetadata);
+        String pCollection,
+        Coder<T> coder,
+        ConsumerAndMetadata consumerAndMetadata,
+        MetricsEnvironmentState metricsEnvironmentState) {
+      super(pCollection, coder, consumerAndMetadata, metricsEnvironmentState);
       this.delegate = (HandlesSplits) consumerAndMetadata.getConsumer();
     }
 
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/PTransformFunctionRegistry.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/PTransformFunctionRegistry.java
index 2b13e02c610..f6bd008c424 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/PTransformFunctionRegistry.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/PTransformFunctionRegistry.java
@@ -17,20 +17,19 @@
  */
 package org.apache.beam.fn.harness.data;
 
-import java.io.Closeable;
 import java.util.ArrayList;
 import java.util.List;
 import org.apache.beam.fn.harness.control.ExecutionStateSampler.ExecutionState;
 import org.apache.beam.fn.harness.control.ExecutionStateSampler.ExecutionStateTracker;
 import org.apache.beam.model.pipeline.v1.MetricsApi.MonitoringInfo;
-import org.apache.beam.runners.core.metrics.MetricsContainerImpl;
 import org.apache.beam.runners.core.metrics.MetricsContainerStepMap;
 import org.apache.beam.runners.core.metrics.MonitoringInfoConstants;
 import org.apache.beam.runners.core.metrics.MonitoringInfoConstants.Urns;
 import org.apache.beam.runners.core.metrics.ShortIdMap;
 import org.apache.beam.runners.core.metrics.SimpleMonitoringInfoBuilder;
 import org.apache.beam.sdk.function.ThrowingRunnable;
-import org.apache.beam.sdk.metrics.MetricsEnvironment;
+import org.apache.beam.sdk.metrics.MetricsContainer;
+import org.apache.beam.sdk.metrics.MetricsEnvironment.MetricsEnvironmentState;
 
 /**
  * A class to to register and retrieve functions for bundle processing (i.e. the start, or finish
@@ -58,7 +57,8 @@ import org.apache.beam.sdk.metrics.MetricsEnvironment;
  */
 public class PTransformFunctionRegistry {
 
-  private MetricsContainerStepMap metricsContainerRegistry;
+  private final MetricsContainerStepMap metricsContainerRegistry;
+  private final MetricsEnvironmentState metricsEnvironmentState;
   private final ExecutionStateTracker stateTracker;
   private final String executionStateUrn;
   private final ShortIdMap shortIds;
@@ -70,11 +70,15 @@ public class PTransformFunctionRegistry {
    *
    * @param metricsContainerRegistry - Used to enable a metric container to properly account for the
    *     pTransform in user metrics.
+   * @param metricsEnvironmentState - Used to activate which metrics container receives counter
+   *     updates.
+   * @param shortIds - Provides short ids for {@link MonitoringInfo}.
    * @param stateTracker - The tracker to enter states in order to calculate execution time metrics.
    * @param executionStateUrn - The URN for the execution state .
    */
   public PTransformFunctionRegistry(
       MetricsContainerStepMap metricsContainerRegistry,
+      MetricsEnvironmentState metricsEnvironmentState,
       ShortIdMap shortIds,
       ExecutionStateTracker stateTracker,
       String executionStateUrn) {
@@ -89,6 +93,7 @@ public class PTransformFunctionRegistry {
         throw new IllegalArgumentException(String.format("Unknown URN %s", executionStateUrn));
     }
     this.metricsContainerRegistry = metricsContainerRegistry;
+    this.metricsEnvironmentState = metricsEnvironmentState;
     this.shortIds = shortIds;
     this.executionStateUrn = executionStateUrn;
     this.stateTracker = stateTracker;
@@ -118,17 +123,17 @@ public class PTransformFunctionRegistry {
     ExecutionState executionState =
         stateTracker.create(shortId, pTransformId, pTransformUniqueName, stateName);
 
-    MetricsContainerImpl container = metricsContainerRegistry.getContainer(pTransformId);
+    MetricsContainer container = metricsContainerRegistry.getContainer(pTransformId);
 
     ThrowingRunnable wrapped =
         () -> {
-          try (Closeable metricCloseable = MetricsEnvironment.scopedMetricsContainer(container)) {
-            executionState.activate();
-            try {
-              runnable.run();
-            } finally {
-              executionState.deactivate();
-            }
+          MetricsContainer oldContainer = metricsEnvironmentState.activate(container);
+          executionState.activate();
+          try {
+            runnable.run();
+          } finally {
+            executionState.deactivate();
+            metricsEnvironmentState.activate(oldContainer);
           }
         };
     runnables.add(wrapped);
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java
index 343f80cf6e9..404e63b3edf 100644
--- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java
@@ -74,6 +74,7 @@ import org.apache.beam.fn.harness.control.ExecutionStateSampler.ExecutionStateTr
 import org.apache.beam.fn.harness.control.FinalizeBundleHandler.CallbackRegistration;
 import org.apache.beam.fn.harness.control.ProcessBundleHandler.BundleProcessor;
 import org.apache.beam.fn.harness.control.ProcessBundleHandler.BundleProcessorCache;
+import org.apache.beam.fn.harness.control.ProcessBundleHandler.MetricsEnvironmentStateForBundle;
 import org.apache.beam.fn.harness.data.BeamFnDataClient;
 import org.apache.beam.fn.harness.data.PCollectionConsumerRegistry;
 import org.apache.beam.fn.harness.data.PTransformFunctionRegistry;
@@ -294,6 +295,11 @@ public class ProcessBundleHandlerTest {
       return wrappedBundleProcessor.getMetricsContainerRegistry();
     }
 
+    @Override
+    MetricsEnvironmentStateForBundle getMetricsEnvironmentStateForBundle() {
+      return wrappedBundleProcessor.getMetricsEnvironmentStateForBundle();
+    }
+
     @Override
     public ExecutionStateTracker getStateTracker() {
       return wrappedBundleProcessor.getStateTracker();
@@ -742,6 +748,7 @@ public class ProcessBundleHandlerTest {
             splitListener,
             pCollectionConsumerRegistry,
             metricsContainerRegistry,
+            new MetricsEnvironmentStateForBundle(),
             stateTracker,
             beamFnStateClient,
             bundleFinalizationCallbacks,
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistryTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistryTest.java
index acf3a60ab4e..f65237c986e 100644
--- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistryTest.java
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistryTest.java
@@ -27,7 +27,7 @@ import static org.mockito.Mockito.doThrow;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
-import static org.powermock.api.mockito.PowerMockito.mockStatic;
+import static org.mockito.Mockito.when;
 
 import java.util.ArrayList;
 import java.util.Arrays;
@@ -52,7 +52,9 @@ import org.apache.beam.runners.core.metrics.SimpleMonitoringInfoBuilder;
 import org.apache.beam.sdk.coders.IterableCoder;
 import org.apache.beam.sdk.coders.StringUtf8Coder;
 import org.apache.beam.sdk.fn.data.FnDataReceiver;
+import org.apache.beam.sdk.metrics.MetricsContainer;
 import org.apache.beam.sdk.metrics.MetricsEnvironment;
+import org.apache.beam.sdk.metrics.MetricsEnvironment.MetricsEnvironmentState;
 import org.apache.beam.sdk.options.PipelineOptionsFactory;
 import org.apache.beam.sdk.util.WindowedValue;
 import org.apache.beam.sdk.util.common.ElementByteSizeObservableIterable;
@@ -65,14 +67,13 @@ import org.junit.Rule;
 import org.junit.Test;
 import org.junit.rules.ExpectedException;
 import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+import org.mockito.InOrder;
+import org.mockito.Mockito;
 import org.mockito.stubbing.Answer;
-import org.powermock.api.mockito.PowerMockito;
-import org.powermock.core.classloader.annotations.PrepareForTest;
-import org.powermock.modules.junit4.PowerMockRunner;
 
 /** Tests for {@link PCollectionConsumerRegistryTest}. */
-@RunWith(PowerMockRunner.class)
-@PrepareForTest(MetricsEnvironment.class)
+@RunWith(JUnit4.class)
 @SuppressWarnings({
   "rawtypes", // TODO(https://github.com/apache/beam/issues/20447)
 })
@@ -126,6 +127,7 @@ public class PCollectionConsumerRegistryTest {
     PCollectionConsumerRegistry consumers =
         new PCollectionConsumerRegistry(
             metricsContainerRegistry,
+            MetricsEnvironment::setCurrentContainer,
             sampler.create(),
             shortIds,
             reporterAndRegistrar,
@@ -187,6 +189,7 @@ public class PCollectionConsumerRegistryTest {
     PCollectionConsumerRegistry consumers =
         new PCollectionConsumerRegistry(
             metricsContainerRegistry,
+            MetricsEnvironment::setCurrentContainer,
             sampler.create(),
             shortIds,
             reporterAndRegistrar,
@@ -214,6 +217,7 @@ public class PCollectionConsumerRegistryTest {
     PCollectionConsumerRegistry consumers =
         new PCollectionConsumerRegistry(
             metricsContainerRegistry,
+            MetricsEnvironment::setCurrentContainer,
             sampler.create(),
             shortIds,
             reporterAndRegistrar,
@@ -273,6 +277,7 @@ public class PCollectionConsumerRegistryTest {
     PCollectionConsumerRegistry consumers =
         new PCollectionConsumerRegistry(
             metricsContainerRegistry,
+            MetricsEnvironment::setCurrentContainer,
             sampler.create(),
             shortIds,
             reporterAndRegistrar,
@@ -337,6 +342,7 @@ public class PCollectionConsumerRegistryTest {
     PCollectionConsumerRegistry consumers =
         new PCollectionConsumerRegistry(
             metricsContainerRegistry,
+            MetricsEnvironment::setCurrentContainer,
             sampler.create(),
             shortIds,
             reporterAndRegistrar,
@@ -367,6 +373,7 @@ public class PCollectionConsumerRegistryTest {
     PCollectionConsumerRegistry consumers =
         new PCollectionConsumerRegistry(
             metricsContainerRegistry,
+            MetricsEnvironment::setCurrentContainer,
             sampler.create(),
             shortIds,
             reporterAndRegistrar,
@@ -383,8 +390,8 @@ public class PCollectionConsumerRegistryTest {
   }
 
   @Test
-  public void testScopedMetricContainerInvokedUponAcceptingElement() throws Exception {
-    mockStatic(MetricsEnvironment.class);
+  public void testMetricContainerUpdatedUponAcceptingElement() throws Exception {
+    MetricsEnvironmentState metricsEnvironmentState = mock(MetricsEnvironmentState.class);
 
     MetricsContainerStepMap metricsContainerRegistry = new MetricsContainerStepMap();
     ShortIdMap shortIds = new ShortIdMap();
@@ -392,6 +399,7 @@ public class PCollectionConsumerRegistryTest {
     PCollectionConsumerRegistry consumers =
         new PCollectionConsumerRegistry(
             metricsContainerRegistry,
+            metricsEnvironmentState,
             sampler.create(),
             shortIds,
             reporterAndRegistrar,
@@ -402,6 +410,13 @@ public class PCollectionConsumerRegistryTest {
     consumers.register(P_COLLECTION_A, "pTransformA", "pTransformAName", consumerA1);
     consumers.register(P_COLLECTION_A, "pTransformB", "pTransformBName", consumerA2);
 
+    // Test both cases; when there is an existing container and where there is no container
+    MetricsContainer oldContainer = mock(MetricsContainer.class);
+    when(metricsEnvironmentState.activate(metricsContainerRegistry.getContainer("pTransformA")))
+        .thenReturn(oldContainer);
+    when(metricsEnvironmentState.activate(metricsContainerRegistry.getContainer("pTransformB")))
+        .thenReturn(null);
+
     FnDataReceiver<WindowedValue<String>> wrapperConsumer =
         (FnDataReceiver<WindowedValue<String>>)
             (FnDataReceiver) consumers.getMultiplexingConsumer(P_COLLECTION_A);
@@ -409,13 +424,18 @@ public class PCollectionConsumerRegistryTest {
     WindowedValue<String> element = valueInGlobalWindow("elem");
     wrapperConsumer.accept(element);
 
-    // Verify that static scopedMetricsContainer is called with pTransformA's container.
-    PowerMockito.verifyStatic(MetricsEnvironment.class, times(1));
-    MetricsEnvironment.scopedMetricsContainer(metricsContainerRegistry.getContainer("pTransformA"));
-
-    // Verify that static scopedMetricsContainer is called with pTransformB's container.
-    PowerMockito.verifyStatic(MetricsEnvironment.class, times(1));
-    MetricsEnvironment.scopedMetricsContainer(metricsContainerRegistry.getContainer("pTransformB"));
+    // Verify that metrics environment state is updated with pTransformA's container, then reset to
+    // the oldContainer, then pTransformB's container and then reset to null.
+    InOrder inOrder = Mockito.inOrder(metricsEnvironmentState);
+    inOrder
+        .verify(metricsEnvironmentState)
+        .activate(metricsContainerRegistry.getContainer("pTransformA"));
+    inOrder.verify(metricsEnvironmentState).activate(oldContainer);
+    inOrder
+        .verify(metricsEnvironmentState)
+        .activate(metricsContainerRegistry.getContainer("pTransformB"));
+    inOrder.verify(metricsEnvironmentState).activate(null);
+    inOrder.verifyNoMoreInteractions();
   }
 
   @Test
@@ -428,6 +448,7 @@ public class PCollectionConsumerRegistryTest {
     PCollectionConsumerRegistry consumers =
         new PCollectionConsumerRegistry(
             metricsContainerRegistry,
+            MetricsEnvironment::setCurrentContainer,
             sampler.create(),
             shortIds,
             reporterAndRegistrar,
@@ -459,6 +480,7 @@ public class PCollectionConsumerRegistryTest {
     PCollectionConsumerRegistry consumers =
         new PCollectionConsumerRegistry(
             metricsContainerRegistry,
+            MetricsEnvironment::setCurrentContainer,
             sampler.create(),
             shortIds,
             reporterAndRegistrar,
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/PTransformFunctionRegistryTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/PTransformFunctionRegistryTest.java
index 35025a61a71..7def4286258 100644
--- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/PTransformFunctionRegistryTest.java
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/PTransformFunctionRegistryTest.java
@@ -21,8 +21,7 @@ import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertNotNull;
 import static org.junit.Assert.assertTrue;
 import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.times;
-import static org.powermock.api.mockito.PowerMockito.mockStatic;
+import static org.mockito.Mockito.when;
 
 import java.util.concurrent.atomic.AtomicBoolean;
 import org.apache.beam.fn.harness.control.ExecutionStateSampler;
@@ -32,19 +31,20 @@ import org.apache.beam.runners.core.metrics.MetricsContainerStepMap;
 import org.apache.beam.runners.core.metrics.MonitoringInfoConstants.Urns;
 import org.apache.beam.runners.core.metrics.ShortIdMap;
 import org.apache.beam.sdk.function.ThrowingRunnable;
+import org.apache.beam.sdk.metrics.MetricsContainer;
 import org.apache.beam.sdk.metrics.MetricsEnvironment;
+import org.apache.beam.sdk.metrics.MetricsEnvironment.MetricsEnvironmentState;
 import org.apache.beam.sdk.options.PipelineOptionsFactory;
 import org.junit.After;
 import org.junit.Before;
 import org.junit.Test;
 import org.junit.runner.RunWith;
-import org.powermock.api.mockito.PowerMockito;
-import org.powermock.core.classloader.annotations.PrepareForTest;
-import org.powermock.modules.junit4.PowerMockRunner;
+import org.junit.runners.JUnit4;
+import org.mockito.InOrder;
+import org.mockito.Mockito;
 
 /** Tests for {@link PTransformFunctionRegistry}. */
-@RunWith(PowerMockRunner.class)
-@PrepareForTest(MetricsEnvironment.class)
+@RunWith(JUnit4.class)
 public class PTransformFunctionRegistryTest {
 
   private ExecutionStateSampler sampler;
@@ -65,6 +65,7 @@ public class PTransformFunctionRegistryTest {
     PTransformFunctionRegistry testObject =
         new PTransformFunctionRegistry(
             mock(MetricsContainerStepMap.class),
+            MetricsEnvironment::setCurrentContainer,
             new ShortIdMap(),
             executionStateTracker,
             Urns.START_BUNDLE_MSECS);
@@ -110,12 +111,13 @@ public class PTransformFunctionRegistryTest {
 
   @Test
   public void testMetricsUponRunningFunctions() throws Exception {
+    MetricsEnvironmentState metricsEnvironmentState = mock(MetricsEnvironmentState.class);
     ExecutionStateTracker executionStateTracker = sampler.create();
-    mockStatic(MetricsEnvironment.class);
     MetricsContainerStepMap metricsContainerRegistry = new MetricsContainerStepMap();
     PTransformFunctionRegistry testObject =
         new PTransformFunctionRegistry(
             metricsContainerRegistry,
+            metricsEnvironmentState,
             new ShortIdMap(),
             executionStateTracker,
             Urns.START_BUNDLE_MSECS);
@@ -125,18 +127,30 @@ public class PTransformFunctionRegistryTest {
     testObject.register("pTransformA", "pTranformAName", runnableA);
     testObject.register("pTransformB", "pTranformBName", runnableB);
 
+    // Test both cases; when there is an existing container and where there is no container
+    MetricsContainer oldContainer = mock(MetricsContainer.class);
+    when(metricsEnvironmentState.activate(metricsContainerRegistry.getContainer("pTransformA")))
+        .thenReturn(oldContainer);
+    when(metricsEnvironmentState.activate(metricsContainerRegistry.getContainer("pTransformB")))
+        .thenReturn(null);
+
     executionStateTracker.start("testBundleId");
     for (ThrowingRunnable func : testObject.getFunctions()) {
       func.run();
     }
     executionStateTracker.reset();
 
-    // Verify that static scopedMetricsContainer is called with pTransformA's container.
-    PowerMockito.verifyStatic(MetricsEnvironment.class, times(1));
-    MetricsEnvironment.scopedMetricsContainer(metricsContainerRegistry.getContainer("pTransformA"));
-
-    // Verify that static scopedMetricsContainer is called with pTransformB's container.
-    PowerMockito.verifyStatic(MetricsEnvironment.class, times(1));
-    MetricsEnvironment.scopedMetricsContainer(metricsContainerRegistry.getContainer("pTransformB"));
+    // Verify that metrics environment state is updated with pTransformA's container, then reset to
+    // the oldContainer, then pTransformB's container and then reset to null.
+    InOrder inOrder = Mockito.inOrder(metricsEnvironmentState);
+    inOrder
+        .verify(metricsEnvironmentState)
+        .activate(metricsContainerRegistry.getContainer("pTransformA"));
+    inOrder.verify(metricsEnvironmentState).activate(oldContainer);
+    inOrder
+        .verify(metricsEnvironmentState)
+        .activate(metricsContainerRegistry.getContainer("pTransformB"));
+    inOrder.verify(metricsEnvironmentState).activate(null);
+    inOrder.verifyNoMoreInteractions();
   }
 }