You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by sc...@apache.org on 2018/12/13 18:10:20 UTC
[beam] branch master updated: [BEAM-6138] Add User Counter Metric
Support to Java SDK (#6799)
This is an automated email from the ASF dual-hosted git repository.
scott pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 0b3b9e0 [BEAM-6138] Add User Counter Metric Support to Java SDK (#6799)
0b3b9e0 is described below
commit 0b3b9e0b99ad798836b789b181b51d47cc5234a6
Author: Alex Amato <aj...@google.com>
AuthorDate: Thu Dec 13 10:10:12 2018 -0800
[BEAM-6138] Add User Counter Metric Support to Java SDK (#6799)
---
.../runners/core/metrics/MetricsContainerImpl.java | 19 +++
.../fnexecution/control/RemoteExecutionTest.java | 129 +++++++++++++++++++
.../fn/harness/control/ProcessBundleHandler.java | 38 ++++--
.../beam/fn/harness/FnApiDoFnRunnerTest.java | 136 +++++++++++++++++++++
4 files changed, 309 insertions(+), 13 deletions(-)
diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/MetricsContainerImpl.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/MetricsContainerImpl.java
index 95bfa74..9147919 100644
--- a/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/MetricsContainerImpl.java
+++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/MetricsContainerImpl.java
@@ -21,8 +21,10 @@ import static com.google.common.base.Preconditions.checkNotNull;
import com.google.common.collect.ImmutableList;
import java.io.Serializable;
+import java.util.ArrayList;
import java.util.Map;
import javax.annotation.Nullable;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.MonitoringInfo;
import org.apache.beam.runners.core.construction.metrics.MetricKey;
import org.apache.beam.runners.core.metrics.MetricUpdates.MetricUpdate;
import org.apache.beam.sdk.annotations.Experimental;
@@ -136,6 +138,23 @@ public class MetricsContainerImpl implements Serializable, MetricsContainer {
extractUpdates(counters), extractUpdates(distributions), extractUpdates(gauges));
}
+ /** Return the cumulative values for any metrics in this container as MonitoringInfos. */
+ public Iterable<MonitoringInfo> getMonitoringInfos() {
+ // Extract user metrics and store as MonitoringInfos.
+ ArrayList<MonitoringInfo> monitoringInfos = new ArrayList<MonitoringInfo>();
+ MetricUpdates mus = this.getUpdates();
+
+ for (MetricUpdate<Long> mu : mus.counterUpdates()) {
+ SimpleMonitoringInfoBuilder builder = new SimpleMonitoringInfoBuilder(true);
+ builder.setUrnForUserMetric(
+ mu.getKey().metricName().getNamespace(), mu.getKey().metricName().getName());
+ builder.setInt64Value(mu.getUpdate());
+ builder.setTimestampToNow();
+ monitoringInfos.add(builder.build());
+ }
+ return monitoringInfos;
+ }
+
private void commitUpdates(MetricsMap<MetricName, ? extends MetricCell<?>> cells) {
for (MetricCell<?> cell : cells.values()) {
cell.getDirty().afterCommit();
diff --git a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java
index f53257f..1fdf6db 100644
--- a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java
+++ b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java
@@ -50,12 +50,16 @@ import java.util.concurrent.Future;
import java.util.concurrent.ThreadFactory;
import java.util.function.Function;
import org.apache.beam.fn.harness.FnHarness;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.MonitoringInfo;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleProgressResponse;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleResponse;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.Target;
import org.apache.beam.model.pipeline.v1.RunnerApi;
import org.apache.beam.runners.core.construction.PipelineTranslation;
import org.apache.beam.runners.core.construction.graph.ExecutableStage;
import org.apache.beam.runners.core.construction.graph.FusedPipeline;
import org.apache.beam.runners.core.construction.graph.GreedyPipelineFuser;
+import org.apache.beam.runners.core.metrics.SimpleMonitoringInfoBuilder;
import org.apache.beam.runners.fnexecution.GrpcContextHeaderAccessorProvider;
import org.apache.beam.runners.fnexecution.GrpcFnServer;
import org.apache.beam.runners.fnexecution.InProcessServerFactory;
@@ -83,6 +87,7 @@ import org.apache.beam.sdk.coders.VoidCoder;
import org.apache.beam.sdk.fn.data.FnDataReceiver;
import org.apache.beam.sdk.fn.stream.OutboundObserverFactory;
import org.apache.beam.sdk.fn.test.InProcessManagedChannelFactory;
+import org.apache.beam.sdk.metrics.Metrics;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.state.BagState;
import org.apache.beam.sdk.state.ReadableState;
@@ -109,16 +114,20 @@ import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionList;
import org.apache.beam.sdk.values.PCollectionView;
import org.apache.beam.vendor.grpc.v1_13_1.com.google.protobuf.ByteString;
+import org.hamcrest.CoreMatchers;
import org.hamcrest.collection.IsEmptyIterable;
import org.hamcrest.collection.IsIterableContainingInOrder;
import org.joda.time.DateTimeUtils;
import org.joda.time.Duration;
import org.junit.After;
+import org.junit.Assert;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
/**
* Tests the execution of a pipeline from specification time to executing a single fused stage,
@@ -128,6 +137,8 @@ import org.junit.runners.JUnit4;
public class RemoteExecutionTest implements Serializable {
@Rule public transient ResetDateTimeProvider resetDateTimeProvider = new ResetDateTimeProvider();
+ private static final Logger LOG = LoggerFactory.getLogger(RemoteExecutionTest.class);
+
private transient GrpcFnServer<FnApiControlClientPoolService> controlServer;
private transient GrpcFnServer<GrpcDataService> dataServer;
private transient GrpcFnServer<GrpcStateService> stateServer;
@@ -487,6 +498,124 @@ public class RemoteExecutionTest implements Serializable {
}
@Test
+ public void testMetrics() throws Exception {
+ final String counterMetricName = "counterMetric";
+ Pipeline p = Pipeline.create();
+ PCollection<String> input =
+ p.apply("impulse", Impulse.create())
+ .apply(
+ "create",
+ ParDo.of(
+ new DoFn<byte[], String>() {
+ @ProcessElement
+ public void process(ProcessContext ctxt) {
+ Metrics.counter(RemoteExecutionTest.class, counterMetricName).inc();
+ }
+ }))
+ .setCoder(StringUtf8Coder.of());
+
+ RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p);
+ FusedPipeline fused = GreedyPipelineFuser.fuse(pipelineProto);
+ Optional<ExecutableStage> optionalStage =
+ Iterables.tryFind(fused.getFusedStages(), (ExecutableStage stage) -> true);
+ checkState(optionalStage.isPresent(), "Expected a stage with side inputs.");
+ ExecutableStage stage = optionalStage.get();
+
+ ExecutableProcessBundleDescriptor descriptor =
+ ProcessBundleDescriptors.fromExecutableStage(
+ "test_stage",
+ stage,
+ dataServer.getApiServiceDescriptor(),
+ stateServer.getApiServiceDescriptor());
+
+ BundleProcessor processor =
+ controlClient.getProcessor(
+ descriptor.getProcessBundleDescriptor(),
+ descriptor.getRemoteInputDestinations(),
+ stateDelegator);
+
+ Map<Target, Coder<WindowedValue<?>>> outputTargets = descriptor.getOutputTargetCoders();
+ Map<Target, Collection<WindowedValue<?>>> outputValues = new HashMap<>();
+ Map<Target, RemoteOutputReceiver<?>> outputReceivers = new HashMap<>();
+ for (Entry<Target, Coder<WindowedValue<?>>> targetCoder : outputTargets.entrySet()) {
+ List<WindowedValue<?>> outputContents = Collections.synchronizedList(new ArrayList<>());
+ outputValues.put(targetCoder.getKey(), outputContents);
+ outputReceivers.put(
+ targetCoder.getKey(),
+ RemoteOutputReceiver.of(targetCoder.getValue(), outputContents::add));
+ }
+
+ Iterable<byte[]> sideInputData =
+ Arrays.asList(
+ CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "A"),
+ CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "B"),
+ CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "C"));
+
+ StateRequestHandler stateRequestHandler =
+ StateRequestHandlers.forSideInputHandlerFactory(
+ descriptor.getSideInputSpecs(),
+ new SideInputHandlerFactory() {
+ @Override
+ public <T, V, W extends BoundedWindow> SideInputHandler<V, W> forSideInput(
+ String pTransformId,
+ String sideInputId,
+ RunnerApi.FunctionSpec accessPattern,
+ Coder<T> elementCoder,
+ Coder<W> windowCoder) {
+ return new SideInputHandler<V, W>() {
+ @Override
+ public Iterable<V> get(byte[] key, W window) {
+ return (Iterable) sideInputData;
+ }
+
+ @Override
+ public Coder<V> resultCoder() {
+ return ((KvCoder) elementCoder).getValueCoder();
+ }
+ };
+ }
+ });
+
+ BundleProgressHandler progressHandler =
+ new BundleProgressHandler() {
+ @Override
+ public void onProgress(ProcessBundleProgressResponse progress) {}
+
+ @Override
+ public void onCompleted(ProcessBundleResponse response) {
+ // Assert the timestamps are non empty then 0 them out before comparing.
+ List<MonitoringInfo> actualMIs = new ArrayList<>();
+ for (MonitoringInfo mi : response.getMonitoringInfosList()) {
+ MonitoringInfo.Builder builder = MonitoringInfo.newBuilder();
+ Assert.assertTrue(mi.getTimestamp().getSeconds() > 0);
+ builder.mergeFrom(mi);
+ builder.clearTimestamp();
+ actualMIs.add(builder.build());
+ }
+
+ SimpleMonitoringInfoBuilder builder = new SimpleMonitoringInfoBuilder();
+ builder.setUrnForUserMetric(RemoteExecutionTest.class.getName(), counterMetricName);
+ builder.setInt64Value(2);
+ MonitoringInfo expectedCounter = builder.build();
+
+ assertThat(actualMIs, CoreMatchers.hasItems(expectedCounter));
+ }
+ };
+
+ try (ActiveBundle bundle =
+ processor.newBundle(outputReceivers, stateRequestHandler, progressHandler)) {
+ Iterables.getOnlyElement(bundle.getInputReceivers().values())
+ .accept(
+ WindowedValue.valueInGlobalWindow(
+ CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "X")));
+ Iterables.getOnlyElement(bundle.getInputReceivers().values())
+ .accept(
+ WindowedValue.valueInGlobalWindow(
+ CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "Y")));
+ }
+ }
+
+ @Test
public void testExecutionWithUserState() throws Exception {
Pipeline p = Pipeline.create();
final String stateId = "foo";
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java
index e22a6c2..3018343 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java
@@ -26,6 +26,7 @@ import com.google.common.collect.Lists;
import com.google.common.collect.Multimap;
import com.google.common.collect.SetMultimap;
import com.google.common.collect.Sets;
+import java.io.Closeable;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashSet;
@@ -47,6 +48,7 @@ import org.apache.beam.fn.harness.state.BeamFnStateGrpcClientCache;
import org.apache.beam.model.fnexecution.v1.BeamFnApi;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.BundleApplication;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.DelayedBundleApplication;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.MonitoringInfo;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleDescriptor;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleRequest;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleResponse;
@@ -60,8 +62,10 @@ import org.apache.beam.model.pipeline.v1.RunnerApi.PCollection;
import org.apache.beam.model.pipeline.v1.RunnerApi.PTransform;
import org.apache.beam.model.pipeline.v1.RunnerApi.WindowingStrategy;
import org.apache.beam.runners.core.construction.PTransformTranslation;
+import org.apache.beam.runners.core.metrics.MetricsContainerImpl;
import org.apache.beam.sdk.fn.data.FnDataReceiver;
import org.apache.beam.sdk.fn.function.ThrowingRunnable;
+import org.apache.beam.sdk.metrics.MetricsEnvironment;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.util.common.ReflectHelpers;
@@ -292,21 +296,29 @@ public class ProcessBundleHandler {
splitListener);
}
- // Already in reverse topological order so we don't need to do anything.
- for (ThrowingRunnable startFunction : startFunctions) {
- LOG.debug("Starting function {}", startFunction);
- startFunction.run();
- }
+ MetricsContainerImpl metricsContainer = new MetricsContainerImpl(request.getInstructionId());
+ try (Closeable closeable = MetricsEnvironment.scopedMetricsContainer(metricsContainer)) {
+
+ // Already in reverse topological order so we don't need to do anything.
+ for (ThrowingRunnable startFunction : startFunctions) {
+ LOG.debug("Starting function {}", startFunction);
+ startFunction.run();
+ }
- queueingClient.drainAndBlock();
+ queueingClient.drainAndBlock();
- // Need to reverse this since we want to call finish in topological order.
- for (ThrowingRunnable finishFunction : Lists.reverse(finishFunctions)) {
- LOG.debug("Finishing function {}", finishFunction);
- finishFunction.run();
- }
- if (!allResiduals.isEmpty()) {
- response.addAllResidualRoots(allResiduals.values());
+ // Need to reverse this since we want to call finish in topological order.
+ for (ThrowingRunnable finishFunction : Lists.reverse(finishFunctions)) {
+ LOG.debug("Finishing function {}", finishFunction);
+ finishFunction.run();
+ }
+ if (!allResiduals.isEmpty()) {
+ response.addAllResidualRoots(allResiduals.values());
+ }
+
+ for (MonitoringInfo mi : metricsContainer.getMonitoringInfos()) {
+ response.addMonitoringInfos(mi);
+ }
}
}
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java
index 9149edb..5d8c413 100644
--- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java
@@ -27,11 +27,13 @@ import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.fail;
+import com.google.auto.value.AutoValue;
import com.google.common.base.Suppliers;
import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterables;
import com.google.common.collect.ListMultimap;
+import java.io.Closeable;
import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
@@ -44,10 +46,18 @@ import org.apache.beam.model.pipeline.v1.RunnerApi.Environment;
import org.apache.beam.runners.core.construction.PTransformTranslation;
import org.apache.beam.runners.core.construction.PipelineTranslation;
import org.apache.beam.runners.core.construction.SdkComponents;
+import org.apache.beam.runners.core.metrics.MetricUpdates;
+import org.apache.beam.runners.core.metrics.MetricUpdates.MetricUpdate;
+import org.apache.beam.runners.core.metrics.MetricsContainerImpl;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.fn.data.FnDataReceiver;
import org.apache.beam.sdk.fn.function.ThrowingRunnable;
+import org.apache.beam.sdk.metrics.Counter;
+import org.apache.beam.sdk.metrics.MetricName;
+import org.apache.beam.sdk.metrics.Metrics;
+import org.apache.beam.sdk.metrics.MetricsContainer;
+import org.apache.beam.sdk.metrics.MetricsEnvironment;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.state.BagState;
import org.apache.beam.sdk.state.CombiningState;
@@ -86,6 +96,8 @@ import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
/** Tests for {@link FnApiDoFnRunner}. */
@RunWith(JUnit4.class)
@@ -93,6 +105,8 @@ public class FnApiDoFnRunnerTest implements Serializable {
@Rule public transient ResetDateTimeProvider dateTimeProvider = new ResetDateTimeProvider();
+ private static final Logger LOG = LoggerFactory.getLogger(FnApiDoFnRunnerTest.class);
+
public static final String TEST_PTRANSFORM_ID = "pTransformId";
private static class ConcatCombineFn extends CombineFn<String, String, String> {
@@ -408,6 +422,9 @@ public class FnApiDoFnRunnerTest implements Serializable {
private static class TestSideInputIsAccessibleForDownstreamCallersDoFn
extends DoFn<String, Iterable<String>> {
+ private final Counter countedElements =
+ Metrics.counter(TestSideInputIsAccessibleForDownstreamCallersDoFn.class, "countedElems");
+
private final PCollectionView<Iterable<String>> iterableSideInput;
private TestSideInputIsAccessibleForDownstreamCallersDoFn(
@@ -417,6 +434,7 @@ public class FnApiDoFnRunnerTest implements Serializable {
@ProcessElement
public void processElement(ProcessContext context) {
+ countedElements.inc();
context.output(context.sideInput(iterableSideInput));
}
}
@@ -514,6 +532,124 @@ public class FnApiDoFnRunnerTest implements Serializable {
assertEquals(stateData, fakeClient.getData());
}
+ /**
+ * A simple Tuple class for creating a list of ExpectedMetrics using the stepName, metricName and
+ * value of the MetricUpdate classes.
+ */
+ @AutoValue
+ public abstract static class ExpectedMetric implements Serializable {
+ static ExpectedMetric create(String stepName, MetricName metricName, long value) {
+ return new AutoValue_FnApiDoFnRunnerTest_ExpectedMetric(stepName, metricName, value);
+ }
+
+ public abstract String stepName();
+
+ public abstract MetricName metricName();
+
+ public abstract long value();
+ }
+
+ @Test
+ public void testUsingMetrics() throws Exception {
+ MetricsContainerImpl metricsContainer = new MetricsContainerImpl("testUsingMetrics");
+ Closeable closeable = MetricsEnvironment.scopedMetricsContainer(metricsContainer);
+ FixedWindows windowFn = FixedWindows.of(Duration.millis(1L));
+ IntervalWindow windowA = windowFn.assignWindow(new Instant(1L));
+ IntervalWindow windowB = windowFn.assignWindow(new Instant(2L));
+ ByteString encodedWindowA =
+ ByteString.copyFrom(CoderUtils.encodeToByteArray(windowFn.windowCoder(), windowA));
+ ByteString encodedWindowB =
+ ByteString.copyFrom(CoderUtils.encodeToByteArray(windowFn.windowCoder(), windowB));
+
+ Pipeline p = Pipeline.create();
+ PCollection<String> valuePCollection =
+ p.apply(Create.of("unused")).apply(Window.into(windowFn));
+ PCollectionView<Iterable<String>> iterableSideInputView =
+ valuePCollection.apply(View.asIterable());
+ PCollection<Iterable<String>> outputPCollection =
+ valuePCollection.apply(
+ TEST_PTRANSFORM_ID,
+ ParDo.of(new TestSideInputIsAccessibleForDownstreamCallersDoFn(iterableSideInputView))
+ .withSideInputs(iterableSideInputView));
+
+ SdkComponents sdkComponents = SdkComponents.create(p.getOptions());
+ RunnerApi.Pipeline pProto = PipelineTranslation.toProto(p, sdkComponents, true);
+ String inputPCollectionId = sdkComponents.registerPCollection(valuePCollection);
+
+ RunnerApi.PTransform pTransform =
+ pProto
+ .getComponents()
+ .getTransformsOrThrow(
+ pProto
+ .getComponents()
+ .getTransformsOrThrow(TEST_PTRANSFORM_ID)
+ .getSubtransforms(0));
+
+ ImmutableMap<StateKey, ByteString> stateData =
+ ImmutableMap.of(
+ multimapSideInputKey(
+ iterableSideInputView.getTagInternal().getId(), ByteString.EMPTY, encodedWindowA),
+ encode("iterableValue1A", "iterableValue2A", "iterableValue3A"),
+ multimapSideInputKey(
+ iterableSideInputView.getTagInternal().getId(), ByteString.EMPTY, encodedWindowB),
+ encode("iterableValue1B", "iterableValue2B", "iterableValue3B"));
+
+ FakeBeamFnStateClient fakeClient = new FakeBeamFnStateClient(stateData);
+
+ List<WindowedValue<Iterable<String>>> mainOutputValues = new ArrayList<>();
+ ListMultimap<String, FnDataReceiver<WindowedValue<?>>> consumers = ArrayListMultimap.create();
+ consumers.put(
+ Iterables.getOnlyElement(pTransform.getOutputsMap().values()),
+ (FnDataReceiver) (FnDataReceiver<WindowedValue<Iterable<String>>>) mainOutputValues::add);
+ List<ThrowingRunnable> startFunctions = new ArrayList<>();
+ List<ThrowingRunnable> finishFunctions = new ArrayList<>();
+
+ new FnApiDoFnRunner.Factory<>()
+ .createRunnerForPTransform(
+ PipelineOptionsFactory.create(),
+ null /* beamFnDataClient */,
+ fakeClient,
+ TEST_PTRANSFORM_ID,
+ pTransform,
+ Suppliers.ofInstance("57L")::get,
+ pProto.getComponents().getPcollectionsMap(),
+ pProto.getComponents().getCodersMap(),
+ pProto.getComponents().getWindowingStrategiesMap(),
+ consumers,
+ startFunctions::add,
+ finishFunctions::add,
+ null /* splitListener */);
+
+ Iterables.getOnlyElement(startFunctions).run();
+ mainOutputValues.clear();
+
+ // Ensure that bag user state that is initially empty or populated works.
+ // Ensure that the bagUserStateKey order does not matter when we traverse over KV pairs.
+ FnDataReceiver<WindowedValue<?>> mainInput =
+ Iterables.getOnlyElement(consumers.get(inputPCollectionId));
+ mainInput.accept(valueInWindow("X", windowA));
+ mainInput.accept(valueInWindow("Y", windowB));
+
+ MetricsContainer mc = MetricsEnvironment.getCurrentContainer();
+ MetricName metricName =
+ MetricName.named(TestSideInputIsAccessibleForDownstreamCallersDoFn.class, "countedElems");
+ List<ExpectedMetric> expectedMetrics = new ArrayList<ExpectedMetric>();
+ expectedMetrics.add(ExpectedMetric.create("testUsingMetrics", metricName, 2));
+
+ closeable.close();
+ MetricUpdates updates = metricsContainer.getUpdates();
+
+ // Validate MetricUpdates
+ int i = 0;
+ for (MetricUpdate mu : updates.counterUpdates()) {
+ assertEquals(expectedMetrics.get(i).metricName(), mu.getKey().metricName());
+ assertEquals(expectedMetrics.get(i).stepName(), mu.getKey().stepName());
+ assertEquals(expectedMetrics.get(i).value(), mu.getUpdate());
+ i++;
+ }
+ assertEquals(1, i); // Validate the length.
+ }
+
private static class TestTimerfulDoFn extends DoFn<KV<String, String>, String> {
@StateId("bag")
private final StateSpec<BagState<String>> bagStateSpec = StateSpecs.bag(StringUtf8Coder.of());