You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by dh...@apache.org on 2017/05/09 16:36:30 UTC

[2/3] beam git commit: Cherry-pick pull request #2649 into release-2.0.0 branch

Cherry-pick pull request #2649 into release-2.0.0 branch


Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/3a4ffd2c
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/3a4ffd2c
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/3a4ffd2c

Branch: refs/heads/release-2.0.0
Commit: 3a4ffd2ce8e90486cf51f420a42599ddf95b9a5d
Parents: bad377c
Author: Aviem Zur <av...@gmail.com>
Authored: Fri May 5 23:13:24 2017 +0300
Committer: Dan Halperin <dh...@google.com>
Committed: Tue May 9 09:36:18 2017 -0700

----------------------------------------------------------------------
 .../apache/beam/runners/core/LateDataUtils.java |   2 +-
 .../beam/runners/flink/FlinkRunnerResult.java   |   8 +-
 .../metrics/DoFnRunnerWithMetricsUpdate.java    |  12 +-
 .../flink/metrics/FlinkMetricContainer.java     | 273 +++--------
 .../flink/metrics/FlinkMetricResults.java       | 146 ------
 .../flink/metrics/MetricsAccumulator.java       |  60 +++
 .../flink/metrics/ReaderInvocationUtil.java     |   7 +-
 .../translation/wrappers/SourceInputFormat.java |   8 +-
 .../streaming/io/BoundedSourceWrapper.java      |   8 +-
 .../streaming/io/UnboundedSourceWrapper.java    |   9 +-
 .../beam/runners/spark/SparkPipelineResult.java |   8 +-
 .../apache/beam/runners/spark/io/SourceRDD.java |   4 +-
 .../runners/spark/io/SparkUnboundedSource.java  |  19 +-
 .../spark/metrics/MetricsAccumulator.java       |  20 +-
 .../spark/metrics/MetricsAccumulatorParam.java  |  20 +-
 .../runners/spark/metrics/SparkBeamMetric.java  |  11 +-
 .../spark/metrics/SparkBeamMetricSource.java    |   2 +-
 .../spark/metrics/SparkMetricResults.java       | 172 -------
 .../spark/metrics/SparkMetricsContainer.java    | 174 -------
 .../SparkGroupAlsoByWindowViaWindowSet.java     |   4 +-
 .../spark/stateful/StateSpecFunctions.java      |   8 +-
 .../translation/DoFnRunnerWithMetrics.java      |   6 +-
 .../spark/translation/MultiDoFnFunction.java    |   6 +-
 .../spark/translation/TransformTranslator.java  |   4 +-
 .../streaming/StreamingTransformTranslator.java |   4 +-
 .../apache/beam/sdk/metrics/CounterCell.java    |  27 +-
 .../org/apache/beam/sdk/metrics/DirtyState.java |   3 +-
 .../beam/sdk/metrics/DistributionCell.java      |  16 +-
 .../org/apache/beam/sdk/metrics/GaugeCell.java  |  20 +-
 .../org/apache/beam/sdk/metrics/MetricCell.java |  14 +-
 .../org/apache/beam/sdk/metrics/Metrics.java    |   2 +-
 .../beam/sdk/metrics/MetricsContainer.java      |  29 +-
 .../sdk/metrics/MetricsContainerStepMap.java    | 487 +++++++++++++++++++
 .../org/apache/beam/sdk/metrics/MetricsMap.java |   5 +-
 .../beam/sdk/metrics/CounterCellTest.java       |   6 +-
 .../metrics/MetricsContainerStepMapTest.java    | 258 ++++++++++
 .../beam/sdk/metrics/MetricsContainerTest.java  |  14 +-
 37 files changed, 1086 insertions(+), 790 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/beam/blob/3a4ffd2c/runners/core-java/src/main/java/org/apache/beam/runners/core/LateDataUtils.java
----------------------------------------------------------------------
diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/LateDataUtils.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/LateDataUtils.java
index c45387b..f7c0d31 100644
--- a/runners/core-java/src/main/java/org/apache/beam/runners/core/LateDataUtils.java
+++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/LateDataUtils.java
@@ -71,7 +71,7 @@ public class LateDataUtils {
                         .isBefore(timerInternals.currentInputWatermarkTime());
                 if (expired) {
                   // The element is too late for this window.
-                  droppedDueToLateness.inc();
+                  droppedDueToLateness.update(1L);
                   WindowTracing.debug(
                       "GroupAlsoByWindow: Dropping element at {} for key: {}; "
                           + "window: {} since it is too far behind inputWatermark: {}",

http://git-wip-us.apache.org/repos/asf/beam/blob/3a4ffd2c/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkRunnerResult.java
----------------------------------------------------------------------
diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkRunnerResult.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkRunnerResult.java
index 90dc79b..038895a 100644
--- a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkRunnerResult.java
+++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkRunnerResult.java
@@ -17,12 +17,15 @@
  */
 package org.apache.beam.runners.flink;
 
+import static org.apache.beam.sdk.metrics.MetricsContainerStepMap.asAttemptedOnlyMetricResults;
+
 import java.io.IOException;
 import java.util.Collections;
 import java.util.Map;
-import org.apache.beam.runners.flink.metrics.FlinkMetricResults;
+import org.apache.beam.runners.flink.metrics.FlinkMetricContainer;
 import org.apache.beam.sdk.PipelineResult;
 import org.apache.beam.sdk.metrics.MetricResults;
+import org.apache.beam.sdk.metrics.MetricsContainerStepMap;
 import org.joda.time.Duration;
 
 /**
@@ -72,6 +75,7 @@ public class FlinkRunnerResult implements PipelineResult {
 
   @Override
   public MetricResults metrics() {
-    return new FlinkMetricResults(accumulators);
+    return asAttemptedOnlyMetricResults(
+        (MetricsContainerStepMap) accumulators.get(FlinkMetricContainer.ACCUMULATOR_NAME));
   }
 }

http://git-wip-us.apache.org/repos/asf/beam/blob/3a4ffd2c/runners/flink/src/main/java/org/apache/beam/runners/flink/metrics/DoFnRunnerWithMetricsUpdate.java
----------------------------------------------------------------------
diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/metrics/DoFnRunnerWithMetricsUpdate.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/metrics/DoFnRunnerWithMetricsUpdate.java
index dae91fe..40191d2 100644
--- a/runners/flink/src/main/java/org/apache/beam/runners/flink/metrics/DoFnRunnerWithMetricsUpdate.java
+++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/metrics/DoFnRunnerWithMetricsUpdate.java
@@ -34,6 +34,7 @@ import org.joda.time.Instant;
  */
 public class DoFnRunnerWithMetricsUpdate<InputT, OutputT> implements DoFnRunner<InputT, OutputT> {
 
+  private final String stepName;
   private final FlinkMetricContainer container;
   private final DoFnRunner<InputT, OutputT> delegate;
 
@@ -41,14 +42,15 @@ public class DoFnRunnerWithMetricsUpdate<InputT, OutputT> implements DoFnRunner<
       String stepName,
       DoFnRunner<InputT, OutputT> delegate,
       RuntimeContext runtimeContext) {
+    this.stepName = stepName;
     this.delegate = delegate;
-    container = new FlinkMetricContainer(stepName, runtimeContext);
+    container = new FlinkMetricContainer(runtimeContext);
   }
 
   @Override
   public void startBundle() {
     try (Closeable ignored =
-             MetricsEnvironment.scopedMetricsContainer(container.getMetricsContainer())) {
+             MetricsEnvironment.scopedMetricsContainer(container.getMetricsContainer(stepName))) {
       delegate.startBundle();
     } catch (IOException e) {
       throw new RuntimeException(e);
@@ -58,7 +60,7 @@ public class DoFnRunnerWithMetricsUpdate<InputT, OutputT> implements DoFnRunner<
   @Override
   public void processElement(final WindowedValue<InputT> elem) {
     try (Closeable ignored =
-             MetricsEnvironment.scopedMetricsContainer(container.getMetricsContainer())) {
+             MetricsEnvironment.scopedMetricsContainer(container.getMetricsContainer(stepName))) {
       delegate.processElement(elem);
     } catch (IOException e) {
       throw new RuntimeException(e);
@@ -69,7 +71,7 @@ public class DoFnRunnerWithMetricsUpdate<InputT, OutputT> implements DoFnRunner<
   public void onTimer(final String timerId, final BoundedWindow window, final Instant timestamp,
                       final TimeDomain timeDomain) {
     try (Closeable ignored =
-             MetricsEnvironment.scopedMetricsContainer(container.getMetricsContainer())) {
+             MetricsEnvironment.scopedMetricsContainer(container.getMetricsContainer(stepName))) {
       delegate.onTimer(timerId, window, timestamp, timeDomain);
     } catch (IOException e) {
       throw new RuntimeException(e);
@@ -79,7 +81,7 @@ public class DoFnRunnerWithMetricsUpdate<InputT, OutputT> implements DoFnRunner<
   @Override
   public void finishBundle() {
     try (Closeable ignored =
-             MetricsEnvironment.scopedMetricsContainer(container.getMetricsContainer())) {
+             MetricsEnvironment.scopedMetricsContainer(container.getMetricsContainer(stepName))) {
       delegate.finishBundle();
     } catch (IOException e) {
       throw new RuntimeException(e);

http://git-wip-us.apache.org/repos/asf/beam/blob/3a4ffd2c/runners/flink/src/main/java/org/apache/beam/runners/flink/metrics/FlinkMetricContainer.java
----------------------------------------------------------------------
diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/metrics/FlinkMetricContainer.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/metrics/FlinkMetricContainer.java
index d020f69..f81205e 100644
--- a/runners/flink/src/main/java/org/apache/beam/runners/flink/metrics/FlinkMetricContainer.java
+++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/metrics/FlinkMetricContainer.java
@@ -17,19 +17,24 @@
  */
 package org.apache.beam.runners.flink.metrics;
 
+import static org.apache.beam.sdk.metrics.MetricsContainerStepMap.asAttemptedOnlyMetricResults;
+
 import java.util.HashMap;
 import java.util.Map;
-import org.apache.beam.sdk.metrics.DistributionData;
-import org.apache.beam.sdk.metrics.GaugeData;
-import org.apache.beam.sdk.metrics.MetricKey;
-import org.apache.beam.sdk.metrics.MetricName;
-import org.apache.beam.sdk.metrics.MetricUpdates;
+import org.apache.beam.sdk.metrics.DistributionResult;
+import org.apache.beam.sdk.metrics.GaugeResult;
+import org.apache.beam.sdk.metrics.MetricQueryResults;
+import org.apache.beam.sdk.metrics.MetricResult;
+import org.apache.beam.sdk.metrics.MetricResults;
 import org.apache.beam.sdk.metrics.MetricsContainer;
+import org.apache.beam.sdk.metrics.MetricsContainerStepMap;
+import org.apache.beam.sdk.metrics.MetricsFilter;
 import org.apache.flink.api.common.accumulators.Accumulator;
-import org.apache.flink.api.common.accumulators.LongCounter;
 import org.apache.flink.api.common.functions.RuntimeContext;
 import org.apache.flink.metrics.Counter;
 import org.apache.flink.metrics.Gauge;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 
 /**
  * Helper class for holding a {@link MetricsContainer} and forwarding Beam metrics to
@@ -37,46 +42,61 @@ import org.apache.flink.metrics.Gauge;
  */
 public class FlinkMetricContainer {
 
+  public static final String ACCUMULATOR_NAME = "__metricscontainers";
+
+  private static final Logger LOG = LoggerFactory.getLogger(FlinkMetricContainer.class);
+
   private static final String METRIC_KEY_SEPARATOR = "__";
-  static final String COUNTER_PREFIX = "__counter";
-  static final String DISTRIBUTION_PREFIX = "__distribution";
-  static final String GAUGE_PREFIX = "__gauge";
+  private static final String COUNTER_PREFIX = "__counter";
+  private static final String DISTRIBUTION_PREFIX = "__distribution";
+  private static final String GAUGE_PREFIX = "__gauge";
 
-  private final MetricsContainer metricsContainer;
   private final RuntimeContext runtimeContext;
   private final Map<String, Counter> flinkCounterCache;
   private final Map<String, FlinkDistributionGauge> flinkDistributionGaugeCache;
   private final Map<String, FlinkGauge> flinkGaugeCache;
+  private final MetricsAccumulator metricsAccumulator;
 
-  public FlinkMetricContainer(String stepName, RuntimeContext runtimeContext) {
-    metricsContainer = new MetricsContainer(stepName);
+  public FlinkMetricContainer(RuntimeContext runtimeContext) {
     this.runtimeContext = runtimeContext;
-    flinkCounterCache = new HashMap<>();
-    flinkDistributionGaugeCache = new HashMap<>();
-    flinkGaugeCache = new HashMap<>();
+    this.flinkCounterCache = new HashMap<>();
+    this.flinkDistributionGaugeCache = new HashMap<>();
+    this.flinkGaugeCache = new HashMap<>();
+
+    Accumulator<MetricsContainerStepMap, MetricsContainerStepMap> metricsAccumulator =
+        runtimeContext.getAccumulator(ACCUMULATOR_NAME);
+    if (metricsAccumulator == null) {
+      metricsAccumulator = new MetricsAccumulator();
+      try {
+        runtimeContext.addAccumulator(ACCUMULATOR_NAME, metricsAccumulator);
+      } catch (Exception e) {
+        LOG.error("Failed to create metrics accumulator.", e);
+      }
+    }
+    this.metricsAccumulator = (MetricsAccumulator) metricsAccumulator;
   }
 
-  public MetricsContainer getMetricsContainer() {
-    return metricsContainer;
+  MetricsContainer getMetricsContainer(String stepName) {
+    return metricsAccumulator != null
+        ? metricsAccumulator.getLocalValue().getContainer(stepName)
+        : null;
   }
 
-  public void updateMetrics() {
-    // update metrics
-    MetricUpdates updates = metricsContainer.getUpdates();
-    if (updates != null) {
-      updateCounters(updates.counterUpdates());
-      updateDistributions(updates.distributionUpdates());
-      updateGauge(updates.gaugeUpdates());
-      metricsContainer.commitUpdates();
-    }
+  void updateMetrics() {
+    MetricResults metricResults =
+        asAttemptedOnlyMetricResults(metricsAccumulator.getLocalValue());
+    MetricQueryResults metricQueryResults =
+        metricResults.queryMetrics(MetricsFilter.builder().build());
+    updateCounters(metricQueryResults.counters());
+    updateDistributions(metricQueryResults.distributions());
+    updateGauge(metricQueryResults.gauges());
   }
 
-  private void updateCounters(Iterable<MetricUpdates.MetricUpdate<Long>> updates) {
-
-    for (MetricUpdates.MetricUpdate<Long> metricUpdate : updates) {
+  private void updateCounters(Iterable<MetricResult<Long>> counters) {
+    for (MetricResult<Long> metricResult : counters) {
+      String flinkMetricName = getFlinkMetricNameString(COUNTER_PREFIX, metricResult);
 
-      String flinkMetricName = getFlinkMetricNameString(COUNTER_PREFIX, metricUpdate.getKey());
-      Long update = metricUpdate.getUpdate();
+      Long update = metricResult.attempted();
 
       // update flink metric
       Counter counter = flinkCounterCache.get(flinkMetricName);
@@ -86,26 +106,15 @@ public class FlinkMetricContainer {
       }
       counter.dec(counter.getCount());
       counter.inc(update);
-
-      // update flink accumulator
-      Accumulator<Long, Long> accumulator = runtimeContext.getAccumulator(flinkMetricName);
-      if (accumulator == null) {
-        accumulator = new LongCounter(update);
-        runtimeContext.addAccumulator(flinkMetricName, accumulator);
-      } else {
-        accumulator.resetLocal();
-        accumulator.add(update);
-      }
     }
   }
 
-  private void updateDistributions(Iterable<MetricUpdates.MetricUpdate<DistributionData>> updates) {
-
-    for (MetricUpdates.MetricUpdate<DistributionData> metricUpdate : updates) {
-
+  private void updateDistributions(Iterable<MetricResult<DistributionResult>> distributions) {
+    for (MetricResult<DistributionResult> metricResult : distributions) {
       String flinkMetricName =
-          getFlinkMetricNameString(DISTRIBUTION_PREFIX, metricUpdate.getKey());
-      DistributionData update = metricUpdate.getUpdate();
+          getFlinkMetricNameString(DISTRIBUTION_PREFIX, metricResult);
+
+      DistributionResult update = metricResult.attempted();
 
       // update flink metric
       FlinkDistributionGauge gauge = flinkDistributionGaugeCache.get(flinkMetricName);
@@ -116,26 +125,15 @@ public class FlinkMetricContainer {
       } else {
         gauge.update(update);
       }
-
-      // update flink accumulator
-      Accumulator<DistributionData, DistributionData> accumulator =
-          runtimeContext.getAccumulator(flinkMetricName);
-      if (accumulator == null) {
-        accumulator = new FlinkDistributionDataAccumulator(update);
-        runtimeContext.addAccumulator(flinkMetricName, accumulator);
-      } else {
-        accumulator.resetLocal();
-        accumulator.add(update);
-      }
     }
   }
 
-  private void updateGauge(Iterable<MetricUpdates.MetricUpdate<GaugeData>> updates) {
-    for (MetricUpdates.MetricUpdate<GaugeData> metricUpdate : updates) {
-
+  private void updateGauge(Iterable<MetricResult<GaugeResult>> gauges) {
+    for (MetricResult<GaugeResult> metricResult : gauges) {
       String flinkMetricName =
-          getFlinkMetricNameString(GAUGE_PREFIX, metricUpdate.getKey());
-      GaugeData update = metricUpdate.getUpdate();
+          getFlinkMetricNameString(GAUGE_PREFIX, metricResult);
+
+      GaugeResult update = metricResult.attempted();
 
       // update flink metric
       FlinkGauge gauge = flinkGaugeCache.get(flinkMetricName);
@@ -146,170 +144,55 @@ public class FlinkMetricContainer {
       } else {
         gauge.update(update);
       }
-
-      // update flink accumulator
-      Accumulator<GaugeData, GaugeData> accumulator =
-          runtimeContext.getAccumulator(flinkMetricName);
-      if (accumulator == null) {
-        accumulator = new FlinkGaugeAccumulator(update);
-        runtimeContext.addAccumulator(flinkMetricName, accumulator);
-      }
-      accumulator.resetLocal();
-      accumulator.add(update);
     }
   }
 
-  private static String getFlinkMetricNameString(String prefix, MetricKey key) {
+  private static String getFlinkMetricNameString(String prefix, MetricResult<?> metricResult) {
     return prefix
-        + METRIC_KEY_SEPARATOR + key.stepName()
-        + METRIC_KEY_SEPARATOR + key.metricName().namespace()
-        + METRIC_KEY_SEPARATOR + key.metricName().name();
-  }
-
-  static MetricKey parseMetricKey(String flinkMetricName) {
-    String[] arr = flinkMetricName.split(METRIC_KEY_SEPARATOR);
-    return MetricKey.create(arr[2], MetricName.named(arr[3], arr[4]));
+        + METRIC_KEY_SEPARATOR + metricResult.step()
+        + METRIC_KEY_SEPARATOR + metricResult.name().namespace()
+        + METRIC_KEY_SEPARATOR + metricResult.name().name();
   }
 
   /**
-   * Flink {@link Gauge} for {@link DistributionData}.
+   * Flink {@link Gauge} for {@link DistributionResult}.
    */
-  public static class FlinkDistributionGauge implements Gauge<DistributionData> {
+  public static class FlinkDistributionGauge implements Gauge<DistributionResult> {
 
-    DistributionData data;
+    DistributionResult data;
 
-    FlinkDistributionGauge(DistributionData data) {
+    FlinkDistributionGauge(DistributionResult data) {
       this.data = data;
     }
 
-    void update(DistributionData data) {
+    void update(DistributionResult data) {
       this.data = data;
     }
 
     @Override
-    public DistributionData getValue() {
+    public DistributionResult getValue() {
       return data;
     }
   }
 
   /**
-   * Flink {@link Gauge} for {@link GaugeData}.
+   * Flink {@link Gauge} for {@link GaugeResult}.
    */
-  public static class FlinkGauge implements Gauge<GaugeData> {
+  public static class FlinkGauge implements Gauge<GaugeResult> {
 
-    GaugeData data;
+    GaugeResult data;
 
-    FlinkGauge(GaugeData data) {
+    FlinkGauge(GaugeResult data) {
       this.data = data;
     }
 
-    void update(GaugeData update) {
-      this.data = data.combine(update);
+    void update(GaugeResult update) {
+      this.data = update;
     }
 
     @Override
-    public GaugeData getValue() {
+    public GaugeResult getValue() {
       return data;
     }
   }
-
-  /**
-   * Flink {@link Accumulator} for {@link GaugeData}.
-   */
-  public static class FlinkDistributionDataAccumulator implements
-      Accumulator<DistributionData, DistributionData> {
-
-    private static final long serialVersionUID = 1L;
-
-    private DistributionData data;
-
-    public FlinkDistributionDataAccumulator(DistributionData data) {
-      this.data = data;
-    }
-
-    @Override
-    public void add(DistributionData value) {
-      if (data == null) {
-        this.data = value;
-      } else {
-        this.data = this.data.combine(value);
-      }
-    }
-
-    @Override
-    public DistributionData getLocalValue() {
-      return data;
-    }
-
-    @Override
-    public void resetLocal() {
-      data = null;
-    }
-
-    @Override
-    public void merge(Accumulator<DistributionData, DistributionData> other) {
-      data = data.combine(other.getLocalValue());
-    }
-
-    @Override
-    public Accumulator<DistributionData, DistributionData> clone() {
-      try {
-        super.clone();
-      } catch (CloneNotSupportedException e) {
-        throw new RuntimeException(e);
-      }
-
-      return new FlinkDistributionDataAccumulator(
-          DistributionData.create(data.sum(), data.count(), data.min(), data.max()));
-    }
-  }
-
-  /**
-   * Flink {@link Accumulator} for {@link GaugeData}.
-   */
-  public static class FlinkGaugeAccumulator implements Accumulator<GaugeData, GaugeData> {
-
-    private GaugeData data;
-
-    public FlinkGaugeAccumulator(GaugeData data) {
-      this.data = data;
-    }
-
-    @Override
-    public void add(GaugeData value) {
-      if (data == null) {
-        this.data = value;
-      } else {
-        this.data = this.data.combine(value);
-      }
-    }
-
-    @Override
-    public GaugeData getLocalValue() {
-      return data;
-    }
-
-    @Override
-    public void resetLocal() {
-      this.data = null;
-    }
-
-    @Override
-    public void merge(Accumulator<GaugeData, GaugeData> other) {
-      data = data.combine(other.getLocalValue());
-    }
-
-    @Override
-    public Accumulator<GaugeData, GaugeData> clone() {
-      try {
-        super.clone();
-      } catch (CloneNotSupportedException e) {
-        throw new RuntimeException(e);
-      }
-
-      return new FlinkGaugeAccumulator(
-          GaugeData.create(data.value()));
-    }
-  }
-
 }

http://git-wip-us.apache.org/repos/asf/beam/blob/3a4ffd2c/runners/flink/src/main/java/org/apache/beam/runners/flink/metrics/FlinkMetricResults.java
----------------------------------------------------------------------
diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/metrics/FlinkMetricResults.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/metrics/FlinkMetricResults.java
deleted file mode 100644
index 9e1430b..0000000
--- a/runners/flink/src/main/java/org/apache/beam/runners/flink/metrics/FlinkMetricResults.java
+++ /dev/null
@@ -1,146 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.beam.runners.flink.metrics;
-
-
-import static org.apache.beam.runners.flink.metrics.FlinkMetricContainer.COUNTER_PREFIX;
-import static org.apache.beam.runners.flink.metrics.FlinkMetricContainer.DISTRIBUTION_PREFIX;
-import static org.apache.beam.runners.flink.metrics.FlinkMetricContainer.GAUGE_PREFIX;
-
-import java.util.ArrayList;
-import java.util.List;
-import java.util.Map;
-import org.apache.beam.sdk.metrics.DistributionData;
-import org.apache.beam.sdk.metrics.DistributionResult;
-import org.apache.beam.sdk.metrics.GaugeData;
-import org.apache.beam.sdk.metrics.GaugeResult;
-import org.apache.beam.sdk.metrics.MetricFiltering;
-import org.apache.beam.sdk.metrics.MetricKey;
-import org.apache.beam.sdk.metrics.MetricName;
-import org.apache.beam.sdk.metrics.MetricQueryResults;
-import org.apache.beam.sdk.metrics.MetricResult;
-import org.apache.beam.sdk.metrics.MetricResults;
-import org.apache.beam.sdk.metrics.MetricsFilter;
-
-/**
- * Implementation of {@link MetricResults} for the Flink Runner.
- */
-public class FlinkMetricResults extends MetricResults {
-
-  private Map<String, Object> accumulators;
-
-  public FlinkMetricResults(Map<String, Object> accumulators) {
-    this.accumulators = accumulators;
-  }
-
-  @Override
-  public MetricQueryResults queryMetrics(MetricsFilter filter) {
-    return new FlinkMetricQueryResults(filter);
-  }
-
-  private class FlinkMetricQueryResults implements MetricQueryResults {
-
-    private MetricsFilter filter;
-
-    FlinkMetricQueryResults(MetricsFilter filter) {
-      this.filter = filter;
-    }
-
-    @Override
-    public Iterable<MetricResult<Long>> counters() {
-      List<MetricResult<Long>> result = new ArrayList<>();
-      for (Map.Entry<String, Object> accumulator : accumulators.entrySet()) {
-        if (accumulator.getKey().startsWith(COUNTER_PREFIX)) {
-          MetricKey metricKey = FlinkMetricContainer.parseMetricKey(accumulator.getKey());
-          if (MetricFiltering.matches(filter, metricKey)) {
-            result.add(new FlinkMetricResult<>(
-                metricKey.metricName(), metricKey.stepName(), (Long) accumulator.getValue()));
-          }
-        }
-      }
-      return result;
-    }
-
-    @Override
-    public Iterable<MetricResult<DistributionResult>> distributions() {
-      List<MetricResult<DistributionResult>> result = new ArrayList<>();
-      for (Map.Entry<String, Object> accumulator : accumulators.entrySet()) {
-        if (accumulator.getKey().startsWith(DISTRIBUTION_PREFIX)) {
-          MetricKey metricKey = FlinkMetricContainer.parseMetricKey(accumulator.getKey());
-          DistributionData data = (DistributionData) accumulator.getValue();
-          if (MetricFiltering.matches(filter, metricKey)) {
-            result.add(new FlinkMetricResult<>(
-                metricKey.metricName(), metricKey.stepName(), data.extractResult()));
-          }
-        }
-      }
-      return result;
-    }
-
-    @Override
-    public Iterable<MetricResult<GaugeResult>> gauges() {
-      List<MetricResult<GaugeResult>> result = new ArrayList<>();
-      for (Map.Entry<String, Object> accumulator : accumulators.entrySet()) {
-        if (accumulator.getKey().startsWith(GAUGE_PREFIX)) {
-          MetricKey metricKey = FlinkMetricContainer.parseMetricKey(accumulator.getKey());
-          GaugeData data = (GaugeData) accumulator.getValue();
-          if (MetricFiltering.matches(filter, metricKey)) {
-            result.add(new FlinkMetricResult<>(
-                metricKey.metricName(), metricKey.stepName(), data.extractResult()));
-          }
-        }
-      }
-      return result;
-    }
-
-  }
-
-  private static class FlinkMetricResult<T> implements MetricResult<T> {
-    private final MetricName name;
-    private final String step;
-    private final T result;
-
-    FlinkMetricResult(MetricName name, String step, T result) {
-      this.name = name;
-      this.step = step;
-      this.result = result;
-    }
-
-    @Override
-    public MetricName name() {
-      return name;
-    }
-
-    @Override
-    public String step() {
-      return step;
-    }
-
-    @Override
-    public T committed() {
-      throw new UnsupportedOperationException("Flink runner does not currently support committed"
-          + " metrics results. Please use 'attempted' instead.");
-    }
-
-    @Override
-    public T attempted() {
-      return result;
-    }
-  }
-
-}

http://git-wip-us.apache.org/repos/asf/beam/blob/3a4ffd2c/runners/flink/src/main/java/org/apache/beam/runners/flink/metrics/MetricsAccumulator.java
----------------------------------------------------------------------
diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/metrics/MetricsAccumulator.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/metrics/MetricsAccumulator.java
new file mode 100644
index 0000000..a9dc2ce
--- /dev/null
+++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/metrics/MetricsAccumulator.java
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.flink.metrics;
+
+import org.apache.beam.sdk.metrics.MetricsContainerStepMap;
+import org.apache.flink.api.common.accumulators.Accumulator;
+import org.apache.flink.api.common.accumulators.SimpleAccumulator;
+
+/**
+ * Accumulator of {@link MetricsContainerStepMap}.
+ */
+public class MetricsAccumulator implements SimpleAccumulator<MetricsContainerStepMap> {
+  private MetricsContainerStepMap metricsContainers = new MetricsContainerStepMap();
+
+  @Override
+  public void add(MetricsContainerStepMap value) {
+    metricsContainers.updateAll(value);
+  }
+
+  @Override
+  public MetricsContainerStepMap getLocalValue() {
+    return metricsContainers;
+  }
+
+  @Override
+  public void resetLocal() {
+    this.metricsContainers = new MetricsContainerStepMap();
+  }
+
+  @Override
+  public void merge(Accumulator<MetricsContainerStepMap, MetricsContainerStepMap> other) {
+    this.add(other.getLocalValue());
+  }
+
+  @Override
+  public Accumulator<MetricsContainerStepMap, MetricsContainerStepMap> clone() {
+    try {
+      super.clone();
+    } catch (CloneNotSupportedException ignored) {
+    }
+    MetricsAccumulator metricsAccumulator = new MetricsAccumulator();
+    metricsAccumulator.getLocalValue().updateAll(this.getLocalValue());
+    return metricsAccumulator;
+  }
+}

http://git-wip-us.apache.org/repos/asf/beam/blob/3a4ffd2c/runners/flink/src/main/java/org/apache/beam/runners/flink/metrics/ReaderInvocationUtil.java
----------------------------------------------------------------------
diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/metrics/ReaderInvocationUtil.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/metrics/ReaderInvocationUtil.java
index 38263d9..64738cc 100644
--- a/runners/flink/src/main/java/org/apache/beam/runners/flink/metrics/ReaderInvocationUtil.java
+++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/metrics/ReaderInvocationUtil.java
@@ -32,13 +32,16 @@ import org.apache.beam.sdk.options.PipelineOptions;
  */
 public class ReaderInvocationUtil<OutputT, ReaderT extends Source.Reader<OutputT>> {
 
+  private final String stepName;
   private final FlinkMetricContainer container;
   private final Boolean enableMetrics;
 
   public ReaderInvocationUtil(
+      String stepName,
       PipelineOptions options,
       FlinkMetricContainer container) {
     FlinkPipelineOptions flinkPipelineOptions = options.as(FlinkPipelineOptions.class);
+    this.stepName = stepName;
     enableMetrics = flinkPipelineOptions.getEnableMetrics();
     this.container = container;
   }
@@ -46,7 +49,7 @@ public class ReaderInvocationUtil<OutputT, ReaderT extends Source.Reader<OutputT
   public boolean invokeStart(ReaderT reader) throws IOException {
     if (enableMetrics) {
       try (Closeable ignored =
-               MetricsEnvironment.scopedMetricsContainer(container.getMetricsContainer())) {
+               MetricsEnvironment.scopedMetricsContainer(container.getMetricsContainer(stepName))) {
         boolean result = reader.start();
         container.updateMetrics();
         return result;
@@ -59,7 +62,7 @@ public class ReaderInvocationUtil<OutputT, ReaderT extends Source.Reader<OutputT
   public boolean invokeAdvance(ReaderT reader) throws IOException {
     if (enableMetrics) {
       try (Closeable ignored =
-               MetricsEnvironment.scopedMetricsContainer(container.getMetricsContainer())) {
+               MetricsEnvironment.scopedMetricsContainer(container.getMetricsContainer(stepName))) {
         boolean result = reader.advance();
         container.updateMetrics();
         return result;

http://git-wip-us.apache.org/repos/asf/beam/blob/3a4ffd2c/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/SourceInputFormat.java
----------------------------------------------------------------------
diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/SourceInputFormat.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/SourceInputFormat.java
index f2b81fc..27e6912 100644
--- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/SourceInputFormat.java
+++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/SourceInputFormat.java
@@ -71,9 +71,13 @@ public class SourceInputFormat<T>
 
   @Override
   public void open(SourceInputSplit<T> sourceInputSplit) throws IOException {
-    FlinkMetricContainer metricContainer = new FlinkMetricContainer(stepName, getRuntimeContext());
+    FlinkMetricContainer metricContainer = new FlinkMetricContainer(getRuntimeContext());
+
     readerInvoker =
-        new ReaderInvocationUtil<>(serializedOptions.getPipelineOptions(), metricContainer);
+        new ReaderInvocationUtil<>(
+            stepName,
+            serializedOptions.getPipelineOptions(),
+            metricContainer);
 
     reader = ((BoundedSource<T>) sourceInputSplit.getSource()).createReader(options);
     inputAvailable = readerInvoker.invokeStart(reader);

http://git-wip-us.apache.org/repos/asf/beam/blob/3a4ffd2c/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/BoundedSourceWrapper.java
----------------------------------------------------------------------
diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/BoundedSourceWrapper.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/BoundedSourceWrapper.java
index a142685..6d75688 100644
--- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/BoundedSourceWrapper.java
+++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/BoundedSourceWrapper.java
@@ -104,9 +104,13 @@ public class BoundedSourceWrapper<OutputT>
         numSubtasks,
         localSources);
 
-    FlinkMetricContainer metricContainer = new FlinkMetricContainer(stepName, getRuntimeContext());
+    FlinkMetricContainer metricContainer = new FlinkMetricContainer(getRuntimeContext());
+
     ReaderInvocationUtil<OutputT, BoundedSource.BoundedReader<OutputT>> readerInvoker =
-        new ReaderInvocationUtil<>(serializedOptions.getPipelineOptions(), metricContainer);
+        new ReaderInvocationUtil<>(
+            stepName,
+            serializedOptions.getPipelineOptions(),
+            metricContainer);
 
     readers = new ArrayList<>();
     // initialize readers from scratch

http://git-wip-us.apache.org/repos/asf/beam/blob/3a4ffd2c/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/UnboundedSourceWrapper.java
----------------------------------------------------------------------
diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/UnboundedSourceWrapper.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/UnboundedSourceWrapper.java
index a731e2b..ec21699 100644
--- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/UnboundedSourceWrapper.java
+++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/UnboundedSourceWrapper.java
@@ -215,10 +215,13 @@ public class UnboundedSourceWrapper<
 
     context = ctx;
 
-    FlinkMetricContainer metricContainer = new FlinkMetricContainer(stepName, getRuntimeContext());
-    ReaderInvocationUtil<OutputT, UnboundedSource.UnboundedReader<OutputT>> readerInvoker =
-        new ReaderInvocationUtil<>(serializedOptions.getPipelineOptions(), metricContainer);
+    FlinkMetricContainer metricContainer = new FlinkMetricContainer(getRuntimeContext());
 
+    ReaderInvocationUtil<OutputT, UnboundedSource.UnboundedReader<OutputT>> readerInvoker =
+        new ReaderInvocationUtil<>(
+            stepName,
+            serializedOptions.getPipelineOptions(),
+            metricContainer);
 
     if (localReaders.size() == 0) {
       // do nothing, but still look busy ...

http://git-wip-us.apache.org/repos/asf/beam/blob/3a4ffd2c/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineResult.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineResult.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineResult.java
index 3e94a45..3986e33 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineResult.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineResult.java
@@ -18,13 +18,15 @@
 
 package org.apache.beam.runners.spark;
 
+import static org.apache.beam.sdk.metrics.MetricsContainerStepMap.asAttemptedOnlyMetricResults;
+
 import java.io.IOException;
 import java.util.Objects;
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.Future;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.TimeoutException;
-import org.apache.beam.runners.spark.metrics.SparkMetricResults;
+import org.apache.beam.runners.spark.metrics.MetricsAccumulator;
 import org.apache.beam.runners.spark.translation.SparkContextFactory;
 import org.apache.beam.sdk.Pipeline;
 import org.apache.beam.sdk.PipelineResult;
@@ -41,7 +43,6 @@ public abstract class SparkPipelineResult implements PipelineResult {
   protected final Future pipelineExecution;
   protected JavaSparkContext javaSparkContext;
   protected PipelineResult.State state;
-  private final SparkMetricResults metricResults = new SparkMetricResults();
 
   SparkPipelineResult(final Future<?> pipelineExecution, final JavaSparkContext javaSparkContext) {
     this.pipelineExecution = pipelineExecution;
@@ -106,7 +107,8 @@ public abstract class SparkPipelineResult implements PipelineResult {
 
   @Override
   public MetricResults metrics() {
-    return metricResults;
+    return asAttemptedOnlyMetricResults(
+        MetricsAccumulator.getInstance().value());
   }
 
   @Override

http://git-wip-us.apache.org/repos/asf/beam/blob/3a4ffd2c/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SourceRDD.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SourceRDD.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SourceRDD.java
index e294359..71a19e7 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SourceRDD.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SourceRDD.java
@@ -26,12 +26,12 @@ import java.util.Collections;
 import java.util.Iterator;
 import java.util.List;
 import org.apache.beam.runners.spark.metrics.MetricsAccumulator;
-import org.apache.beam.runners.spark.metrics.SparkMetricsContainer;
 import org.apache.beam.runners.spark.translation.SparkRuntimeContext;
 import org.apache.beam.sdk.io.BoundedSource;
 import org.apache.beam.sdk.io.Source;
 import org.apache.beam.sdk.io.UnboundedSource;
 import org.apache.beam.sdk.metrics.MetricsContainer;
+import org.apache.beam.sdk.metrics.MetricsContainerStepMap;
 import org.apache.beam.sdk.metrics.MetricsEnvironment;
 import org.apache.beam.sdk.util.WindowedValue;
 import org.apache.spark.Accumulator;
@@ -65,7 +65,7 @@ public class SourceRDD {
     private final SparkRuntimeContext runtimeContext;
     private final int numPartitions;
     private final String stepName;
-    private final Accumulator<SparkMetricsContainer> metricsAccum;
+    private final Accumulator<MetricsContainerStepMap> metricsAccum;
 
     // to satisfy Scala API.
     private static final scala.collection.immutable.Seq<Dependency<?>> NIL =

http://git-wip-us.apache.org/repos/asf/beam/blob/3a4ffd2c/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SparkUnboundedSource.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SparkUnboundedSource.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SparkUnboundedSource.java
index 0388f6c..2a9de4b 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SparkUnboundedSource.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SparkUnboundedSource.java
@@ -25,7 +25,6 @@ import java.util.Collections;
 import org.apache.beam.runners.spark.SparkPipelineOptions;
 import org.apache.beam.runners.spark.coders.CoderHelpers;
 import org.apache.beam.runners.spark.metrics.MetricsAccumulator;
-import org.apache.beam.runners.spark.metrics.SparkMetricsContainer;
 import org.apache.beam.runners.spark.stateful.StateSpecFunctions;
 import org.apache.beam.runners.spark.translation.SparkRuntimeContext;
 import org.apache.beam.runners.spark.translation.streaming.UnboundedDataset;
@@ -37,6 +36,7 @@ import org.apache.beam.sdk.io.UnboundedSource.CheckpointMark;
 import org.apache.beam.sdk.metrics.Gauge;
 import org.apache.beam.sdk.metrics.Metrics;
 import org.apache.beam.sdk.metrics.MetricsContainer;
+import org.apache.beam.sdk.metrics.MetricsContainerStepMap;
 import org.apache.beam.sdk.metrics.MetricsEnvironment;
 import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
 import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
@@ -190,7 +190,7 @@ public class SparkUnboundedSource {
     public scala.Option<RDD<BoxedUnit>> compute(Time validTime) {
       // compute parent.
       scala.Option<RDD<Metadata>> parentRDDOpt = parent.getOrCompute(validTime);
-      final Accumulator<SparkMetricsContainer> metricsAccum = MetricsAccumulator.getInstance();
+      final Accumulator<MetricsContainerStepMap> metricsAccum = MetricsAccumulator.getInstance();
       long count = 0;
       SparkWatermarks sparkWatermark = null;
       Instant globalLowWatermarkForBatch = BoundedWindow.TIMESTAMP_MIN_VALUE;
@@ -211,7 +211,7 @@ public class SparkUnboundedSource {
                   ? partitionHighWatermark : globalHighWatermarkForBatch;
           // Update metrics reported in the read
           final Gauge gauge = Metrics.gauge(NAMESPACE, READ_DURATION_MILLIS);
-          final MetricsContainer container = metadata.getMetricsContainer().getContainer(stepName);
+          final MetricsContainer container = metadata.getMetricsContainers().getContainer(stepName);
           try (Closeable ignored = MetricsEnvironment.scopedMetricsContainer(container)) {
             final long readDurationMillis = metadata.getReadDurationMillis();
             if (readDurationMillis > maxReadDuration) {
@@ -220,7 +220,7 @@ public class SparkUnboundedSource {
           } catch (IOException e) {
             throw new RuntimeException(e);
           }
-          metricsAccum.value().update(metadata.getMetricsContainer());
+          metricsAccum.value().updateAll(metadata.getMetricsContainers());
         }
 
         sparkWatermark =
@@ -260,20 +260,19 @@ public class SparkUnboundedSource {
     private final Instant lowWatermark;
     private final Instant highWatermark;
     private final long readDurationMillis;
-    private final SparkMetricsContainer metricsContainer;
+    private final MetricsContainerStepMap metricsContainers;
 
     public Metadata(
         long numRecords,
         Instant lowWatermark,
         Instant highWatermark,
         final long readDurationMillis,
-        SparkMetricsContainer metricsContainer) {
+        MetricsContainerStepMap metricsContainer) {
       this.numRecords = numRecords;
       this.readDurationMillis = readDurationMillis;
-      this.metricsContainer = metricsContainer;
+      this.metricsContainers = metricsContainer;
       this.lowWatermark = lowWatermark;
       this.highWatermark = highWatermark;
-      metricsContainer.materialize();
     }
 
     long getNumRecords() {
@@ -292,8 +291,8 @@ public class SparkUnboundedSource {
       return readDurationMillis;
     }
 
-    SparkMetricsContainer getMetricsContainer() {
-      return metricsContainer;
+    MetricsContainerStepMap getMetricsContainers() {
+      return metricsContainers;
     }
   }
 

http://git-wip-us.apache.org/repos/asf/beam/blob/3a4ffd2c/runners/spark/src/main/java/org/apache/beam/runners/spark/metrics/MetricsAccumulator.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/metrics/MetricsAccumulator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/metrics/MetricsAccumulator.java
index 1153db6..1dcfa2f 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/metrics/MetricsAccumulator.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/metrics/MetricsAccumulator.java
@@ -24,6 +24,7 @@ import java.io.IOException;
 import org.apache.beam.runners.spark.SparkPipelineOptions;
 import org.apache.beam.runners.spark.translation.streaming.Checkpoint;
 import org.apache.beam.runners.spark.translation.streaming.Checkpoint.CheckpointDir;
+import org.apache.beam.sdk.metrics.MetricsContainerStepMap;
 import org.apache.hadoop.fs.FileSystem;
 import org.apache.hadoop.fs.Path;
 import org.apache.spark.Accumulator;
@@ -44,7 +45,7 @@ public class MetricsAccumulator {
   private static final String ACCUMULATOR_NAME = "Beam.Metrics";
   private static final String ACCUMULATOR_CHECKPOINT_FILENAME = "metrics";
 
-  private static volatile Accumulator<SparkMetricsContainer> instance = null;
+  private static volatile Accumulator<MetricsContainerStepMap> instance = null;
   private static volatile FileSystem fileSystem;
   private static volatile Path checkpointFilePath;
 
@@ -58,11 +59,13 @@ public class MetricsAccumulator {
           Optional<CheckpointDir> maybeCheckpointDir =
               opts.isStreaming() ? Optional.of(new CheckpointDir(opts.getCheckpointDir()))
                   : Optional.<CheckpointDir>absent();
-          Accumulator<SparkMetricsContainer> accumulator =
-              jsc.sc().accumulator(new SparkMetricsContainer(), ACCUMULATOR_NAME,
+          Accumulator<MetricsContainerStepMap> accumulator =
+              jsc.sc().accumulator(
+                  new MetricsContainerStepMap(),
+                  ACCUMULATOR_NAME,
                   new MetricsAccumulatorParam());
           if (maybeCheckpointDir.isPresent()) {
-            Optional<SparkMetricsContainer> maybeRecoveredValue =
+            Optional<MetricsContainerStepMap> maybeRecoveredValue =
                 recoverValueFromCheckpoint(jsc, maybeCheckpointDir.get());
             if (maybeRecoveredValue.isPresent()) {
               accumulator.setValue(maybeRecoveredValue.get());
@@ -75,7 +78,7 @@ public class MetricsAccumulator {
     }
   }
 
-  public static Accumulator<SparkMetricsContainer> getInstance() {
+  public static Accumulator<MetricsContainerStepMap> getInstance() {
     if (instance == null) {
       throw new IllegalStateException("Metrics accumulator has not been instantiated");
     } else {
@@ -83,14 +86,15 @@ public class MetricsAccumulator {
     }
   }
 
-  private static Optional<SparkMetricsContainer> recoverValueFromCheckpoint(
+  private static Optional<MetricsContainerStepMap> recoverValueFromCheckpoint(
       JavaSparkContext jsc,
       CheckpointDir checkpointDir) {
     try {
       Path beamCheckpointPath = checkpointDir.getBeamCheckpointDir();
       checkpointFilePath = new Path(beamCheckpointPath, ACCUMULATOR_CHECKPOINT_FILENAME);
       fileSystem = checkpointFilePath.getFileSystem(jsc.hadoopConfiguration());
-      SparkMetricsContainer recoveredValue = Checkpoint.readObject(fileSystem, checkpointFilePath);
+      MetricsContainerStepMap recoveredValue =
+          Checkpoint.readObject(fileSystem, checkpointFilePath);
       if (recoveredValue != null) {
         LOG.info("Recovered metrics from checkpoint.");
         return Optional.of(recoveredValue);
@@ -117,7 +121,7 @@ public class MetricsAccumulator {
   }
 
   /**
-   * Spark Listener which checkpoints {@link SparkMetricsContainer} values for fault-tolerance.
+   * Spark Listener which checkpoints {@link MetricsContainerStepMap} values for fault-tolerance.
    */
   public static class AccumulatorCheckpointingSparkListener extends JavaStreamingListener {
     @Override

http://git-wip-us.apache.org/repos/asf/beam/blob/3a4ffd2c/runners/spark/src/main/java/org/apache/beam/runners/spark/metrics/MetricsAccumulatorParam.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/metrics/MetricsAccumulatorParam.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/metrics/MetricsAccumulatorParam.java
index 9948c81..dee4ebc 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/metrics/MetricsAccumulatorParam.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/metrics/MetricsAccumulatorParam.java
@@ -18,25 +18,31 @@
 
 package org.apache.beam.runners.spark.metrics;
 
+import org.apache.beam.sdk.metrics.MetricsContainerStepMap;
 import org.apache.spark.AccumulatorParam;
 
 
 /**
  * Metrics accumulator param.
  */
-class MetricsAccumulatorParam implements AccumulatorParam<SparkMetricsContainer> {
+class MetricsAccumulatorParam implements AccumulatorParam<MetricsContainerStepMap> {
   @Override
-  public SparkMetricsContainer addAccumulator(SparkMetricsContainer c1, SparkMetricsContainer c2) {
-    return c1.update(c2);
+  public MetricsContainerStepMap addAccumulator(
+      MetricsContainerStepMap c1,
+      MetricsContainerStepMap c2) {
+    return addInPlace(c1, c2);
   }
 
   @Override
-  public SparkMetricsContainer addInPlace(SparkMetricsContainer c1, SparkMetricsContainer c2) {
-    return c1.update(c2);
+  public MetricsContainerStepMap addInPlace(
+      MetricsContainerStepMap c1,
+      MetricsContainerStepMap c2) {
+    c1.updateAll(c2);
+    return c1;
   }
 
   @Override
-  public SparkMetricsContainer zero(SparkMetricsContainer initialValue) {
-    return new SparkMetricsContainer();
+  public MetricsContainerStepMap zero(MetricsContainerStepMap initialValue) {
+    return new MetricsContainerStepMap();
   }
 }

http://git-wip-us.apache.org/repos/asf/beam/blob/3a4ffd2c/runners/spark/src/main/java/org/apache/beam/runners/spark/metrics/SparkBeamMetric.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/metrics/SparkBeamMetric.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/metrics/SparkBeamMetric.java
index 2d445a9..e4bd598 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/metrics/SparkBeamMetric.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/metrics/SparkBeamMetric.java
@@ -18,6 +18,8 @@
 
 package org.apache.beam.runners.spark.metrics;
 
+import static org.apache.beam.sdk.metrics.MetricsContainerStepMap.asAttemptedOnlyMetricResults;
+
 import com.codahale.metrics.Metric;
 import com.google.common.annotations.VisibleForTesting;
 import java.util.HashMap;
@@ -27,20 +29,23 @@ import org.apache.beam.sdk.metrics.GaugeResult;
 import org.apache.beam.sdk.metrics.MetricName;
 import org.apache.beam.sdk.metrics.MetricQueryResults;
 import org.apache.beam.sdk.metrics.MetricResult;
+import org.apache.beam.sdk.metrics.MetricResults;
+import org.apache.beam.sdk.metrics.MetricsContainerStepMap;
 import org.apache.beam.sdk.metrics.MetricsFilter;
 
 
 /**
- * An adapter between the {@link SparkMetricsContainer} and Codahale's {@link Metric} interface.
+ * An adapter between the {@link MetricsContainerStepMap} and Codahale's {@link Metric} interface.
  */
 class SparkBeamMetric implements Metric {
   private static final String ILLEGAL_CHARACTERS = "[^A-Za-z0-9\\._-]";
   private static final String ILLEGAL_CHARACTERS_AND_PERIOD = "[^A-Za-z0-9_-]";
 
-  private final SparkMetricResults metricResults = new SparkMetricResults();
-
   Map<String, ?> renderAll() {
     Map<String, Object> metrics = new HashMap<>();
+    MetricResults metricResults =
+        asAttemptedOnlyMetricResults(
+            MetricsAccumulator.getInstance().value());
     MetricQueryResults metricQueryResults =
         metricResults.queryMetrics(MetricsFilter.builder().build());
     for (MetricResult<Long> metricResult : metricQueryResults.counters()) {

http://git-wip-us.apache.org/repos/asf/beam/blob/3a4ffd2c/runners/spark/src/main/java/org/apache/beam/runners/spark/metrics/SparkBeamMetricSource.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/metrics/SparkBeamMetricSource.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/metrics/SparkBeamMetricSource.java
index 5c6fc24..03128d7 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/metrics/SparkBeamMetricSource.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/metrics/SparkBeamMetricSource.java
@@ -24,7 +24,7 @@ import org.apache.spark.metrics.source.Source;
 
 /**
  * A Spark {@link Source} that is tailored to expose a {@link SparkBeamMetric},
- * wrapping an underlying {@link SparkMetricsContainer} instance.
+ * wrapping an underlying {@link org.apache.beam.sdk.metrics.MetricResults} instance.
  */
 public class SparkBeamMetricSource implements Source {
   private static final String METRIC_NAME = "Metrics";

http://git-wip-us.apache.org/repos/asf/beam/blob/3a4ffd2c/runners/spark/src/main/java/org/apache/beam/runners/spark/metrics/SparkMetricResults.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/metrics/SparkMetricResults.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/metrics/SparkMetricResults.java
deleted file mode 100644
index faf4c52..0000000
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/metrics/SparkMetricResults.java
+++ /dev/null
@@ -1,172 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.beam.runners.spark.metrics;
-
-import com.google.common.base.Function;
-import com.google.common.base.Predicate;
-import com.google.common.collect.FluentIterable;
-import org.apache.beam.sdk.metrics.DistributionData;
-import org.apache.beam.sdk.metrics.DistributionResult;
-import org.apache.beam.sdk.metrics.GaugeData;
-import org.apache.beam.sdk.metrics.GaugeResult;
-import org.apache.beam.sdk.metrics.MetricFiltering;
-import org.apache.beam.sdk.metrics.MetricKey;
-import org.apache.beam.sdk.metrics.MetricName;
-import org.apache.beam.sdk.metrics.MetricQueryResults;
-import org.apache.beam.sdk.metrics.MetricResult;
-import org.apache.beam.sdk.metrics.MetricResults;
-import org.apache.beam.sdk.metrics.MetricUpdates.MetricUpdate;
-import org.apache.beam.sdk.metrics.MetricsFilter;
-
-
-/**
- * Implementation of {@link MetricResults} for the Spark Runner.
- */
-public class SparkMetricResults extends MetricResults {
-
-  @Override
-  public MetricQueryResults queryMetrics(MetricsFilter filter) {
-    return new SparkMetricQueryResults(filter);
-  }
-
-  private static class SparkMetricQueryResults implements MetricQueryResults {
-    private final MetricsFilter filter;
-
-    SparkMetricQueryResults(MetricsFilter filter) {
-      this.filter = filter;
-    }
-
-    @Override
-    public Iterable<MetricResult<Long>> counters() {
-      return
-          FluentIterable
-              .from(SparkMetricsContainer.getCounters())
-              .filter(matchesFilter(filter))
-              .transform(TO_COUNTER_RESULT)
-              .toList();
-    }
-
-    @Override
-    public Iterable<MetricResult<DistributionResult>> distributions() {
-      return
-          FluentIterable
-              .from(SparkMetricsContainer.getDistributions())
-              .filter(matchesFilter(filter))
-              .transform(TO_DISTRIBUTION_RESULT)
-              .toList();
-    }
-
-    @Override
-    public Iterable<MetricResult<GaugeResult>> gauges() {
-      return
-          FluentIterable
-              .from(SparkMetricsContainer.getGauges())
-              .filter(matchesFilter(filter))
-              .transform(TO_GAUGE_RESULT)
-              .toList();
-    }
-
-    private Predicate<MetricUpdate<?>> matchesFilter(final MetricsFilter filter) {
-      return new Predicate<MetricUpdate<?>>() {
-        @Override
-        public boolean apply(MetricUpdate<?> metricResult) {
-          return MetricFiltering.matches(filter, metricResult.getKey());
-        }
-      };
-    }
-  }
-
-  private static final Function<MetricUpdate<DistributionData>, MetricResult<DistributionResult>>
-      TO_DISTRIBUTION_RESULT =
-      new Function<MetricUpdate<DistributionData>, MetricResult<DistributionResult>>() {
-        @Override
-        public MetricResult<DistributionResult> apply(MetricUpdate<DistributionData> metricResult) {
-          if (metricResult != null) {
-            MetricKey key = metricResult.getKey();
-            return new SparkMetricResult<>(key.metricName(), key.stepName(),
-                metricResult.getUpdate().extractResult());
-          } else {
-            return null;
-          }
-        }
-      };
-
-  private static final Function<MetricUpdate<Long>, MetricResult<Long>>
-      TO_COUNTER_RESULT =
-      new Function<MetricUpdate<Long>, MetricResult<Long>>() {
-        @Override
-        public MetricResult<Long> apply(MetricUpdate<Long> metricResult) {
-          if (metricResult != null) {
-            MetricKey key = metricResult.getKey();
-            return new SparkMetricResult<>(key.metricName(), key.stepName(),
-                metricResult.getUpdate());
-          } else {
-            return null;
-          }
-        }
-      };
-
-  private static final Function<MetricUpdate<GaugeData>, MetricResult<GaugeResult>>
-      TO_GAUGE_RESULT =
-      new Function<MetricUpdate<GaugeData>, MetricResult<GaugeResult>>() {
-        @Override
-        public MetricResult<GaugeResult> apply(MetricUpdate<GaugeData> metricResult) {
-          if (metricResult != null) {
-            MetricKey key = metricResult.getKey();
-            return new SparkMetricResult<>(key.metricName(), key.stepName(),
-                metricResult.getUpdate().extractResult());
-          } else {
-            return null;
-          }
-        }
-      };
-
-  private static class SparkMetricResult<T> implements MetricResult<T> {
-    private final MetricName name;
-    private final String step;
-    private final T result;
-
-    SparkMetricResult(MetricName name, String step, T result) {
-      this.name = name;
-      this.step = step;
-      this.result = result;
-    }
-
-    @Override
-    public MetricName name() {
-      return name;
-    }
-
-    @Override
-    public String step() {
-      return step;
-    }
-
-    @Override
-    public T committed() {
-      throw new UnsupportedOperationException("Spark runner does not currently support committed"
-          + " metrics results. Please use 'attempted' instead.");
-    }
-
-    @Override
-    public T attempted() {
-      return result;
-    }
-  }
-}

http://git-wip-us.apache.org/repos/asf/beam/blob/3a4ffd2c/runners/spark/src/main/java/org/apache/beam/runners/spark/metrics/SparkMetricsContainer.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/metrics/SparkMetricsContainer.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/metrics/SparkMetricsContainer.java
deleted file mode 100644
index 9e94c14..0000000
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/metrics/SparkMetricsContainer.java
+++ /dev/null
@@ -1,174 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.beam.runners.spark.metrics;
-
-import com.google.common.cache.CacheBuilder;
-import com.google.common.cache.CacheLoader;
-import com.google.common.cache.LoadingCache;
-import java.io.IOException;
-import java.io.ObjectOutputStream;
-import java.io.Serializable;
-import java.util.Collection;
-import java.util.HashMap;
-import java.util.Map;
-import java.util.concurrent.ExecutionException;
-import org.apache.beam.sdk.metrics.DistributionData;
-import org.apache.beam.sdk.metrics.GaugeData;
-import org.apache.beam.sdk.metrics.MetricKey;
-import org.apache.beam.sdk.metrics.MetricUpdates;
-import org.apache.beam.sdk.metrics.MetricUpdates.MetricUpdate;
-import org.apache.beam.sdk.metrics.MetricsContainer;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
-
-/**
- * Spark accumulator value which holds all {@link MetricsContainer}s, aggregates and merges them.
- */
-public class SparkMetricsContainer implements Serializable {
-  private static final Logger LOG = LoggerFactory.getLogger(SparkMetricsContainer.class);
-
-  private transient volatile LoadingCache<String, MetricsContainer> metricsContainers;
-
-  private final Map<MetricKey, MetricUpdate<Long>> counters = new HashMap<>();
-  private final Map<MetricKey, MetricUpdate<DistributionData>> distributions = new HashMap<>();
-  private final Map<MetricKey, MetricUpdate<GaugeData>> gauges = new HashMap<>();
-
-  public MetricsContainer getContainer(String stepName) {
-    if (metricsContainers == null) {
-      synchronized (this) {
-        if (metricsContainers == null) {
-          initializeMetricsContainers();
-        }
-      }
-    }
-    try {
-      return metricsContainers.get(stepName);
-    } catch (ExecutionException e) {
-      LOG.error("Error while creating metrics container", e);
-      return null;
-    }
-  }
-
-  static Collection<MetricUpdate<Long>> getCounters() {
-    SparkMetricsContainer sparkMetricsContainer = getInstance();
-    sparkMetricsContainer.materialize();
-    return sparkMetricsContainer.counters.values();
-  }
-
-  static Collection<MetricUpdate<DistributionData>> getDistributions() {
-    SparkMetricsContainer sparkMetricsContainer = getInstance();
-    sparkMetricsContainer.materialize();
-    return sparkMetricsContainer.distributions.values();
-  }
-
-  static Collection<MetricUpdate<GaugeData>> getGauges() {
-    return getInstance().gauges.values();
-  }
-
-  public SparkMetricsContainer update(SparkMetricsContainer other) {
-    other.materialize();
-    this.updateCounters(other.counters.values());
-    this.updateDistributions(other.distributions.values());
-    this.updateGauges(other.gauges.values());
-    return this;
-  }
-
-  private static SparkMetricsContainer getInstance() {
-    return MetricsAccumulator.getInstance().value();
-  }
-
-  private void writeObject(ObjectOutputStream out) throws IOException {
-    // Since MetricsContainer instances are not serializable, materialize a serializable map of
-    // MetricsAggregators relating to the same metrics. This is done here, when Spark serializes
-    // the SparkMetricsContainer accumulator before sending results back to the driver at a point in
-    // time where all the metrics updates have already been made to the MetricsContainers.
-    materialize();
-    out.defaultWriteObject();
-  }
-
-  /**
-   * Materialize metrics. Must be called to enable this instance's data to be serialized correctly.
-   * This method is idempotent.
-   */
-  public void materialize() {
-    // Nullifying metricsContainers makes this method idempotent.
-    if (metricsContainers != null) {
-      for (MetricsContainer container : metricsContainers.asMap().values()) {
-        MetricUpdates cumulative = container.getCumulative();
-        this.updateCounters(cumulative.counterUpdates());
-        this.updateDistributions(cumulative.distributionUpdates());
-        this.updateGauges(cumulative.gaugeUpdates());
-      }
-      metricsContainers = null;
-    }
-  }
-
-  private void updateCounters(Iterable<MetricUpdate<Long>> updates) {
-    for (MetricUpdate<Long> update : updates) {
-      MetricKey key = update.getKey();
-      MetricUpdate<Long> current = counters.get(key);
-      counters.put(key, current != null
-          ? MetricUpdate.create(key, current.getUpdate() + update.getUpdate()) : update);
-    }
-  }
-
-  private void updateDistributions(Iterable<MetricUpdate<DistributionData>> updates) {
-    for (MetricUpdate<DistributionData> update : updates) {
-      MetricKey key = update.getKey();
-      MetricUpdate<DistributionData> current = distributions.get(key);
-      distributions.put(key, current != null
-          ? MetricUpdate.create(key, current.getUpdate().combine(update.getUpdate())) : update);
-    }
-  }
-
-  private void updateGauges(Iterable<MetricUpdate<GaugeData>> updates) {
-    for (MetricUpdate<GaugeData> update : updates) {
-      MetricKey key = update.getKey();
-      MetricUpdate<GaugeData> current = gauges.get(key);
-      gauges.put(
-          key,
-          current != null
-              ? MetricUpdate.create(key, current.getUpdate().combine(update.getUpdate()))
-              : update);
-    }
-  }
-
-  private static class MetricsContainerCacheLoader extends CacheLoader<String, MetricsContainer> {
-    @SuppressWarnings("NullableProblems")
-    @Override
-    public MetricsContainer load(String stepName) throws Exception {
-      return new MetricsContainer(stepName);
-    }
-  }
-
-  private void initializeMetricsContainers() {
-    metricsContainers = CacheBuilder.<String, SparkMetricsContainer>newBuilder()
-        .build(new MetricsContainerCacheLoader());
-  }
-
-  @Override
-  public String toString() {
-    StringBuilder sb = new StringBuilder();
-    for (Map.Entry<String, ?> metric : new SparkBeamMetric().renderAll().entrySet()) {
-      sb.append(metric.getKey()).append(": ").append(metric.getValue()).append(" ");
-    }
-    return sb.toString();
-  }
-}

http://git-wip-us.apache.org/repos/asf/beam/blob/3a4ffd2c/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkGroupAlsoByWindowViaWindowSet.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkGroupAlsoByWindowViaWindowSet.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkGroupAlsoByWindowViaWindowSet.java
index 9bc8760..37d9635 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkGroupAlsoByWindowViaWindowSet.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkGroupAlsoByWindowViaWindowSet.java
@@ -321,12 +321,12 @@ public class SparkGroupAlsoByWindowViaWindowSet {
         long lateDropped = droppedDueToLateness.getCumulative();
         if (lateDropped > 0) {
           LOG.info(String.format("Dropped %d elements due to lateness.", lateDropped));
-          droppedDueToLateness.inc(-droppedDueToLateness.getCumulative());
+          droppedDueToLateness.update(-droppedDueToLateness.getCumulative());
         }
         long closedWindowDropped = droppedDueToClosedWindow.getCumulative();
         if (closedWindowDropped > 0) {
           LOG.info(String.format("Dropped %d elements due to closed window.", closedWindowDropped));
-          droppedDueToClosedWindow.inc(-droppedDueToClosedWindow.getCumulative());
+          droppedDueToClosedWindow.update(-droppedDueToClosedWindow.getCumulative());
         }
 
         return scala.collection.JavaConversions.asScalaIterator(outIter);

http://git-wip-us.apache.org/repos/asf/beam/blob/3a4ffd2c/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/StateSpecFunctions.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/StateSpecFunctions.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/StateSpecFunctions.java
index d8d52c4..17a3c73 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/StateSpecFunctions.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/StateSpecFunctions.java
@@ -31,12 +31,12 @@ import org.apache.beam.runners.spark.coders.CoderHelpers;
 import org.apache.beam.runners.spark.io.EmptyCheckpointMark;
 import org.apache.beam.runners.spark.io.MicrobatchSource;
 import org.apache.beam.runners.spark.io.SparkUnboundedSource.Metadata;
-import org.apache.beam.runners.spark.metrics.SparkMetricsContainer;
 import org.apache.beam.runners.spark.translation.SparkRuntimeContext;
 import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.io.Source;
 import org.apache.beam.sdk.io.UnboundedSource;
 import org.apache.beam.sdk.metrics.MetricsContainer;
+import org.apache.beam.sdk.metrics.MetricsContainerStepMap;
 import org.apache.beam.sdk.metrics.MetricsEnvironment;
 import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
 import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
@@ -110,8 +110,8 @@ public class StateSpecFunctions {
           scala.Option<CheckpointMarkT> startCheckpointMark,
           State<Tuple2<byte[], Instant>> state) {
 
-        SparkMetricsContainer sparkMetricsContainer = new SparkMetricsContainer();
-        MetricsContainer metricsContainer = sparkMetricsContainer.getContainer(stepName);
+        MetricsContainerStepMap metricsContainers = new MetricsContainerStepMap();
+        MetricsContainer metricsContainer = metricsContainers.getContainer(stepName);
 
         // Add metrics container to the scope of org.apache.beam.sdk.io.Source.Reader methods
         // since they may report metrics.
@@ -214,7 +214,7 @@ public class StateSpecFunctions {
                 lowWatermark,
                 highWatermark,
                 readDurationMillis,
-                sparkMetricsContainer));
+                metricsContainers));
 
         } catch (IOException e) {
           throw new RuntimeException(e);

http://git-wip-us.apache.org/repos/asf/beam/blob/3a4ffd2c/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/DoFnRunnerWithMetrics.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/DoFnRunnerWithMetrics.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/DoFnRunnerWithMetrics.java
index d74b253..8349b09 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/DoFnRunnerWithMetrics.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/DoFnRunnerWithMetrics.java
@@ -21,8 +21,8 @@ package org.apache.beam.runners.spark.translation;
 import java.io.Closeable;
 import java.io.IOException;
 import org.apache.beam.runners.core.DoFnRunner;
-import org.apache.beam.runners.spark.metrics.SparkMetricsContainer;
 import org.apache.beam.sdk.metrics.MetricsContainer;
+import org.apache.beam.sdk.metrics.MetricsContainerStepMap;
 import org.apache.beam.sdk.metrics.MetricsEnvironment;
 import org.apache.beam.sdk.state.TimeDomain;
 import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
@@ -37,12 +37,12 @@ import org.joda.time.Instant;
 class DoFnRunnerWithMetrics<InputT, OutputT> implements DoFnRunner<InputT, OutputT> {
   private final DoFnRunner<InputT, OutputT> delegate;
   private final String stepName;
-  private final Accumulator<SparkMetricsContainer> metricsAccum;
+  private final Accumulator<MetricsContainerStepMap> metricsAccum;
 
   DoFnRunnerWithMetrics(
       String stepName,
       DoFnRunner<InputT, OutputT> delegate,
-      Accumulator<SparkMetricsContainer> metricsAccum) {
+      Accumulator<MetricsContainerStepMap> metricsAccum) {
     this.delegate = delegate;
     this.stepName = stepName;
     this.metricsAccum = metricsAccum;

http://git-wip-us.apache.org/repos/asf/beam/blob/3a4ffd2c/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/MultiDoFnFunction.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/MultiDoFnFunction.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/MultiDoFnFunction.java
index 9bfd2fa..ecf96b6 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/MultiDoFnFunction.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/MultiDoFnFunction.java
@@ -28,9 +28,9 @@ import java.util.Map;
 import org.apache.beam.runners.core.DoFnRunner;
 import org.apache.beam.runners.core.DoFnRunners;
 import org.apache.beam.runners.spark.aggregators.NamedAggregators;
-import org.apache.beam.runners.spark.metrics.SparkMetricsContainer;
 import org.apache.beam.runners.spark.util.SideInputBroadcast;
 import org.apache.beam.runners.spark.util.SparkSideInputReader;
+import org.apache.beam.sdk.metrics.MetricsContainerStepMap;
 import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.util.WindowedValue;
 import org.apache.beam.sdk.values.KV;
@@ -52,7 +52,7 @@ public class MultiDoFnFunction<InputT, OutputT>
     implements PairFlatMapFunction<Iterator<WindowedValue<InputT>>, TupleTag<?>, WindowedValue<?>> {
 
   private final Accumulator<NamedAggregators> aggAccum;
-  private final Accumulator<SparkMetricsContainer> metricsAccum;
+  private final Accumulator<MetricsContainerStepMap> metricsAccum;
   private final String stepName;
   private final DoFn<InputT, OutputT> doFn;
   private final SparkRuntimeContext runtimeContext;
@@ -71,7 +71,7 @@ public class MultiDoFnFunction<InputT, OutputT>
    */
   public MultiDoFnFunction(
       Accumulator<NamedAggregators> aggAccum,
-      Accumulator<SparkMetricsContainer> metricsAccum,
+      Accumulator<MetricsContainerStepMap> metricsAccum,
       String stepName,
       DoFn<InputT, OutputT> doFn,
       SparkRuntimeContext runtimeContext,

http://git-wip-us.apache.org/repos/asf/beam/blob/3a4ffd2c/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java
index 8a8e246..acbac32 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java
@@ -35,13 +35,13 @@ import org.apache.beam.runners.spark.aggregators.NamedAggregators;
 import org.apache.beam.runners.spark.coders.CoderHelpers;
 import org.apache.beam.runners.spark.io.SourceRDD;
 import org.apache.beam.runners.spark.metrics.MetricsAccumulator;
-import org.apache.beam.runners.spark.metrics.SparkMetricsContainer;
 import org.apache.beam.runners.spark.util.SideInputBroadcast;
 import org.apache.beam.sdk.coders.CannotProvideCoderException;
 import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.coders.KvCoder;
 import org.apache.beam.sdk.coders.StringUtf8Coder;
 import org.apache.beam.sdk.io.Read;
+import org.apache.beam.sdk.metrics.MetricsContainerStepMap;
 import org.apache.beam.sdk.transforms.Combine;
 import org.apache.beam.sdk.transforms.CombineWithContext;
 import org.apache.beam.sdk.transforms.Create;
@@ -359,7 +359,7 @@ public final class TransformTranslator {
         WindowingStrategy<?, ?> windowingStrategy =
             context.getInput(transform).getWindowingStrategy();
         Accumulator<NamedAggregators> aggAccum = AggregatorsAccumulator.getInstance();
-        Accumulator<SparkMetricsContainer> metricsAccum = MetricsAccumulator.getInstance();
+        Accumulator<MetricsContainerStepMap> metricsAccum = MetricsAccumulator.getInstance();
         JavaPairRDD<TupleTag<?>, WindowedValue<?>> all =
             inRDD.mapPartitionsToPair(
                 new MultiDoFnFunction<>(

http://git-wip-us.apache.org/repos/asf/beam/blob/3a4ffd2c/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java
index 2c4a747..f736e53 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java
@@ -39,7 +39,6 @@ import org.apache.beam.runners.spark.io.ConsoleIO;
 import org.apache.beam.runners.spark.io.CreateStream;
 import org.apache.beam.runners.spark.io.SparkUnboundedSource;
 import org.apache.beam.runners.spark.metrics.MetricsAccumulator;
-import org.apache.beam.runners.spark.metrics.SparkMetricsContainer;
 import org.apache.beam.runners.spark.stateful.SparkGroupAlsoByWindowViaWindowSet;
 import org.apache.beam.runners.spark.translation.BoundedDataset;
 import org.apache.beam.runners.spark.translation.Dataset;
@@ -59,6 +58,7 @@ import org.apache.beam.runners.spark.util.SideInputBroadcast;
 import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.coders.KvCoder;
 import org.apache.beam.sdk.io.Read;
+import org.apache.beam.sdk.metrics.MetricsContainerStepMap;
 import org.apache.beam.sdk.transforms.Combine;
 import org.apache.beam.sdk.transforms.CombineWithContext;
 import org.apache.beam.sdk.transforms.DoFn;
@@ -395,7 +395,7 @@ public final class StreamingTransformTranslator {
                       JavaRDD<WindowedValue<InputT>> rdd) throws Exception {
                     final Accumulator<NamedAggregators> aggAccum =
                         AggregatorsAccumulator.getInstance();
-                    final Accumulator<SparkMetricsContainer> metricsAccum =
+                    final Accumulator<MetricsContainerStepMap> metricsAccum =
                         MetricsAccumulator.getInstance();
                     final Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, SideInputBroadcast<?>>>
                         sideInputs =

http://git-wip-us.apache.org/repos/asf/beam/blob/3a4ffd2c/sdks/java/core/src/main/java/org/apache/beam/sdk/metrics/CounterCell.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/metrics/CounterCell.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/metrics/CounterCell.java
index 7ab5ebc..4b8548f 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/metrics/CounterCell.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/metrics/CounterCell.java
@@ -30,7 +30,7 @@ import org.apache.beam.sdk.annotations.Experimental.Kind;
  * indirection.
  */
 @Experimental(Kind.METRICS)
-public class CounterCell implements MetricCell<Long> {
+public class CounterCell implements MetricCell<Counter, Long> {
 
   private final DirtyState dirty = new DirtyState();
   private final AtomicLong value = new AtomicLong();
@@ -41,13 +41,26 @@ public class CounterCell implements MetricCell<Long> {
    */
   CounterCell() {}
 
-  /** Increment the counter by the given amount. */
-  private void add(long n) {
+  /**
+   * Increment the counter by the given amount.
+   * @param n value to increment by. Can be negative to decrement.
+   */
+  public void update(long n) {
     value.addAndGet(n);
     dirty.afterModification();
   }
 
   @Override
+  public void update(Long n) {
+    throw new UnsupportedOperationException("CounterCell.update(Long n) should not be used"
+    + " as it performs unnecessary boxing/unboxing. Use CounterCell.update(long n) instead.");
+  }
+
+  @Override public void update(MetricCell<Counter, Long> other) {
+    update((long) other.getCumulative());
+  }
+
+  @Override
   public DirtyState getDirty() {
     return dirty;
   }
@@ -56,12 +69,4 @@ public class CounterCell implements MetricCell<Long> {
   public Long getCumulative() {
     return value.get();
   }
-
-  public void inc() {
-    add(1);
-  }
-
-  public void inc(long n) {
-    add(n);
-  }
 }

http://git-wip-us.apache.org/repos/asf/beam/blob/3a4ffd2c/sdks/java/core/src/main/java/org/apache/beam/sdk/metrics/DirtyState.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/metrics/DirtyState.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/metrics/DirtyState.java
index 6706be8..4e0c15c 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/metrics/DirtyState.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/metrics/DirtyState.java
@@ -18,6 +18,7 @@
 
 package org.apache.beam.sdk.metrics;
 
+import java.io.Serializable;
 import java.util.concurrent.atomic.AtomicReference;
 import org.apache.beam.sdk.annotations.Experimental;
 import org.apache.beam.sdk.annotations.Experimental.Kind;
@@ -41,7 +42,7 @@ import org.apache.beam.sdk.annotations.Experimental.Kind;
  * completed.
  */
 @Experimental(Kind.METRICS)
-class DirtyState {
+class DirtyState implements Serializable {
   private enum State {
     /** Indicates that there have been changes to the MetricCell since last commit. */
     DIRTY,

http://git-wip-us.apache.org/repos/asf/beam/blob/3a4ffd2c/sdks/java/core/src/main/java/org/apache/beam/sdk/metrics/DistributionCell.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/metrics/DistributionCell.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/metrics/DistributionCell.java
index 0f3f6a4..93a3649 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/metrics/DistributionCell.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/metrics/DistributionCell.java
@@ -30,7 +30,7 @@ import org.apache.beam.sdk.annotations.Experimental.Kind;
  * of indirection.
  */
 @Experimental(Kind.METRICS)
-public class DistributionCell implements MetricCell<DistributionData> {
+public class DistributionCell implements MetricCell<Distribution, DistributionData> {
 
   private final DirtyState dirty = new DirtyState();
   private final AtomicReference<DistributionData> value =
@@ -42,16 +42,26 @@ public class DistributionCell implements MetricCell<DistributionData> {
    */
   DistributionCell() {}
 
-  /** Increment the counter by the given amount. */
+  /** Increment the distribution by the given amount. */
   public void update(long n) {
+    update(DistributionData.singleton(n));
+  }
+
+  @Override
+  public void update(DistributionData data) {
     DistributionData original;
     do {
       original = value.get();
-    } while (!value.compareAndSet(original, original.combine(DistributionData.singleton(n))));
+    } while (!value.compareAndSet(original, original.combine(data)));
     dirty.afterModification();
   }
 
   @Override
+  public void update(MetricCell<Distribution, DistributionData> other) {
+    update(other.getCumulative());
+  }
+
+  @Override
   public DirtyState getDirty() {
     return dirty;
   }

http://git-wip-us.apache.org/repos/asf/beam/blob/3a4ffd2c/sdks/java/core/src/main/java/org/apache/beam/sdk/metrics/GaugeCell.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/metrics/GaugeCell.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/metrics/GaugeCell.java
index 6f8e880..0cdd568 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/metrics/GaugeCell.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/metrics/GaugeCell.java
@@ -29,17 +29,33 @@ import org.apache.beam.sdk.annotations.Experimental;
  * of indirection.
  */
 @Experimental(Experimental.Kind.METRICS)
-public class GaugeCell implements MetricCell<GaugeData> {
+public class GaugeCell implements MetricCell<Gauge, GaugeData> {
 
   private final DirtyState dirty = new DirtyState();
   private final AtomicReference<GaugeData> gaugeValue = new AtomicReference<>(GaugeData.empty());
 
+  /** Set the gauge to the given value. */
   public void set(long value) {
+    update(GaugeData.create(value));
+  }
+
+  @Override
+  public void update(GaugeData data) {
+    GaugeData original;
+    do {
+      original = gaugeValue.get();
+    } while (!gaugeValue.compareAndSet(original, original.combine(data)));
+    dirty.afterModification();
+  }
+
+  @Override
+  public void update(MetricCell<Gauge, GaugeData> other) {
     GaugeData original;
     do {
       original = gaugeValue.get();
-    } while (!gaugeValue.compareAndSet(original, original.combine(GaugeData.create(value))));
+    } while (!gaugeValue.compareAndSet(original, original.combine(other.getCumulative())));
     dirty.afterModification();
+    update(other.getCumulative());
   }
 
   @Override

http://git-wip-us.apache.org/repos/asf/beam/blob/3a4ffd2c/sdks/java/core/src/main/java/org/apache/beam/sdk/metrics/MetricCell.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/metrics/MetricCell.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/metrics/MetricCell.java
index 82e30cb..403cac2 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/metrics/MetricCell.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/metrics/MetricCell.java
@@ -17,6 +17,7 @@
  */
 package org.apache.beam.sdk.metrics;
 
+import java.io.Serializable;
 import org.apache.beam.sdk.annotations.Experimental;
 import org.apache.beam.sdk.annotations.Experimental.Kind;
 
@@ -24,10 +25,21 @@ import org.apache.beam.sdk.annotations.Experimental.Kind;
  * A {@link MetricCell} is used for accumulating in-memory changes to a metric. It represents a
  * specific metric name in a single context.
  *
+ * @param <UserT> The type of the user interface for reporting changes to this cell.
  * @param <DataT> The type of metric data stored (and extracted) from this cell.
  */
 @Experimental(Kind.METRICS)
-public interface MetricCell<DataT> {
+public interface MetricCell<UserT extends Metric, DataT> extends Serializable {
+
+  /**
+   * Update value of this cell.
+   */
+  void update(DataT data);
+
+  /**
+   * Update value of this cell by merging the value of another cell.
+   */
+  void update(MetricCell<UserT, DataT> other);
 
   /**
    * Return the {@link DirtyState} tracking whether this metric cell contains uncommitted changes.