You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by sc...@apache.org on 2018/12/13 18:10:20 UTC

[beam] branch master updated: [BEAM-6138] Add User Counter Metric Support to Java SDK (#6799)

This is an automated email from the ASF dual-hosted git repository.

scott pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 0b3b9e0  [BEAM-6138] Add User Counter Metric Support to Java SDK (#6799)
0b3b9e0 is described below

commit 0b3b9e0b99ad798836b789b181b51d47cc5234a6
Author: Alex Amato <aj...@google.com>
AuthorDate: Thu Dec 13 10:10:12 2018 -0800

    [BEAM-6138] Add User Counter Metric Support to Java SDK (#6799)
---
 .../runners/core/metrics/MetricsContainerImpl.java |  19 +++
 .../fnexecution/control/RemoteExecutionTest.java   | 129 +++++++++++++++++++
 .../fn/harness/control/ProcessBundleHandler.java   |  38 ++++--
 .../beam/fn/harness/FnApiDoFnRunnerTest.java       | 136 +++++++++++++++++++++
 4 files changed, 309 insertions(+), 13 deletions(-)

diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/MetricsContainerImpl.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/MetricsContainerImpl.java
index 95bfa74..9147919 100644
--- a/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/MetricsContainerImpl.java
+++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/MetricsContainerImpl.java
@@ -21,8 +21,10 @@ import static com.google.common.base.Preconditions.checkNotNull;
 
 import com.google.common.collect.ImmutableList;
 import java.io.Serializable;
+import java.util.ArrayList;
 import java.util.Map;
 import javax.annotation.Nullable;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.MonitoringInfo;
 import org.apache.beam.runners.core.construction.metrics.MetricKey;
 import org.apache.beam.runners.core.metrics.MetricUpdates.MetricUpdate;
 import org.apache.beam.sdk.annotations.Experimental;
@@ -136,6 +138,23 @@ public class MetricsContainerImpl implements Serializable, MetricsContainer {
         extractUpdates(counters), extractUpdates(distributions), extractUpdates(gauges));
   }
 
+  /** Return the cumulative values for any metrics in this container as MonitoringInfos. */
+  public Iterable<MonitoringInfo> getMonitoringInfos() {
+    // Extract user metrics and store as MonitoringInfos.
+    ArrayList<MonitoringInfo> monitoringInfos = new ArrayList<MonitoringInfo>();
+    MetricUpdates mus = this.getUpdates();
+
+    for (MetricUpdate<Long> mu : mus.counterUpdates()) {
+      SimpleMonitoringInfoBuilder builder = new SimpleMonitoringInfoBuilder(true);
+      builder.setUrnForUserMetric(
+          mu.getKey().metricName().getNamespace(), mu.getKey().metricName().getName());
+      builder.setInt64Value(mu.getUpdate());
+      builder.setTimestampToNow();
+      monitoringInfos.add(builder.build());
+    }
+    return monitoringInfos;
+  }
+
   private void commitUpdates(MetricsMap<MetricName, ? extends MetricCell<?>> cells) {
     for (MetricCell<?> cell : cells.values()) {
       cell.getDirty().afterCommit();
diff --git a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java
index f53257f..1fdf6db 100644
--- a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java
+++ b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java
@@ -50,12 +50,16 @@ import java.util.concurrent.Future;
 import java.util.concurrent.ThreadFactory;
 import java.util.function.Function;
 import org.apache.beam.fn.harness.FnHarness;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.MonitoringInfo;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleProgressResponse;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleResponse;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.Target;
 import org.apache.beam.model.pipeline.v1.RunnerApi;
 import org.apache.beam.runners.core.construction.PipelineTranslation;
 import org.apache.beam.runners.core.construction.graph.ExecutableStage;
 import org.apache.beam.runners.core.construction.graph.FusedPipeline;
 import org.apache.beam.runners.core.construction.graph.GreedyPipelineFuser;
+import org.apache.beam.runners.core.metrics.SimpleMonitoringInfoBuilder;
 import org.apache.beam.runners.fnexecution.GrpcContextHeaderAccessorProvider;
 import org.apache.beam.runners.fnexecution.GrpcFnServer;
 import org.apache.beam.runners.fnexecution.InProcessServerFactory;
@@ -83,6 +87,7 @@ import org.apache.beam.sdk.coders.VoidCoder;
 import org.apache.beam.sdk.fn.data.FnDataReceiver;
 import org.apache.beam.sdk.fn.stream.OutboundObserverFactory;
 import org.apache.beam.sdk.fn.test.InProcessManagedChannelFactory;
+import org.apache.beam.sdk.metrics.Metrics;
 import org.apache.beam.sdk.options.PipelineOptionsFactory;
 import org.apache.beam.sdk.state.BagState;
 import org.apache.beam.sdk.state.ReadableState;
@@ -109,16 +114,20 @@ import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.PCollectionList;
 import org.apache.beam.sdk.values.PCollectionView;
 import org.apache.beam.vendor.grpc.v1_13_1.com.google.protobuf.ByteString;
+import org.hamcrest.CoreMatchers;
 import org.hamcrest.collection.IsEmptyIterable;
 import org.hamcrest.collection.IsIterableContainingInOrder;
 import org.joda.time.DateTimeUtils;
 import org.joda.time.Duration;
 import org.junit.After;
+import org.junit.Assert;
 import org.junit.Before;
 import org.junit.Rule;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.JUnit4;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 
 /**
  * Tests the execution of a pipeline from specification time to executing a single fused stage,
@@ -128,6 +137,8 @@ import org.junit.runners.JUnit4;
 public class RemoteExecutionTest implements Serializable {
   @Rule public transient ResetDateTimeProvider resetDateTimeProvider = new ResetDateTimeProvider();
 
+  private static final Logger LOG = LoggerFactory.getLogger(RemoteExecutionTest.class);
+
   private transient GrpcFnServer<FnApiControlClientPoolService> controlServer;
   private transient GrpcFnServer<GrpcDataService> dataServer;
   private transient GrpcFnServer<GrpcStateService> stateServer;
@@ -487,6 +498,124 @@ public class RemoteExecutionTest implements Serializable {
   }
 
   @Test
+  public void testMetrics() throws Exception {
+    final String counterMetricName = "counterMetric";
+    Pipeline p = Pipeline.create();
+    PCollection<String> input =
+        p.apply("impulse", Impulse.create())
+            .apply(
+                "create",
+                ParDo.of(
+                    new DoFn<byte[], String>() {
+                      @ProcessElement
+                      public void process(ProcessContext ctxt) {
+                        Metrics.counter(RemoteExecutionTest.class, counterMetricName).inc();
+                      }
+                    }))
+            .setCoder(StringUtf8Coder.of());
+
+    RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p);
+    FusedPipeline fused = GreedyPipelineFuser.fuse(pipelineProto);
+    Optional<ExecutableStage> optionalStage =
+        Iterables.tryFind(fused.getFusedStages(), (ExecutableStage stage) -> true);
+    checkState(optionalStage.isPresent(), "Expected a stage with side inputs.");
+    ExecutableStage stage = optionalStage.get();
+
+    ExecutableProcessBundleDescriptor descriptor =
+        ProcessBundleDescriptors.fromExecutableStage(
+            "test_stage",
+            stage,
+            dataServer.getApiServiceDescriptor(),
+            stateServer.getApiServiceDescriptor());
+
+    BundleProcessor processor =
+        controlClient.getProcessor(
+            descriptor.getProcessBundleDescriptor(),
+            descriptor.getRemoteInputDestinations(),
+            stateDelegator);
+
+    Map<Target, Coder<WindowedValue<?>>> outputTargets = descriptor.getOutputTargetCoders();
+    Map<Target, Collection<WindowedValue<?>>> outputValues = new HashMap<>();
+    Map<Target, RemoteOutputReceiver<?>> outputReceivers = new HashMap<>();
+    for (Entry<Target, Coder<WindowedValue<?>>> targetCoder : outputTargets.entrySet()) {
+      List<WindowedValue<?>> outputContents = Collections.synchronizedList(new ArrayList<>());
+      outputValues.put(targetCoder.getKey(), outputContents);
+      outputReceivers.put(
+          targetCoder.getKey(),
+          RemoteOutputReceiver.of(targetCoder.getValue(), outputContents::add));
+    }
+
+    Iterable<byte[]> sideInputData =
+        Arrays.asList(
+            CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "A"),
+            CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "B"),
+            CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "C"));
+
+    StateRequestHandler stateRequestHandler =
+        StateRequestHandlers.forSideInputHandlerFactory(
+            descriptor.getSideInputSpecs(),
+            new SideInputHandlerFactory() {
+              @Override
+              public <T, V, W extends BoundedWindow> SideInputHandler<V, W> forSideInput(
+                  String pTransformId,
+                  String sideInputId,
+                  RunnerApi.FunctionSpec accessPattern,
+                  Coder<T> elementCoder,
+                  Coder<W> windowCoder) {
+                return new SideInputHandler<V, W>() {
+                  @Override
+                  public Iterable<V> get(byte[] key, W window) {
+                    return (Iterable) sideInputData;
+                  }
+
+                  @Override
+                  public Coder<V> resultCoder() {
+                    return ((KvCoder) elementCoder).getValueCoder();
+                  }
+                };
+              }
+            });
+
+    BundleProgressHandler progressHandler =
+        new BundleProgressHandler() {
+          @Override
+          public void onProgress(ProcessBundleProgressResponse progress) {}
+
+          @Override
+          public void onCompleted(ProcessBundleResponse response) {
+            // Assert the timestamps are non empty then 0 them out before comparing.
+            List<MonitoringInfo> actualMIs = new ArrayList<>();
+            for (MonitoringInfo mi : response.getMonitoringInfosList()) {
+              MonitoringInfo.Builder builder = MonitoringInfo.newBuilder();
+              Assert.assertTrue(mi.getTimestamp().getSeconds() > 0);
+              builder.mergeFrom(mi);
+              builder.clearTimestamp();
+              actualMIs.add(builder.build());
+            }
+
+            SimpleMonitoringInfoBuilder builder = new SimpleMonitoringInfoBuilder();
+            builder.setUrnForUserMetric(RemoteExecutionTest.class.getName(), counterMetricName);
+            builder.setInt64Value(2);
+            MonitoringInfo expectedCounter = builder.build();
+
+            assertThat(actualMIs, CoreMatchers.hasItems(expectedCounter));
+          }
+        };
+
+    try (ActiveBundle bundle =
+        processor.newBundle(outputReceivers, stateRequestHandler, progressHandler)) {
+      Iterables.getOnlyElement(bundle.getInputReceivers().values())
+          .accept(
+              WindowedValue.valueInGlobalWindow(
+                  CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "X")));
+      Iterables.getOnlyElement(bundle.getInputReceivers().values())
+          .accept(
+              WindowedValue.valueInGlobalWindow(
+                  CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "Y")));
+    }
+  }
+
+  @Test
   public void testExecutionWithUserState() throws Exception {
     Pipeline p = Pipeline.create();
     final String stateId = "foo";
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java
index e22a6c2..3018343 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java
@@ -26,6 +26,7 @@ import com.google.common.collect.Lists;
 import com.google.common.collect.Multimap;
 import com.google.common.collect.SetMultimap;
 import com.google.common.collect.Sets;
+import java.io.Closeable;
 import java.io.IOException;
 import java.util.ArrayList;
 import java.util.HashSet;
@@ -47,6 +48,7 @@ import org.apache.beam.fn.harness.state.BeamFnStateGrpcClientCache;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.BundleApplication;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.DelayedBundleApplication;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.MonitoringInfo;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleDescriptor;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleRequest;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleResponse;
@@ -60,8 +62,10 @@ import org.apache.beam.model.pipeline.v1.RunnerApi.PCollection;
 import org.apache.beam.model.pipeline.v1.RunnerApi.PTransform;
 import org.apache.beam.model.pipeline.v1.RunnerApi.WindowingStrategy;
 import org.apache.beam.runners.core.construction.PTransformTranslation;
+import org.apache.beam.runners.core.metrics.MetricsContainerImpl;
 import org.apache.beam.sdk.fn.data.FnDataReceiver;
 import org.apache.beam.sdk.fn.function.ThrowingRunnable;
+import org.apache.beam.sdk.metrics.MetricsEnvironment;
 import org.apache.beam.sdk.options.PipelineOptions;
 import org.apache.beam.sdk.util.WindowedValue;
 import org.apache.beam.sdk.util.common.ReflectHelpers;
@@ -292,21 +296,29 @@ public class ProcessBundleHandler {
             splitListener);
       }
 
-      // Already in reverse topological order so we don't need to do anything.
-      for (ThrowingRunnable startFunction : startFunctions) {
-        LOG.debug("Starting function {}", startFunction);
-        startFunction.run();
-      }
+      MetricsContainerImpl metricsContainer = new MetricsContainerImpl(request.getInstructionId());
+      try (Closeable closeable = MetricsEnvironment.scopedMetricsContainer(metricsContainer)) {
+
+        // Already in reverse topological order so we don't need to do anything.
+        for (ThrowingRunnable startFunction : startFunctions) {
+          LOG.debug("Starting function {}", startFunction);
+          startFunction.run();
+        }
 
-      queueingClient.drainAndBlock();
+        queueingClient.drainAndBlock();
 
-      // Need to reverse this since we want to call finish in topological order.
-      for (ThrowingRunnable finishFunction : Lists.reverse(finishFunctions)) {
-        LOG.debug("Finishing function {}", finishFunction);
-        finishFunction.run();
-      }
-      if (!allResiduals.isEmpty()) {
-        response.addAllResidualRoots(allResiduals.values());
+        // Need to reverse this since we want to call finish in topological order.
+        for (ThrowingRunnable finishFunction : Lists.reverse(finishFunctions)) {
+          LOG.debug("Finishing function {}", finishFunction);
+          finishFunction.run();
+        }
+        if (!allResiduals.isEmpty()) {
+          response.addAllResidualRoots(allResiduals.values());
+        }
+
+        for (MonitoringInfo mi : metricsContainer.getMonitoringInfos()) {
+          response.addMonitoringInfos(mi);
+        }
       }
     }
 
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java
index 9149edb..5d8c413 100644
--- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java
@@ -27,11 +27,13 @@ import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertThat;
 import static org.junit.Assert.fail;
 
+import com.google.auto.value.AutoValue;
 import com.google.common.base.Suppliers;
 import com.google.common.collect.ArrayListMultimap;
 import com.google.common.collect.ImmutableMap;
 import com.google.common.collect.Iterables;
 import com.google.common.collect.ListMultimap;
+import java.io.Closeable;
 import java.io.IOException;
 import java.io.Serializable;
 import java.util.ArrayList;
@@ -44,10 +46,18 @@ import org.apache.beam.model.pipeline.v1.RunnerApi.Environment;
 import org.apache.beam.runners.core.construction.PTransformTranslation;
 import org.apache.beam.runners.core.construction.PipelineTranslation;
 import org.apache.beam.runners.core.construction.SdkComponents;
+import org.apache.beam.runners.core.metrics.MetricUpdates;
+import org.apache.beam.runners.core.metrics.MetricUpdates.MetricUpdate;
+import org.apache.beam.runners.core.metrics.MetricsContainerImpl;
 import org.apache.beam.sdk.Pipeline;
 import org.apache.beam.sdk.coders.StringUtf8Coder;
 import org.apache.beam.sdk.fn.data.FnDataReceiver;
 import org.apache.beam.sdk.fn.function.ThrowingRunnable;
+import org.apache.beam.sdk.metrics.Counter;
+import org.apache.beam.sdk.metrics.MetricName;
+import org.apache.beam.sdk.metrics.Metrics;
+import org.apache.beam.sdk.metrics.MetricsContainer;
+import org.apache.beam.sdk.metrics.MetricsEnvironment;
 import org.apache.beam.sdk.options.PipelineOptionsFactory;
 import org.apache.beam.sdk.state.BagState;
 import org.apache.beam.sdk.state.CombiningState;
@@ -86,6 +96,8 @@ import org.junit.Rule;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.JUnit4;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 
 /** Tests for {@link FnApiDoFnRunner}. */
 @RunWith(JUnit4.class)
@@ -93,6 +105,8 @@ public class FnApiDoFnRunnerTest implements Serializable {
 
   @Rule public transient ResetDateTimeProvider dateTimeProvider = new ResetDateTimeProvider();
 
+  private static final Logger LOG = LoggerFactory.getLogger(FnApiDoFnRunnerTest.class);
+
   public static final String TEST_PTRANSFORM_ID = "pTransformId";
 
   private static class ConcatCombineFn extends CombineFn<String, String, String> {
@@ -408,6 +422,9 @@ public class FnApiDoFnRunnerTest implements Serializable {
 
   private static class TestSideInputIsAccessibleForDownstreamCallersDoFn
       extends DoFn<String, Iterable<String>> {
+    private final Counter countedElements =
+        Metrics.counter(TestSideInputIsAccessibleForDownstreamCallersDoFn.class, "countedElems");
+
     private final PCollectionView<Iterable<String>> iterableSideInput;
 
     private TestSideInputIsAccessibleForDownstreamCallersDoFn(
@@ -417,6 +434,7 @@ public class FnApiDoFnRunnerTest implements Serializable {
 
     @ProcessElement
     public void processElement(ProcessContext context) {
+      countedElements.inc();
       context.output(context.sideInput(iterableSideInput));
     }
   }
@@ -514,6 +532,124 @@ public class FnApiDoFnRunnerTest implements Serializable {
     assertEquals(stateData, fakeClient.getData());
   }
 
+  /**
+   * A simple Tuple class for creating a list of ExpectedMetrics using the stepName, metricName and
+   * value of the MetricUpdate classes.
+   */
+  @AutoValue
+  public abstract static class ExpectedMetric implements Serializable {
+    static ExpectedMetric create(String stepName, MetricName metricName, long value) {
+      return new AutoValue_FnApiDoFnRunnerTest_ExpectedMetric(stepName, metricName, value);
+    }
+
+    public abstract String stepName();
+
+    public abstract MetricName metricName();
+
+    public abstract long value();
+  }
+
+  @Test
+  public void testUsingMetrics() throws Exception {
+    MetricsContainerImpl metricsContainer = new MetricsContainerImpl("testUsingMetrics");
+    Closeable closeable = MetricsEnvironment.scopedMetricsContainer(metricsContainer);
+    FixedWindows windowFn = FixedWindows.of(Duration.millis(1L));
+    IntervalWindow windowA = windowFn.assignWindow(new Instant(1L));
+    IntervalWindow windowB = windowFn.assignWindow(new Instant(2L));
+    ByteString encodedWindowA =
+        ByteString.copyFrom(CoderUtils.encodeToByteArray(windowFn.windowCoder(), windowA));
+    ByteString encodedWindowB =
+        ByteString.copyFrom(CoderUtils.encodeToByteArray(windowFn.windowCoder(), windowB));
+
+    Pipeline p = Pipeline.create();
+    PCollection<String> valuePCollection =
+        p.apply(Create.of("unused")).apply(Window.into(windowFn));
+    PCollectionView<Iterable<String>> iterableSideInputView =
+        valuePCollection.apply(View.asIterable());
+    PCollection<Iterable<String>> outputPCollection =
+        valuePCollection.apply(
+            TEST_PTRANSFORM_ID,
+            ParDo.of(new TestSideInputIsAccessibleForDownstreamCallersDoFn(iterableSideInputView))
+                .withSideInputs(iterableSideInputView));
+
+    SdkComponents sdkComponents = SdkComponents.create(p.getOptions());
+    RunnerApi.Pipeline pProto = PipelineTranslation.toProto(p, sdkComponents, true);
+    String inputPCollectionId = sdkComponents.registerPCollection(valuePCollection);
+
+    RunnerApi.PTransform pTransform =
+        pProto
+            .getComponents()
+            .getTransformsOrThrow(
+                pProto
+                    .getComponents()
+                    .getTransformsOrThrow(TEST_PTRANSFORM_ID)
+                    .getSubtransforms(0));
+
+    ImmutableMap<StateKey, ByteString> stateData =
+        ImmutableMap.of(
+            multimapSideInputKey(
+                iterableSideInputView.getTagInternal().getId(), ByteString.EMPTY, encodedWindowA),
+            encode("iterableValue1A", "iterableValue2A", "iterableValue3A"),
+            multimapSideInputKey(
+                iterableSideInputView.getTagInternal().getId(), ByteString.EMPTY, encodedWindowB),
+            encode("iterableValue1B", "iterableValue2B", "iterableValue3B"));
+
+    FakeBeamFnStateClient fakeClient = new FakeBeamFnStateClient(stateData);
+
+    List<WindowedValue<Iterable<String>>> mainOutputValues = new ArrayList<>();
+    ListMultimap<String, FnDataReceiver<WindowedValue<?>>> consumers = ArrayListMultimap.create();
+    consumers.put(
+        Iterables.getOnlyElement(pTransform.getOutputsMap().values()),
+        (FnDataReceiver) (FnDataReceiver<WindowedValue<Iterable<String>>>) mainOutputValues::add);
+    List<ThrowingRunnable> startFunctions = new ArrayList<>();
+    List<ThrowingRunnable> finishFunctions = new ArrayList<>();
+
+    new FnApiDoFnRunner.Factory<>()
+        .createRunnerForPTransform(
+            PipelineOptionsFactory.create(),
+            null /* beamFnDataClient */,
+            fakeClient,
+            TEST_PTRANSFORM_ID,
+            pTransform,
+            Suppliers.ofInstance("57L")::get,
+            pProto.getComponents().getPcollectionsMap(),
+            pProto.getComponents().getCodersMap(),
+            pProto.getComponents().getWindowingStrategiesMap(),
+            consumers,
+            startFunctions::add,
+            finishFunctions::add,
+            null /* splitListener */);
+
+    Iterables.getOnlyElement(startFunctions).run();
+    mainOutputValues.clear();
+
+    // Ensure that bag user state that is initially empty or populated works.
+    // Ensure that the bagUserStateKey order does not matter when we traverse over KV pairs.
+    FnDataReceiver<WindowedValue<?>> mainInput =
+        Iterables.getOnlyElement(consumers.get(inputPCollectionId));
+    mainInput.accept(valueInWindow("X", windowA));
+    mainInput.accept(valueInWindow("Y", windowB));
+
+    MetricsContainer mc = MetricsEnvironment.getCurrentContainer();
+    MetricName metricName =
+        MetricName.named(TestSideInputIsAccessibleForDownstreamCallersDoFn.class, "countedElems");
+    List<ExpectedMetric> expectedMetrics = new ArrayList<ExpectedMetric>();
+    expectedMetrics.add(ExpectedMetric.create("testUsingMetrics", metricName, 2));
+
+    closeable.close();
+    MetricUpdates updates = metricsContainer.getUpdates();
+
+    // Validate MetricUpdates
+    int i = 0;
+    for (MetricUpdate mu : updates.counterUpdates()) {
+      assertEquals(expectedMetrics.get(i).metricName(), mu.getKey().metricName());
+      assertEquals(expectedMetrics.get(i).stepName(), mu.getKey().stepName());
+      assertEquals(expectedMetrics.get(i).value(), mu.getUpdate());
+      i++;
+    }
+    assertEquals(1, i); // Validate the length.
+  }
+
   private static class TestTimerfulDoFn extends DoFn<KV<String, String>, String> {
     @StateId("bag")
     private final StateSpec<BagState<String>> bagStateSpec = StateSpecs.bag(StringUtf8Coder.of());