You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by tg...@apache.org on 2017/11/18 01:23:50 UTC

[3/4] beam git commit: Add A TransformExecutorFactory

Add A TransformExecutorFactory

This creates executors provided an input bundle, a transform, a
completion callback, and a service to execute the transform on.


Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/8101103b
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/8101103b
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/8101103b

Branch: refs/heads/master
Commit: 8101103beb4de82acb1d4785097281c071f124d7
Parents: c8d45d4
Author: Thomas Groh <tg...@google.com>
Authored: Mon Nov 13 18:08:17 2017 -0800
Committer: Thomas Groh <tg...@google.com>
Committed: Fri Nov 17 17:23:41 2017 -0800

----------------------------------------------------------------------
 .../runners/direct/DirectTransformExecutor.java | 196 +++++++
 .../direct/ExecutorServiceParallelExecutor.java |  56 +-
 .../beam/runners/direct/TransformExecutor.java  | 165 +-----
 .../direct/TransformExecutorFactory.java        |  32 ++
 .../direct/TransformExecutorService.java        |   6 +-
 .../direct/TransformExecutorServices.java       |  14 +-
 .../direct/DirectRunnerApiSurfaceTest.java      |   5 +-
 .../direct/DirectTransformExecutorTest.java     | 537 +++++++++++++++++++
 .../direct/TransformExecutorServicesTest.java   |  40 +-
 .../runners/direct/TransformExecutorTest.java   | 537 -------------------
 10 files changed, 823 insertions(+), 765 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/beam/blob/8101103b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectTransformExecutor.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectTransformExecutor.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectTransformExecutor.java
new file mode 100644
index 0000000..a5aa809
--- /dev/null
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectTransformExecutor.java
@@ -0,0 +1,196 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.direct;
+
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.MoreObjects;
+import java.io.Closeable;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.Map;
+import java.util.concurrent.Callable;
+import org.apache.beam.runners.core.construction.PTransformTranslation;
+import org.apache.beam.runners.core.metrics.MetricUpdates;
+import org.apache.beam.runners.core.metrics.MetricsContainerImpl;
+import org.apache.beam.sdk.metrics.MetricsEnvironment;
+import org.apache.beam.sdk.runners.AppliedPTransform;
+import org.apache.beam.sdk.util.WindowedValue;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * A {@link Callable} responsible for constructing a {@link TransformEvaluator} from a {@link
+ * TransformEvaluatorFactory} and evaluating it on some bundle of input, and registering the result
+ * using a registered {@link CompletionCallback}.
+ */
+class DirectTransformExecutor<T> implements TransformExecutor {
+  private static final Logger LOG = LoggerFactory.getLogger(DirectTransformExecutor.class);
+
+  static class Factory implements TransformExecutorFactory {
+    private final EvaluationContext context;
+    private final TransformEvaluatorRegistry registry;
+    private final Map<String, Collection<ModelEnforcementFactory>> transformEnforcements;
+
+    Factory(
+        EvaluationContext context,
+        TransformEvaluatorRegistry registry,
+        Map<String, Collection<ModelEnforcementFactory>> transformEnforcements) {
+      this.context = context;
+      this.registry = registry;
+      this.transformEnforcements = transformEnforcements;
+    }
+
+    @Override
+    public TransformExecutor create(
+        CommittedBundle<?> bundle,
+        AppliedPTransform<?, ?, ?> transform,
+        CompletionCallback onComplete,
+        TransformExecutorService executorService) {
+      Collection<ModelEnforcementFactory> enforcements =
+          MoreObjects.firstNonNull(
+              transformEnforcements.get(
+                  PTransformTranslation.urnForTransform(transform.getTransform())),
+              Collections.<ModelEnforcementFactory>emptyList());
+      return new DirectTransformExecutor<>(
+          context, registry, enforcements, bundle, transform, onComplete, executorService);
+    }
+  }
+
+  private final TransformEvaluatorFactory evaluatorFactory;
+  private final Iterable<? extends ModelEnforcementFactory> modelEnforcements;
+
+  /** The transform that will be evaluated. */
+  private final AppliedPTransform<?, ?, ?> transform;
+  /** The inputs this {@link DirectTransformExecutor} will deliver to the transform. */
+  private final CommittedBundle<T> inputBundle;
+
+  private final CompletionCallback onComplete;
+  private final TransformExecutorService transformEvaluationState;
+  private final EvaluationContext context;
+
+  @VisibleForTesting
+  DirectTransformExecutor(
+      EvaluationContext context,
+      TransformEvaluatorFactory factory,
+      Iterable<? extends ModelEnforcementFactory> modelEnforcements,
+      CommittedBundle<T> inputBundle,
+      AppliedPTransform<?, ?, ?> transform,
+      CompletionCallback completionCallback,
+      TransformExecutorService transformEvaluationState) {
+    this.evaluatorFactory = factory;
+    this.modelEnforcements = modelEnforcements;
+
+    this.inputBundle = inputBundle;
+    this.transform = transform;
+
+    this.onComplete = completionCallback;
+
+    this.transformEvaluationState = transformEvaluationState;
+    this.context = context;
+  }
+
+  @Override
+  public void run() {
+    MetricsContainerImpl metricsContainer = new MetricsContainerImpl(transform.getFullName());
+    try (Closeable metricsScope = MetricsEnvironment.scopedMetricsContainer(metricsContainer)) {
+      Collection<ModelEnforcement<T>> enforcements = new ArrayList<>();
+      for (ModelEnforcementFactory enforcementFactory : modelEnforcements) {
+        ModelEnforcement<T> enforcement = enforcementFactory.forBundle(inputBundle, transform);
+        enforcements.add(enforcement);
+      }
+      TransformEvaluator<T> evaluator =
+          evaluatorFactory.forApplication(transform, inputBundle);
+      if (evaluator == null) {
+        onComplete.handleEmpty(transform);
+        // Nothing to do
+        return;
+      }
+
+      processElements(evaluator, metricsContainer, enforcements);
+
+      finishBundle(evaluator, metricsContainer, enforcements);
+    } catch (Exception e) {
+      onComplete.handleException(inputBundle, e);
+      if (e instanceof RuntimeException) {
+        throw (RuntimeException) e;
+      }
+      throw new RuntimeException(e);
+    } catch (Error err) {
+      LOG.error("Error occurred within {}", this, err);
+      onComplete.handleError(err);
+      throw err;
+    } finally {
+      // Report the physical metrics from the end of this step.
+      context.getMetrics().commitPhysical(inputBundle, metricsContainer.getCumulative());
+
+      transformEvaluationState.complete(this);
+    }
+  }
+
+  /**
+   * Processes all the elements in the input bundle using the transform evaluator, applying any
+   * necessary {@link ModelEnforcement ModelEnforcements}.
+   */
+  private void processElements(
+      TransformEvaluator<T> evaluator,
+      MetricsContainerImpl metricsContainer,
+      Collection<ModelEnforcement<T>> enforcements)
+      throws Exception {
+    if (inputBundle != null) {
+      for (WindowedValue<T> value : inputBundle.getElements()) {
+        for (ModelEnforcement<T> enforcement : enforcements) {
+          enforcement.beforeElement(value);
+        }
+
+        evaluator.processElement(value);
+
+        // Report the physical metrics after each element
+        MetricUpdates deltas = metricsContainer.getUpdates();
+        if (deltas != null) {
+          context.getMetrics().updatePhysical(inputBundle, deltas);
+          metricsContainer.commitUpdates();
+        }
+
+        for (ModelEnforcement<T> enforcement : enforcements) {
+          enforcement.afterElement(value);
+        }
+      }
+    }
+  }
+
+  /**
+   * Finishes processing the input bundle and commit the result using the
+   * {@link CompletionCallback}, applying any {@link ModelEnforcement} if necessary.
+   *
+   * @return the {@link TransformResult} produced by
+   *         {@link TransformEvaluator#finishBundle()}
+   */
+  private TransformResult<T> finishBundle(
+      TransformEvaluator<T> evaluator, MetricsContainerImpl metricsContainer,
+      Collection<ModelEnforcement<T>> enforcements)
+      throws Exception {
+    TransformResult<T> result =
+        evaluator.finishBundle().withLogicalMetricUpdates(metricsContainer.getCumulative());
+    CommittedResult outputs = onComplete.handleResult(inputBundle, result);
+    for (ModelEnforcement<T> enforcement : enforcements) {
+      enforcement.afterFinish(inputBundle, result, outputs.getOutputs());
+    }
+    return result;
+  }
+}

http://git-wip-us.apache.org/repos/asf/beam/blob/8101103b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ExecutorServiceParallelExecutor.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ExecutorServiceParallelExecutor.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ExecutorServiceParallelExecutor.java
index 75e2562..fe3765b 100644
--- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ExecutorServiceParallelExecutor.java
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ExecutorServiceParallelExecutor.java
@@ -20,7 +20,6 @@ package org.apache.beam.runners.direct;
 import static com.google.common.base.Preconditions.checkState;
 
 import com.google.auto.value.AutoValue;
-import com.google.common.base.MoreObjects;
 import com.google.common.base.Optional;
 import com.google.common.cache.CacheBuilder;
 import com.google.common.cache.CacheLoader;
@@ -49,7 +48,6 @@ import javax.annotation.Nullable;
 import org.apache.beam.runners.core.KeyedWorkItem;
 import org.apache.beam.runners.core.KeyedWorkItems;
 import org.apache.beam.runners.core.TimerInternals.TimerData;
-import org.apache.beam.runners.core.construction.PTransformTranslation;
 import org.apache.beam.runners.direct.WatermarkManager.FiredTimers;
 import org.apache.beam.sdk.Pipeline;
 import org.apache.beam.sdk.PipelineResult.State;
@@ -77,32 +75,33 @@ final class ExecutorServiceParallelExecutor implements PipelineExecutor {
   private final DirectGraph graph;
   private final RootProviderRegistry rootProviderRegistry;
   private final TransformEvaluatorRegistry registry;
-  private final Map<String, Collection<ModelEnforcementFactory>> transformEnforcements;
 
   private final EvaluationContext evaluationContext;
 
-  private final LoadingCache<StepAndKey, TransformExecutorService> executorServices;
+  private final TransformExecutorFactory executorFactory;
+  private final TransformExecutorService parallelExecutorService;
+  private final LoadingCache<StepAndKey, TransformExecutorService> serialExecutorServices;
 
   private final Queue<ExecutorUpdate> allUpdates;
   private final BlockingQueue<VisibleExecutorUpdate> visibleUpdates;
 
-  private final TransformExecutorService parallelExecutorService;
   private final CompletionCallback defaultCompletionCallback;
 
   private final ConcurrentMap<AppliedPTransform<?, ?, ?>, ConcurrentLinkedQueue<CommittedBundle<?>>>
       pendingRootBundles;
 
- private final AtomicReference<ExecutorState> state =
+  private final AtomicReference<ExecutorState> state =
       new AtomicReference<>(ExecutorState.QUIESCENT);
 
   /**
-   * Measures the number of {@link TransformExecutor TransformExecutors} that have been scheduled
-   * but not yet completed.
+   * Measures the number of {@link TransformExecutor TransformExecutors} that have been
+   * scheduled but not yet completed.
    *
-   * <p>Before a {@link TransformExecutor} is scheduled, this value is incremented. All methods in
-   * {@link CompletionCallback} decrement this value.
+   * <p>Before a {@link TransformExecutor} is scheduled, this value is incremented. All
+   * methods in {@link CompletionCallback} decrement this value.
    */
   private final AtomicLong outstandingWork = new AtomicLong();
+
   private AtomicReference<State> pipelineState = new AtomicReference<>(State.RUNNING);
 
   public static ExecutorServiceParallelExecutor create(
@@ -141,13 +140,12 @@ final class ExecutorServiceParallelExecutor implements PipelineExecutor {
     this.graph = graph;
     this.rootProviderRegistry = rootProviderRegistry;
     this.registry = registry;
-    this.transformEnforcements = transformEnforcements;
     this.evaluationContext = context;
 
     // Weak Values allows TransformExecutorServices that are no longer in use to be reclaimed.
     // Executing TransformExecutorServices have a strong reference to their TransformExecutorService
     // which stops the TransformExecutorServices from being prematurely garbage collected
-    executorServices =
+    serialExecutorServices =
         CacheBuilder.newBuilder()
             .weakValues()
             .removalListener(shutdownExecutorServiceListener())
@@ -160,6 +158,7 @@ final class ExecutorServiceParallelExecutor implements PipelineExecutor {
     defaultCompletionCallback =
         new TimerIterableCompletionCallback(Collections.<TimerData>emptyList());
     this.pendingRootBundles = new ConcurrentHashMap<>();
+    executorFactory = new DirectTransformExecutor.Factory(context, registry, transformEnforcements);
   }
 
   private CacheLoader<StepAndKey, TransformExecutorService>
@@ -222,29 +221,16 @@ final class ExecutorServiceParallelExecutor implements PipelineExecutor {
       final StepAndKey stepAndKey = StepAndKey.of(transform, bundle.getKey());
       // This executor will remain reachable until it has executed all scheduled transforms.
       // The TransformExecutors keep a strong reference to the Executor, the ExecutorService keeps
-      // a reference to the scheduled TransformExecutor callable. Follow-up TransformExecutors
-      // (scheduled due to the completion of another TransformExecutor) are provided to the
-      // ExecutorService before the Earlier TransformExecutor callable completes.
-      transformExecutor = executorServices.getUnchecked(stepAndKey);
+      // a reference to the scheduled DirectTransformExecutor callable. Follow-up TransformExecutors
+      // (scheduled due to the completion of another DirectTransformExecutor) are provided to the
+      // ExecutorService before the Earlier DirectTransformExecutor callable completes.
+      transformExecutor = serialExecutorServices.getUnchecked(stepAndKey);
     } else {
       transformExecutor = parallelExecutorService;
     }
 
-    Collection<ModelEnforcementFactory> enforcements =
-        MoreObjects.firstNonNull(
-            transformEnforcements.get(
-                PTransformTranslation.urnForTransform(transform.getTransform())),
-            Collections.<ModelEnforcementFactory>emptyList());
-
-    TransformExecutor<T> callable =
-        TransformExecutor.create(
-            evaluationContext,
-            registry,
-            enforcements,
-            bundle,
-            transform,
-            onComplete,
-            transformExecutor);
+    TransformExecutor callable =
+        executorFactory.create(bundle, transform, onComplete, transformExecutor);
     outstandingWork.incrementAndGet();
     if (!pipelineState.get().isTerminal()) {
       transformExecutor.schedule(callable);
@@ -321,8 +307,8 @@ final class ExecutorServiceParallelExecutor implements PipelineExecutor {
     pipelineState.compareAndSet(State.RUNNING, newState);
     // Stop accepting new work before shutting down the executor. This ensures that thread don't try
     // to add work to the shutdown executor.
-    executorServices.invalidateAll();
-    executorServices.cleanUp();
+    serialExecutorServices.invalidateAll();
+    serialExecutorServices.cleanUp();
     parallelExecutorService.shutdown();
     executorService.shutdown();
     try {
@@ -584,7 +570,7 @@ final class ExecutorServiceParallelExecutor implements PipelineExecutor {
     }
 
     /**
-     * If all active {@link TransformExecutor TransformExecutors} are in a blocked state,
+     * If all active {@link DirectTransformExecutor TransformExecutors} are in a blocked state,
      * add more work from root nodes that may have additional work. This ensures that if a pipeline
      * has elements available from the root nodes it will add those elements when necessary.
      */
@@ -621,7 +607,7 @@ final class ExecutorServiceParallelExecutor implements PipelineExecutor {
      * been evaluated, and all pending, including potentially blocked work, should be evaluated.
      *
      * <p>The executor becomes active whenever a timer fires, a {@link PCollectionView} is updated,
-     * or output is produced by the evaluation of a {@link TransformExecutor}.
+     * or output is produced by the evaluation of a {@link DirectTransformExecutor}.
      */
     ACTIVE,
     /**

http://git-wip-us.apache.org/repos/asf/beam/blob/8101103b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TransformExecutor.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TransformExecutor.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TransformExecutor.java
index 76d817b..1e269de 100644
--- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TransformExecutor.java
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TransformExecutor.java
@@ -15,167 +15,8 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-package org.apache.beam.runners.direct;
-
-import java.io.Closeable;
-import java.util.ArrayList;
-import java.util.Collection;
-import java.util.concurrent.Callable;
-import org.apache.beam.runners.core.metrics.MetricUpdates;
-import org.apache.beam.runners.core.metrics.MetricsContainerImpl;
-import org.apache.beam.sdk.metrics.MetricsEnvironment;
-import org.apache.beam.sdk.runners.AppliedPTransform;
-import org.apache.beam.sdk.util.WindowedValue;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
-/**
- * A {@link Callable} responsible for constructing a {@link TransformEvaluator} from a
- * {@link TransformEvaluatorFactory} and evaluating it on some bundle of input, and registering
- * the result using a registered {@link CompletionCallback}.
- *
- * <p>A {@link TransformExecutor} that is currently executing also provides access to the thread
- * that it is being executed on.
- */
-class TransformExecutor<T> implements Runnable {
-  private static final Logger LOG = LoggerFactory.getLogger(TransformExecutor.class);
-
-  public static <T> TransformExecutor<T> create(
-      EvaluationContext context,
-      TransformEvaluatorFactory factory,
-      Iterable<? extends ModelEnforcementFactory> modelEnforcements,
-      CommittedBundle<T> inputBundle,
-      AppliedPTransform<?, ?, ?> transform,
-      CompletionCallback completionCallback,
-      TransformExecutorService transformEvaluationState) {
-    return new TransformExecutor<>(
-        context,
-        factory,
-        modelEnforcements,
-        inputBundle,
-        transform,
-        completionCallback,
-        transformEvaluationState);
-  }
-
-  private final TransformEvaluatorFactory evaluatorFactory;
-  private final Iterable<? extends ModelEnforcementFactory> modelEnforcements;
-
-  /** The transform that will be evaluated. */
-  private final AppliedPTransform<?, ?, ?> transform;
-  /** The inputs this {@link TransformExecutor} will deliver to the transform. */
-  private final CommittedBundle<T> inputBundle;
-
-  private final CompletionCallback onComplete;
-  private final TransformExecutorService transformEvaluationState;
-  private final EvaluationContext context;
-
-  private TransformExecutor(
-      EvaluationContext context,
-      TransformEvaluatorFactory factory,
-      Iterable<? extends ModelEnforcementFactory> modelEnforcements,
-      CommittedBundle<T> inputBundle,
-      AppliedPTransform<?, ?, ?> transform,
-      CompletionCallback completionCallback,
-      TransformExecutorService transformEvaluationState) {
-    this.evaluatorFactory = factory;
-    this.modelEnforcements = modelEnforcements;
-
-    this.inputBundle = inputBundle;
-    this.transform = transform;
 
-    this.onComplete = completionCallback;
-
-    this.transformEvaluationState = transformEvaluationState;
-    this.context = context;
-  }
-
-  @Override
-  public void run() {
-    MetricsContainerImpl metricsContainer = new MetricsContainerImpl(transform.getFullName());
-    try (Closeable metricsScope = MetricsEnvironment.scopedMetricsContainer(metricsContainer)) {
-      Collection<ModelEnforcement<T>> enforcements = new ArrayList<>();
-      for (ModelEnforcementFactory enforcementFactory : modelEnforcements) {
-        ModelEnforcement<T> enforcement = enforcementFactory.forBundle(inputBundle, transform);
-        enforcements.add(enforcement);
-      }
-      TransformEvaluator<T> evaluator =
-          evaluatorFactory.forApplication(transform, inputBundle);
-      if (evaluator == null) {
-        onComplete.handleEmpty(transform);
-        // Nothing to do
-        return;
-      }
-
-      processElements(evaluator, metricsContainer, enforcements);
-
-      finishBundle(evaluator, metricsContainer, enforcements);
-    } catch (Exception e) {
-      onComplete.handleException(inputBundle, e);
-      if (e instanceof RuntimeException) {
-        throw (RuntimeException) e;
-      }
-      throw new RuntimeException(e);
-    } catch (Error err) {
-      LOG.error("Error occurred within {}", this, err);
-      onComplete.handleError(err);
-      throw err;
-    } finally {
-      // Report the physical metrics from the end of this step.
-      context.getMetrics().commitPhysical(inputBundle, metricsContainer.getCumulative());
-
-      transformEvaluationState.complete(this);
-    }
-  }
-
-  /**
-   * Processes all the elements in the input bundle using the transform evaluator, applying any
-   * necessary {@link ModelEnforcement ModelEnforcements}.
-   */
-  private void processElements(
-      TransformEvaluator<T> evaluator,
-      MetricsContainerImpl metricsContainer,
-      Collection<ModelEnforcement<T>> enforcements)
-      throws Exception {
-    if (inputBundle != null) {
-      for (WindowedValue<T> value : inputBundle.getElements()) {
-        for (ModelEnforcement<T> enforcement : enforcements) {
-          enforcement.beforeElement(value);
-        }
-
-        evaluator.processElement(value);
-
-        // Report the physical metrics after each element
-        MetricUpdates deltas = metricsContainer.getUpdates();
-        if (deltas != null) {
-          context.getMetrics().updatePhysical(inputBundle, deltas);
-          metricsContainer.commitUpdates();
-        }
-
-        for (ModelEnforcement<T> enforcement : enforcements) {
-          enforcement.afterElement(value);
-        }
-      }
-    }
-  }
+package org.apache.beam.runners.direct;
 
-  /**
-   * Finishes processing the input bundle and commit the result using the
-   * {@link CompletionCallback}, applying any {@link ModelEnforcement} if necessary.
-   *
-   * @return the {@link TransformResult} produced by
-   *         {@link TransformEvaluator#finishBundle()}
-   */
-  private TransformResult<T> finishBundle(
-      TransformEvaluator<T> evaluator, MetricsContainerImpl metricsContainer,
-      Collection<ModelEnforcement<T>> enforcements)
-      throws Exception {
-    TransformResult<T> result =
-        evaluator.finishBundle().withLogicalMetricUpdates(metricsContainer.getCumulative());
-    CommittedResult outputs = onComplete.handleResult(inputBundle, result);
-    for (ModelEnforcement<T> enforcement : enforcements) {
-      enforcement.afterFinish(inputBundle, result, outputs.getOutputs());
-    }
-    return result;
-  }
-}
+/** A {@link Runnable} that will execute a {@code PTransform} on some bundle of input. */
+public interface TransformExecutor extends Runnable {}

http://git-wip-us.apache.org/repos/asf/beam/blob/8101103b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TransformExecutorFactory.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TransformExecutorFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TransformExecutorFactory.java
new file mode 100644
index 0000000..d78c265
--- /dev/null
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TransformExecutorFactory.java
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.direct;
+
+import org.apache.beam.sdk.runners.AppliedPTransform;
+
+/**
+ * A Factory for creating {@link TransformExecutor Transform Executors} on an input.
+ */
+interface TransformExecutorFactory {
+  TransformExecutor create(
+      CommittedBundle<?> bundle,
+      AppliedPTransform<?, ?, ?> transform,
+      CompletionCallback onComplete,
+      TransformExecutorService executorService);
+}

http://git-wip-us.apache.org/repos/asf/beam/blob/8101103b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TransformExecutorService.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TransformExecutorService.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TransformExecutorService.java
index c6f770f..90b3960 100644
--- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TransformExecutorService.java
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TransformExecutorService.java
@@ -25,13 +25,13 @@ interface TransformExecutorService {
   /**
    * Schedule the provided work to be eventually executed.
    */
-  void schedule(TransformExecutor<?> work);
+  void schedule(TransformExecutor work);
 
   /**
    * Finish executing the provided work. This may cause additional
-   * {@link TransformExecutor TransformExecutors} to be evaluated.
+   * {@link DirectTransformExecutor TransformExecutors} to be evaluated.
    */
-  void complete(TransformExecutor<?> completed);
+  void complete(TransformExecutor completed);
 
   /**
    * Cancel any outstanding work, if possible. Any future calls to schedule should ignore any

http://git-wip-us.apache.org/repos/asf/beam/blob/8101103b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TransformExecutorServices.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TransformExecutorServices.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TransformExecutorServices.java
index 53087bf..9aa71f7 100644
--- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TransformExecutorServices.java
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TransformExecutorServices.java
@@ -69,7 +69,7 @@ final class TransformExecutorServices {
     }
 
     @Override
-    public void schedule(TransformExecutor<?> work) {
+    public void schedule(TransformExecutor work) {
       if (active.get()) {
         try {
           executor.submit(work);
@@ -92,7 +92,7 @@ final class TransformExecutorServices {
     }
 
     @Override
-    public void complete(TransformExecutor<?> completed) {
+    public void complete(TransformExecutor completed) {
     }
 
     @Override
@@ -112,8 +112,8 @@ final class TransformExecutorServices {
   private static class SerialTransformExecutor implements TransformExecutorService {
     private final ExecutorService executor;
 
-    private AtomicReference<TransformExecutor<?>> currentlyEvaluating;
-    private final Queue<TransformExecutor<?>> workQueue;
+    private AtomicReference<TransformExecutor> currentlyEvaluating;
+    private final Queue<TransformExecutor> workQueue;
     private boolean active = true;
 
     private SerialTransformExecutor(ExecutorService executor) {
@@ -127,13 +127,13 @@ final class TransformExecutorServices {
      * evaluated and scheduling it immediately otherwise.
      */
     @Override
-    public void schedule(TransformExecutor<?> work) {
+    public void schedule(TransformExecutor work) {
       workQueue.offer(work);
       updateCurrentlyEvaluating();
     }
 
     @Override
-    public void complete(TransformExecutor<?> completed) {
+    public void complete(TransformExecutor completed) {
       if (!currentlyEvaluating.compareAndSet(completed, null)) {
         throw new IllegalStateException(
             "Finished work "
@@ -156,7 +156,7 @@ final class TransformExecutorServices {
       if (currentlyEvaluating.get() == null) {
         // Only synchronize if we need to update what's currently evaluating
         synchronized (this) {
-          TransformExecutor<?> newWork = workQueue.poll();
+          TransformExecutor newWork = workQueue.poll();
           if (active && newWork != null) {
             if (currentlyEvaluating.compareAndSet(null, newWork)) {
               executor.submit(newWork);

http://git-wip-us.apache.org/repos/asf/beam/blob/8101103b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectRunnerApiSurfaceTest.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectRunnerApiSurfaceTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectRunnerApiSurfaceTest.java
index 631349f..f116709 100644
--- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectRunnerApiSurfaceTest.java
+++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectRunnerApiSurfaceTest.java
@@ -42,7 +42,10 @@ public class DirectRunnerApiSurfaceTest {
     // The DirectRunner can expose the Core SDK, anything exposed by the Core SDK, and itself
     @SuppressWarnings("unchecked")
     final Set<String> allowed =
-        ImmutableSet.of("org.apache.beam.sdk", "org.apache.beam.runners.direct", "org.joda.time");
+        ImmutableSet.of(
+            "org.apache.beam.sdk",
+            "org.apache.beam.runners.direct",
+            "org.joda.time");
 
     final Package thisPackage = getClass().getPackage();
     final ClassLoader thisClassLoader = getClass().getClassLoader();

http://git-wip-us.apache.org/repos/asf/beam/blob/8101103b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectTransformExecutorTest.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectTransformExecutorTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectTransformExecutorTest.java
new file mode 100644
index 0000000..aa83e0b
--- /dev/null
+++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectTransformExecutorTest.java
@@ -0,0 +1,537 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.direct;
+
+import static org.hamcrest.Matchers.containsInAnyOrder;
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.is;
+import static org.hamcrest.Matchers.isA;
+import static org.hamcrest.Matchers.nullValue;
+import static org.junit.Assert.assertThat;
+import static org.mockito.Mockito.when;
+
+import com.google.common.base.Optional;
+import com.google.common.collect.Iterables;
+import com.google.common.util.concurrent.MoreExecutors;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.EnumSet;
+import java.util.List;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
+import java.util.concurrent.atomic.AtomicBoolean;
+import org.apache.beam.runners.direct.CommittedResult.OutputType;
+import org.apache.beam.sdk.runners.AppliedPTransform;
+import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.transforms.WithKeys;
+import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.values.PCollection;
+import org.hamcrest.Matchers;
+import org.joda.time.Instant;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.ExpectedException;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+import org.mockito.Mock;
+import org.mockito.MockitoAnnotations;
+
+/** Tests for {@link DirectTransformExecutor}. */
+@RunWith(JUnit4.class)
+public class DirectTransformExecutorTest {
+  @Rule public ExpectedException thrown = ExpectedException.none();
+  private PCollection<String> created;
+
+  private AppliedPTransform<?, ?, ?> createdProducer;
+  private AppliedPTransform<?, ?, ?> downstreamProducer;
+
+  private CountDownLatch evaluatorCompleted;
+
+  private RegisteringCompletionCallback completionCallback;
+  private TransformExecutorService transformEvaluationState;
+  private BundleFactory bundleFactory;
+  @Mock private DirectMetrics metrics;
+  @Mock private EvaluationContext evaluationContext;
+  @Mock private TransformEvaluatorRegistry registry;
+
+  @Rule
+  public TestPipeline p = TestPipeline.create().enableAbandonedNodeEnforcement(false);
+
+  @Before
+  public void setup() {
+    MockitoAnnotations.initMocks(this);
+
+    bundleFactory = ImmutableListBundleFactory.create();
+
+    transformEvaluationState =
+        TransformExecutorServices.parallel(MoreExecutors.newDirectExecutorService());
+
+    evaluatorCompleted = new CountDownLatch(1);
+    completionCallback = new RegisteringCompletionCallback(evaluatorCompleted);
+
+    created = p.apply(Create.of("foo", "spam", "third"));
+    PCollection<KV<Integer, String>> downstream = created.apply(WithKeys.<Integer, String>of(3));
+
+    DirectGraphs.performDirectOverrides(p);
+    DirectGraph graph = DirectGraphs.getGraph(p);
+    createdProducer = graph.getProducer(created);
+    downstreamProducer = graph.getProducer(downstream);
+
+    when(evaluationContext.getMetrics()).thenReturn(metrics);
+  }
+
+  @Test
+  public void callWithNullInputBundleFinishesBundleAndCompletes() throws Exception {
+    final TransformResult<Object> result = StepTransformResult.withoutHold(createdProducer).build();
+    final AtomicBoolean finishCalled = new AtomicBoolean(false);
+    TransformEvaluator<Object> evaluator =
+        new TransformEvaluator<Object>() {
+          @Override
+          public void processElement(WindowedValue<Object> element) throws Exception {
+            throw new IllegalArgumentException("Shouldn't be called");
+          }
+
+          @Override
+          public TransformResult<Object> finishBundle() throws Exception {
+            finishCalled.set(true);
+            return result;
+          }
+        };
+
+    when(registry.forApplication(createdProducer, null)).thenReturn(evaluator);
+
+    DirectTransformExecutor<Object> executor =
+        new DirectTransformExecutor<>(
+            evaluationContext,
+            registry,
+            Collections.<ModelEnforcementFactory>emptyList(),
+            null,
+            createdProducer,
+            completionCallback,
+            transformEvaluationState);
+    executor.run();
+
+    assertThat(finishCalled.get(), is(true));
+    assertThat(completionCallback.handledResult, Matchers.<TransformResult<?>>equalTo(result));
+    assertThat(completionCallback.handledException, is(nullValue()));
+  }
+
+  @Test
+  public void nullTransformEvaluatorTerminates() throws Exception {
+    when(registry.forApplication(createdProducer, null)).thenReturn(null);
+
+    DirectTransformExecutor<Object> executor =
+        new DirectTransformExecutor<>(
+            evaluationContext,
+            registry,
+            Collections.<ModelEnforcementFactory>emptyList(),
+            null,
+            createdProducer,
+            completionCallback,
+            transformEvaluationState);
+    executor.run();
+
+    assertThat(completionCallback.handledResult, is(nullValue()));
+    assertThat(completionCallback.handledEmpty, equalTo(true));
+    assertThat(completionCallback.handledException, is(nullValue()));
+  }
+
+  @Test
+  public void inputBundleProcessesEachElementFinishesAndCompletes() throws Exception {
+    final TransformResult<String> result =
+        StepTransformResult.<String>withoutHold(downstreamProducer).build();
+    final Collection<WindowedValue<String>> elementsProcessed = new ArrayList<>();
+    TransformEvaluator<String> evaluator =
+        new TransformEvaluator<String>() {
+          @Override
+          public void processElement(WindowedValue<String> element) throws Exception {
+            elementsProcessed.add(element);
+            return;
+          }
+
+          @Override
+          public TransformResult<String> finishBundle() throws Exception {
+            return result;
+          }
+        };
+
+    WindowedValue<String> foo = WindowedValue.valueInGlobalWindow("foo");
+    WindowedValue<String> spam = WindowedValue.valueInGlobalWindow("spam");
+    WindowedValue<String> third = WindowedValue.valueInGlobalWindow("third");
+    CommittedBundle<String> inputBundle =
+        bundleFactory.createBundle(created).add(foo).add(spam).add(third).commit(Instant.now());
+    when(registry.<String>forApplication(downstreamProducer, inputBundle)).thenReturn(evaluator);
+
+    DirectTransformExecutor<String> executor =
+        new DirectTransformExecutor<>(
+            evaluationContext,
+            registry,
+            Collections.<ModelEnforcementFactory>emptyList(),
+            inputBundle,
+            downstreamProducer,
+            completionCallback,
+            transformEvaluationState);
+
+    Executors.newSingleThreadExecutor().submit(executor);
+
+    evaluatorCompleted.await();
+
+    assertThat(elementsProcessed, containsInAnyOrder(spam, third, foo));
+    assertThat(completionCallback.handledResult, Matchers.<TransformResult<?>>equalTo(result));
+    assertThat(completionCallback.handledException, is(nullValue()));
+  }
+
+  @Test
+  public void processElementThrowsExceptionCallsback() throws Exception {
+    final TransformResult<String> result =
+        StepTransformResult.<String>withoutHold(downstreamProducer).build();
+    final Exception exception = new Exception();
+    TransformEvaluator<String> evaluator =
+        new TransformEvaluator<String>() {
+          @Override
+          public void processElement(WindowedValue<String> element) throws Exception {
+            throw exception;
+          }
+
+          @Override
+          public TransformResult<String> finishBundle() throws Exception {
+            return result;
+          }
+        };
+
+    WindowedValue<String> foo = WindowedValue.valueInGlobalWindow("foo");
+    CommittedBundle<String> inputBundle =
+        bundleFactory.createBundle(created).add(foo).commit(Instant.now());
+    when(registry.<String>forApplication(downstreamProducer, inputBundle)).thenReturn(evaluator);
+
+    DirectTransformExecutor<String> executor =
+        new DirectTransformExecutor<>(
+            evaluationContext,
+            registry,
+            Collections.<ModelEnforcementFactory>emptyList(),
+            inputBundle,
+            downstreamProducer,
+            completionCallback,
+            transformEvaluationState);
+    Executors.newSingleThreadExecutor().submit(executor);
+
+    evaluatorCompleted.await();
+
+    assertThat(completionCallback.handledResult, is(nullValue()));
+    assertThat(completionCallback.handledException, Matchers.<Throwable>equalTo(exception));
+  }
+
+  @Test
+  public void finishBundleThrowsExceptionCallsback() throws Exception {
+    final Exception exception = new Exception();
+    TransformEvaluator<String> evaluator =
+        new TransformEvaluator<String>() {
+          @Override
+          public void processElement(WindowedValue<String> element) throws Exception {}
+
+          @Override
+          public TransformResult<String> finishBundle() throws Exception {
+            throw exception;
+          }
+        };
+
+    CommittedBundle<String> inputBundle = bundleFactory.createBundle(created).commit(Instant.now());
+    when(registry.<String>forApplication(downstreamProducer, inputBundle)).thenReturn(evaluator);
+
+    DirectTransformExecutor<String> executor =
+        new DirectTransformExecutor<>(
+            evaluationContext,
+            registry,
+            Collections.<ModelEnforcementFactory>emptyList(),
+            inputBundle,
+            downstreamProducer,
+            completionCallback,
+            transformEvaluationState);
+    Executors.newSingleThreadExecutor().submit(executor);
+
+    evaluatorCompleted.await();
+
+    assertThat(completionCallback.handledResult, is(nullValue()));
+    assertThat(completionCallback.handledException, Matchers.<Throwable>equalTo(exception));
+  }
+
+  @Test
+  public void callWithEnforcementAppliesEnforcement() throws Exception {
+    final TransformResult<Object> result =
+        StepTransformResult.withoutHold(downstreamProducer).build();
+
+    TransformEvaluator<Object> evaluator =
+        new TransformEvaluator<Object>() {
+          @Override
+          public void processElement(WindowedValue<Object> element) throws Exception {}
+
+          @Override
+          public TransformResult<Object> finishBundle() throws Exception {
+            return result;
+          }
+        };
+
+    WindowedValue<String> fooElem = WindowedValue.valueInGlobalWindow("foo");
+    WindowedValue<String> barElem = WindowedValue.valueInGlobalWindow("bar");
+    CommittedBundle<String> inputBundle =
+        bundleFactory.createBundle(created).add(fooElem).add(barElem).commit(Instant.now());
+    when(registry.forApplication(downstreamProducer, inputBundle)).thenReturn(evaluator);
+
+    TestEnforcementFactory enforcement = new TestEnforcementFactory();
+    DirectTransformExecutor<String> executor =
+        new DirectTransformExecutor<>(
+            evaluationContext,
+            registry,
+            Collections.<ModelEnforcementFactory>singleton(enforcement),
+            inputBundle,
+            downstreamProducer,
+            completionCallback,
+            transformEvaluationState);
+
+    executor.run();
+    TestEnforcement<?> testEnforcement = enforcement.instance;
+    assertThat(
+        testEnforcement.beforeElements,
+        Matchers.<WindowedValue<?>>containsInAnyOrder(barElem, fooElem));
+    assertThat(
+        testEnforcement.afterElements,
+        Matchers.<WindowedValue<?>>containsInAnyOrder(barElem, fooElem));
+    assertThat(testEnforcement.finishedBundles, Matchers.<TransformResult<?>>contains(result));
+  }
+
+  @Test
+  public void callWithEnforcementThrowsOnFinishPropagates() throws Exception {
+    final TransformResult<Object> result =
+        StepTransformResult.withoutHold(createdProducer).build();
+
+    TransformEvaluator<Object> evaluator =
+        new TransformEvaluator<Object>() {
+          @Override
+          public void processElement(WindowedValue<Object> element) throws Exception {}
+
+          @Override
+          public TransformResult<Object> finishBundle() throws Exception {
+            return result;
+          }
+        };
+
+    WindowedValue<String> fooBytes = WindowedValue.valueInGlobalWindow("foo");
+    CommittedBundle<String> inputBundle =
+        bundleFactory.createBundle(created).add(fooBytes).commit(Instant.now());
+    when(registry.forApplication(downstreamProducer, inputBundle)).thenReturn(evaluator);
+
+    DirectTransformExecutor<String> executor =
+        new DirectTransformExecutor<>(
+            evaluationContext,
+            registry,
+            Collections.<ModelEnforcementFactory>singleton(
+                new ThrowingEnforcementFactory(ThrowingEnforcementFactory.When.AFTER_BUNDLE)),
+            inputBundle,
+            downstreamProducer,
+            completionCallback,
+            transformEvaluationState);
+
+    Future<?> task = Executors.newSingleThreadExecutor().submit(executor);
+
+    thrown.expectCause(isA(RuntimeException.class));
+    thrown.expectMessage("afterFinish");
+    task.get();
+  }
+
+  @Test
+  public void callWithEnforcementThrowsOnElementPropagates() throws Exception {
+    final TransformResult<Object> result =
+        StepTransformResult.withoutHold(createdProducer).build();
+
+    TransformEvaluator<Object> evaluator =
+        new TransformEvaluator<Object>() {
+          @Override
+          public void processElement(WindowedValue<Object> element) throws Exception {}
+
+          @Override
+          public TransformResult<Object> finishBundle() throws Exception {
+            return result;
+          }
+        };
+
+    WindowedValue<String> fooBytes = WindowedValue.valueInGlobalWindow("foo");
+    CommittedBundle<String> inputBundle =
+        bundleFactory.createBundle(created).add(fooBytes).commit(Instant.now());
+    when(registry.forApplication(downstreamProducer, inputBundle)).thenReturn(evaluator);
+
+    DirectTransformExecutor<String> executor =
+        new DirectTransformExecutor<>(
+            evaluationContext,
+            registry,
+            Collections.<ModelEnforcementFactory>singleton(
+                new ThrowingEnforcementFactory(ThrowingEnforcementFactory.When.AFTER_ELEMENT)),
+            inputBundle,
+            downstreamProducer,
+            completionCallback,
+            transformEvaluationState);
+
+    Future<?> task = Executors.newSingleThreadExecutor().submit(executor);
+
+    thrown.expectCause(isA(RuntimeException.class));
+    thrown.expectMessage("afterElement");
+    task.get();
+  }
+
+  private static class RegisteringCompletionCallback implements CompletionCallback {
+    private TransformResult<?> handledResult = null;
+    private boolean handledEmpty = false;
+    private Exception handledException = null;
+    private final CountDownLatch onMethod;
+
+    private RegisteringCompletionCallback(CountDownLatch onMethod) {
+      this.onMethod = onMethod;
+    }
+
+    @Override
+    public CommittedResult handleResult(CommittedBundle<?> inputBundle, TransformResult<?> result) {
+      handledResult = result;
+      onMethod.countDown();
+      @SuppressWarnings("rawtypes")
+      Iterable unprocessedElements =
+          result.getUnprocessedElements() == null
+              ? Collections.emptyList()
+              : result.getUnprocessedElements();
+
+      Optional<? extends CommittedBundle<?>> unprocessedBundle;
+      if (inputBundle == null || Iterables.isEmpty(unprocessedElements)) {
+        unprocessedBundle = Optional.absent();
+      } else {
+        unprocessedBundle =
+            Optional.<CommittedBundle<?>>of(inputBundle.withElements(unprocessedElements));
+      }
+      return CommittedResult.create(
+          result,
+          unprocessedBundle,
+          Collections.<CommittedBundle<?>>emptyList(),
+          EnumSet.noneOf(OutputType.class));
+    }
+
+    @Override
+    public void handleEmpty(AppliedPTransform<?, ?, ?> transform) {
+      handledEmpty = true;
+      onMethod.countDown();
+    }
+
+    @Override
+    public void handleException(CommittedBundle<?> inputBundle, Exception e) {
+      handledException = e;
+      onMethod.countDown();
+    }
+
+    @Override
+    public void handleError(Error err) {
+      throw err;
+    }
+  }
+
+  private static class TestEnforcementFactory implements ModelEnforcementFactory {
+    private TestEnforcement<?> instance;
+
+    @Override
+    public <T> TestEnforcement<T> forBundle(
+        CommittedBundle<T> input, AppliedPTransform<?, ?, ?> consumer) {
+      TestEnforcement<T> newEnforcement = new TestEnforcement<>();
+      instance = newEnforcement;
+      return newEnforcement;
+    }
+  }
+
+  private static class TestEnforcement<T> implements ModelEnforcement<T> {
+    private final List<WindowedValue<T>> beforeElements = new ArrayList<>();
+    private final List<WindowedValue<T>> afterElements = new ArrayList<>();
+    private final List<TransformResult<?>> finishedBundles = new ArrayList<>();
+
+    @Override
+    public void beforeElement(WindowedValue<T> element) {
+      beforeElements.add(element);
+    }
+
+    @Override
+    public void afterElement(WindowedValue<T> element) {
+      afterElements.add(element);
+    }
+
+    @Override
+    public void afterFinish(
+        CommittedBundle<T> input,
+        TransformResult<T> result,
+        Iterable<? extends CommittedBundle<?>> outputs) {
+      finishedBundles.add(result);
+    }
+  }
+
+  private static class ThrowingEnforcementFactory implements ModelEnforcementFactory {
+    private final When when;
+
+    private ThrowingEnforcementFactory(When when) {
+      this.when = when;
+    }
+
+    enum When {
+      BEFORE_BUNDLE,
+      BEFORE_ELEMENT,
+      AFTER_ELEMENT,
+      AFTER_BUNDLE
+    }
+
+    @Override
+    public <T> ModelEnforcement<T> forBundle(
+        CommittedBundle<T> input, AppliedPTransform<?, ?, ?> consumer) {
+      if (when == When.BEFORE_BUNDLE) {
+        throw new RuntimeException("forBundle");
+      }
+      return new ThrowingEnforcement<>();
+    }
+
+    private class ThrowingEnforcement<T> implements ModelEnforcement<T> {
+      @Override
+      public void beforeElement(WindowedValue<T> element) {
+        if (when == When.BEFORE_ELEMENT) {
+          throw new RuntimeException("beforeElement");
+        }
+      }
+
+      @Override
+      public void afterElement(WindowedValue<T> element) {
+        if (when == When.AFTER_ELEMENT) {
+          throw new RuntimeException("afterElement");
+        }
+      }
+
+      @Override
+      public void afterFinish(
+          CommittedBundle<T> input,
+          TransformResult<T> result,
+          Iterable<? extends CommittedBundle<?>> outputs) {
+        if (when == When.AFTER_BUNDLE) {
+          throw new RuntimeException("afterFinish");
+        }
+      }
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/beam/blob/8101103b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/TransformExecutorServicesTest.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/TransformExecutorServicesTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/TransformExecutorServicesTest.java
index 77652b2..5d1c994 100644
--- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/TransformExecutorServicesTest.java
+++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/TransformExecutorServicesTest.java
@@ -46,10 +46,10 @@ public class TransformExecutorServicesTest {
 
   @Test
   public void parallelScheduleMultipleSchedulesBothImmediately() {
-    @SuppressWarnings("unchecked")
-    TransformExecutor<Object> first = mock(TransformExecutor.class);
-    @SuppressWarnings("unchecked")
-    TransformExecutor<Object> second = mock(TransformExecutor.class);
+    @SuppressWarnings("unchecked") DirectTransformExecutor<Object>
+        first = mock(DirectTransformExecutor.class);
+    @SuppressWarnings("unchecked") DirectTransformExecutor<Object>
+        second = mock(DirectTransformExecutor.class);
 
     TransformExecutorService parallel =
         TransformExecutorServices.parallel(executorService);
@@ -65,8 +65,8 @@ public class TransformExecutorServicesTest {
 
   @Test
   public void parallelRejectedStillActiveThrows() {
-    @SuppressWarnings("unchecked")
-    TransformExecutor<Object> first = mock(TransformExecutor.class);
+    @SuppressWarnings("unchecked") DirectTransformExecutor<Object>
+        first = mock(DirectTransformExecutor.class);
 
     TransformExecutorService parallel =
         TransformExecutorServices.parallel(executorService);
@@ -78,8 +78,8 @@ public class TransformExecutorServicesTest {
 
   @Test
   public void parallelRejectedShutdownSucceeds() {
-    @SuppressWarnings("unchecked")
-    TransformExecutor<Object> first = mock(TransformExecutor.class);
+    @SuppressWarnings("unchecked") DirectTransformExecutor<Object>
+        first = mock(DirectTransformExecutor.class);
 
     TransformExecutorService parallel =
         TransformExecutorServices.parallel(executorService);
@@ -90,10 +90,10 @@ public class TransformExecutorServicesTest {
 
   @Test
   public void serialScheduleTwoWaitsForFirstToComplete() {
-    @SuppressWarnings("unchecked")
-    TransformExecutor<Object> first = mock(TransformExecutor.class);
-    @SuppressWarnings("unchecked")
-    TransformExecutor<Object> second = mock(TransformExecutor.class);
+    @SuppressWarnings("unchecked") DirectTransformExecutor<Object>
+        first = mock(DirectTransformExecutor.class);
+    @SuppressWarnings("unchecked") DirectTransformExecutor<Object>
+        second = mock(DirectTransformExecutor.class);
 
     TransformExecutorService serial = TransformExecutorServices.serial(executorService);
     serial.schedule(first);
@@ -110,10 +110,10 @@ public class TransformExecutorServicesTest {
 
   @Test
   public void serialCompleteNotExecutingTaskThrows() {
-    @SuppressWarnings("unchecked")
-    TransformExecutor<Object> first = mock(TransformExecutor.class);
-    @SuppressWarnings("unchecked")
-    TransformExecutor<Object> second = mock(TransformExecutor.class);
+    @SuppressWarnings("unchecked") DirectTransformExecutor<Object>
+        first = mock(DirectTransformExecutor.class);
+    @SuppressWarnings("unchecked") DirectTransformExecutor<Object>
+        second = mock(DirectTransformExecutor.class);
 
     TransformExecutorService serial = TransformExecutorServices.serial(executorService);
     serial.schedule(first);
@@ -129,10 +129,10 @@ public class TransformExecutorServicesTest {
    */
   @Test
   public void serialShutdownCompleteActive() {
-    @SuppressWarnings("unchecked")
-    TransformExecutor<Object> first = mock(TransformExecutor.class);
-    @SuppressWarnings("unchecked")
-    TransformExecutor<Object> second = mock(TransformExecutor.class);
+    @SuppressWarnings("unchecked") DirectTransformExecutor<Object>
+        first = mock(DirectTransformExecutor.class);
+    @SuppressWarnings("unchecked") DirectTransformExecutor<Object>
+        second = mock(DirectTransformExecutor.class);
 
     TransformExecutorService serial = TransformExecutorServices.serial(executorService);
     serial.schedule(first);

http://git-wip-us.apache.org/repos/asf/beam/blob/8101103b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/TransformExecutorTest.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/TransformExecutorTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/TransformExecutorTest.java
deleted file mode 100644
index b7f5a7c..0000000
--- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/TransformExecutorTest.java
+++ /dev/null
@@ -1,537 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.beam.runners.direct;
-
-import static org.hamcrest.Matchers.containsInAnyOrder;
-import static org.hamcrest.Matchers.equalTo;
-import static org.hamcrest.Matchers.is;
-import static org.hamcrest.Matchers.isA;
-import static org.hamcrest.Matchers.nullValue;
-import static org.junit.Assert.assertThat;
-import static org.mockito.Mockito.when;
-
-import com.google.common.base.Optional;
-import com.google.common.collect.Iterables;
-import com.google.common.util.concurrent.MoreExecutors;
-import java.util.ArrayList;
-import java.util.Collection;
-import java.util.Collections;
-import java.util.EnumSet;
-import java.util.List;
-import java.util.concurrent.CountDownLatch;
-import java.util.concurrent.Executors;
-import java.util.concurrent.Future;
-import java.util.concurrent.atomic.AtomicBoolean;
-import org.apache.beam.runners.direct.CommittedResult.OutputType;
-import org.apache.beam.sdk.runners.AppliedPTransform;
-import org.apache.beam.sdk.testing.TestPipeline;
-import org.apache.beam.sdk.transforms.Create;
-import org.apache.beam.sdk.transforms.WithKeys;
-import org.apache.beam.sdk.util.WindowedValue;
-import org.apache.beam.sdk.values.KV;
-import org.apache.beam.sdk.values.PCollection;
-import org.hamcrest.Matchers;
-import org.joda.time.Instant;
-import org.junit.Before;
-import org.junit.Rule;
-import org.junit.Test;
-import org.junit.rules.ExpectedException;
-import org.junit.runner.RunWith;
-import org.junit.runners.JUnit4;
-import org.mockito.Mock;
-import org.mockito.MockitoAnnotations;
-
-/** Tests for {@link TransformExecutor}. */
-@RunWith(JUnit4.class)
-public class TransformExecutorTest {
-  @Rule public ExpectedException thrown = ExpectedException.none();
-  private PCollection<String> created;
-
-  private AppliedPTransform<?, ?, ?> createdProducer;
-  private AppliedPTransform<?, ?, ?> downstreamProducer;
-
-  private CountDownLatch evaluatorCompleted;
-
-  private RegisteringCompletionCallback completionCallback;
-  private TransformExecutorService transformEvaluationState;
-  private BundleFactory bundleFactory;
-  @Mock private DirectMetrics metrics;
-  @Mock private EvaluationContext evaluationContext;
-  @Mock private TransformEvaluatorRegistry registry;
-
-  @Rule
-  public TestPipeline p = TestPipeline.create().enableAbandonedNodeEnforcement(false);
-
-  @Before
-  public void setup() {
-    MockitoAnnotations.initMocks(this);
-
-    bundleFactory = ImmutableListBundleFactory.create();
-
-    transformEvaluationState =
-        TransformExecutorServices.parallel(MoreExecutors.newDirectExecutorService());
-
-    evaluatorCompleted = new CountDownLatch(1);
-    completionCallback = new RegisteringCompletionCallback(evaluatorCompleted);
-
-    created = p.apply(Create.of("foo", "spam", "third"));
-    PCollection<KV<Integer, String>> downstream = created.apply(WithKeys.<Integer, String>of(3));
-
-    DirectGraphs.performDirectOverrides(p);
-    DirectGraph graph = DirectGraphs.getGraph(p);
-    createdProducer = graph.getProducer(created);
-    downstreamProducer = graph.getProducer(downstream);
-
-    when(evaluationContext.getMetrics()).thenReturn(metrics);
-  }
-
-  @Test
-  public void callWithNullInputBundleFinishesBundleAndCompletes() throws Exception {
-    final TransformResult<Object> result = StepTransformResult.withoutHold(createdProducer).build();
-    final AtomicBoolean finishCalled = new AtomicBoolean(false);
-    TransformEvaluator<Object> evaluator =
-        new TransformEvaluator<Object>() {
-          @Override
-          public void processElement(WindowedValue<Object> element) throws Exception {
-            throw new IllegalArgumentException("Shouldn't be called");
-          }
-
-          @Override
-          public TransformResult<Object> finishBundle() throws Exception {
-            finishCalled.set(true);
-            return result;
-          }
-        };
-
-    when(registry.forApplication(createdProducer, null)).thenReturn(evaluator);
-
-    TransformExecutor<Object> executor =
-        TransformExecutor.create(
-            evaluationContext,
-            registry,
-            Collections.<ModelEnforcementFactory>emptyList(),
-            null,
-            createdProducer,
-            completionCallback,
-            transformEvaluationState);
-    executor.run();
-
-    assertThat(finishCalled.get(), is(true));
-    assertThat(completionCallback.handledResult, Matchers.<TransformResult<?>>equalTo(result));
-    assertThat(completionCallback.handledException, is(nullValue()));
-  }
-
-  @Test
-  public void nullTransformEvaluatorTerminates() throws Exception {
-    when(registry.forApplication(createdProducer, null)).thenReturn(null);
-
-    TransformExecutor<Object> executor =
-        TransformExecutor.create(
-            evaluationContext,
-            registry,
-            Collections.<ModelEnforcementFactory>emptyList(),
-            null,
-            createdProducer,
-            completionCallback,
-            transformEvaluationState);
-    executor.run();
-
-    assertThat(completionCallback.handledResult, is(nullValue()));
-    assertThat(completionCallback.handledEmpty, equalTo(true));
-    assertThat(completionCallback.handledException, is(nullValue()));
-  }
-
-  @Test
-  public void inputBundleProcessesEachElementFinishesAndCompletes() throws Exception {
-    final TransformResult<String> result =
-        StepTransformResult.<String>withoutHold(downstreamProducer).build();
-    final Collection<WindowedValue<String>> elementsProcessed = new ArrayList<>();
-    TransformEvaluator<String> evaluator =
-        new TransformEvaluator<String>() {
-          @Override
-          public void processElement(WindowedValue<String> element) throws Exception {
-            elementsProcessed.add(element);
-            return;
-          }
-
-          @Override
-          public TransformResult<String> finishBundle() throws Exception {
-            return result;
-          }
-        };
-
-    WindowedValue<String> foo = WindowedValue.valueInGlobalWindow("foo");
-    WindowedValue<String> spam = WindowedValue.valueInGlobalWindow("spam");
-    WindowedValue<String> third = WindowedValue.valueInGlobalWindow("third");
-    CommittedBundle<String> inputBundle =
-        bundleFactory.createBundle(created).add(foo).add(spam).add(third).commit(Instant.now());
-    when(registry.<String>forApplication(downstreamProducer, inputBundle)).thenReturn(evaluator);
-
-    TransformExecutor<String> executor =
-        TransformExecutor.create(
-            evaluationContext,
-            registry,
-            Collections.<ModelEnforcementFactory>emptyList(),
-            inputBundle,
-            downstreamProducer,
-            completionCallback,
-            transformEvaluationState);
-
-    Executors.newSingleThreadExecutor().submit(executor);
-
-    evaluatorCompleted.await();
-
-    assertThat(elementsProcessed, containsInAnyOrder(spam, third, foo));
-    assertThat(completionCallback.handledResult, Matchers.<TransformResult<?>>equalTo(result));
-    assertThat(completionCallback.handledException, is(nullValue()));
-  }
-
-  @Test
-  public void processElementThrowsExceptionCallsback() throws Exception {
-    final TransformResult<String> result =
-        StepTransformResult.<String>withoutHold(downstreamProducer).build();
-    final Exception exception = new Exception();
-    TransformEvaluator<String> evaluator =
-        new TransformEvaluator<String>() {
-          @Override
-          public void processElement(WindowedValue<String> element) throws Exception {
-            throw exception;
-          }
-
-          @Override
-          public TransformResult<String> finishBundle() throws Exception {
-            return result;
-          }
-        };
-
-    WindowedValue<String> foo = WindowedValue.valueInGlobalWindow("foo");
-    CommittedBundle<String> inputBundle =
-        bundleFactory.createBundle(created).add(foo).commit(Instant.now());
-    when(registry.<String>forApplication(downstreamProducer, inputBundle)).thenReturn(evaluator);
-
-    TransformExecutor<String> executor =
-        TransformExecutor.create(
-            evaluationContext,
-            registry,
-            Collections.<ModelEnforcementFactory>emptyList(),
-            inputBundle,
-            downstreamProducer,
-            completionCallback,
-            transformEvaluationState);
-    Executors.newSingleThreadExecutor().submit(executor);
-
-    evaluatorCompleted.await();
-
-    assertThat(completionCallback.handledResult, is(nullValue()));
-    assertThat(completionCallback.handledException, Matchers.<Throwable>equalTo(exception));
-  }
-
-  @Test
-  public void finishBundleThrowsExceptionCallsback() throws Exception {
-    final Exception exception = new Exception();
-    TransformEvaluator<String> evaluator =
-        new TransformEvaluator<String>() {
-          @Override
-          public void processElement(WindowedValue<String> element) throws Exception {}
-
-          @Override
-          public TransformResult<String> finishBundle() throws Exception {
-            throw exception;
-          }
-        };
-
-    CommittedBundle<String> inputBundle = bundleFactory.createBundle(created).commit(Instant.now());
-    when(registry.<String>forApplication(downstreamProducer, inputBundle)).thenReturn(evaluator);
-
-    TransformExecutor<String> executor =
-        TransformExecutor.create(
-            evaluationContext,
-            registry,
-            Collections.<ModelEnforcementFactory>emptyList(),
-            inputBundle,
-            downstreamProducer,
-            completionCallback,
-            transformEvaluationState);
-    Executors.newSingleThreadExecutor().submit(executor);
-
-    evaluatorCompleted.await();
-
-    assertThat(completionCallback.handledResult, is(nullValue()));
-    assertThat(completionCallback.handledException, Matchers.<Throwable>equalTo(exception));
-  }
-
-  @Test
-  public void callWithEnforcementAppliesEnforcement() throws Exception {
-    final TransformResult<Object> result =
-        StepTransformResult.withoutHold(downstreamProducer).build();
-
-    TransformEvaluator<Object> evaluator =
-        new TransformEvaluator<Object>() {
-          @Override
-          public void processElement(WindowedValue<Object> element) throws Exception {}
-
-          @Override
-          public TransformResult<Object> finishBundle() throws Exception {
-            return result;
-          }
-        };
-
-    WindowedValue<String> fooElem = WindowedValue.valueInGlobalWindow("foo");
-    WindowedValue<String> barElem = WindowedValue.valueInGlobalWindow("bar");
-    CommittedBundle<String> inputBundle =
-        bundleFactory.createBundle(created).add(fooElem).add(barElem).commit(Instant.now());
-    when(registry.forApplication(downstreamProducer, inputBundle)).thenReturn(evaluator);
-
-    TestEnforcementFactory enforcement = new TestEnforcementFactory();
-    TransformExecutor<String> executor =
-        TransformExecutor.create(
-            evaluationContext,
-            registry,
-            Collections.<ModelEnforcementFactory>singleton(enforcement),
-            inputBundle,
-            downstreamProducer,
-            completionCallback,
-            transformEvaluationState);
-
-    executor.run();
-    TestEnforcement<?> testEnforcement = enforcement.instance;
-    assertThat(
-        testEnforcement.beforeElements,
-        Matchers.<WindowedValue<?>>containsInAnyOrder(barElem, fooElem));
-    assertThat(
-        testEnforcement.afterElements,
-        Matchers.<WindowedValue<?>>containsInAnyOrder(barElem, fooElem));
-    assertThat(testEnforcement.finishedBundles, Matchers.<TransformResult<?>>contains(result));
-  }
-
-  @Test
-  public void callWithEnforcementThrowsOnFinishPropagates() throws Exception {
-    final TransformResult<Object> result =
-        StepTransformResult.withoutHold(createdProducer).build();
-
-    TransformEvaluator<Object> evaluator =
-        new TransformEvaluator<Object>() {
-          @Override
-          public void processElement(WindowedValue<Object> element) throws Exception {}
-
-          @Override
-          public TransformResult<Object> finishBundle() throws Exception {
-            return result;
-          }
-        };
-
-    WindowedValue<String> fooBytes = WindowedValue.valueInGlobalWindow("foo");
-    CommittedBundle<String> inputBundle =
-        bundleFactory.createBundle(created).add(fooBytes).commit(Instant.now());
-    when(registry.forApplication(downstreamProducer, inputBundle)).thenReturn(evaluator);
-
-    TransformExecutor<String> executor =
-        TransformExecutor.create(
-            evaluationContext,
-            registry,
-            Collections.<ModelEnforcementFactory>singleton(
-                new ThrowingEnforcementFactory(ThrowingEnforcementFactory.When.AFTER_BUNDLE)),
-            inputBundle,
-            downstreamProducer,
-            completionCallback,
-            transformEvaluationState);
-
-    Future<?> task = Executors.newSingleThreadExecutor().submit(executor);
-
-    thrown.expectCause(isA(RuntimeException.class));
-    thrown.expectMessage("afterFinish");
-    task.get();
-  }
-
-  @Test
-  public void callWithEnforcementThrowsOnElementPropagates() throws Exception {
-    final TransformResult<Object> result =
-        StepTransformResult.withoutHold(createdProducer).build();
-
-    TransformEvaluator<Object> evaluator =
-        new TransformEvaluator<Object>() {
-          @Override
-          public void processElement(WindowedValue<Object> element) throws Exception {}
-
-          @Override
-          public TransformResult<Object> finishBundle() throws Exception {
-            return result;
-          }
-        };
-
-    WindowedValue<String> fooBytes = WindowedValue.valueInGlobalWindow("foo");
-    CommittedBundle<String> inputBundle =
-        bundleFactory.createBundle(created).add(fooBytes).commit(Instant.now());
-    when(registry.forApplication(downstreamProducer, inputBundle)).thenReturn(evaluator);
-
-    TransformExecutor<String> executor =
-        TransformExecutor.create(
-            evaluationContext,
-            registry,
-            Collections.<ModelEnforcementFactory>singleton(
-                new ThrowingEnforcementFactory(ThrowingEnforcementFactory.When.AFTER_ELEMENT)),
-            inputBundle,
-            downstreamProducer,
-            completionCallback,
-            transformEvaluationState);
-
-    Future<?> task = Executors.newSingleThreadExecutor().submit(executor);
-
-    thrown.expectCause(isA(RuntimeException.class));
-    thrown.expectMessage("afterElement");
-    task.get();
-  }
-
-  private static class RegisteringCompletionCallback implements CompletionCallback {
-    private TransformResult<?> handledResult = null;
-    private boolean handledEmpty = false;
-    private Exception handledException = null;
-    private final CountDownLatch onMethod;
-
-    private RegisteringCompletionCallback(CountDownLatch onMethod) {
-      this.onMethod = onMethod;
-    }
-
-    @Override
-    public CommittedResult handleResult(CommittedBundle<?> inputBundle, TransformResult<?> result) {
-      handledResult = result;
-      onMethod.countDown();
-      @SuppressWarnings("rawtypes")
-      Iterable unprocessedElements =
-          result.getUnprocessedElements() == null
-              ? Collections.emptyList()
-              : result.getUnprocessedElements();
-
-      Optional<? extends CommittedBundle<?>> unprocessedBundle;
-      if (inputBundle == null || Iterables.isEmpty(unprocessedElements)) {
-        unprocessedBundle = Optional.absent();
-      } else {
-        unprocessedBundle =
-            Optional.<CommittedBundle<?>>of(inputBundle.withElements(unprocessedElements));
-      }
-      return CommittedResult.create(
-          result,
-          unprocessedBundle,
-          Collections.<CommittedBundle<?>>emptyList(),
-          EnumSet.noneOf(OutputType.class));
-    }
-
-    @Override
-    public void handleEmpty(AppliedPTransform<?, ?, ?> transform) {
-      handledEmpty = true;
-      onMethod.countDown();
-    }
-
-    @Override
-    public void handleException(CommittedBundle<?> inputBundle, Exception e) {
-      handledException = e;
-      onMethod.countDown();
-    }
-
-    @Override
-    public void handleError(Error err) {
-      throw err;
-    }
-  }
-
-  private static class TestEnforcementFactory implements ModelEnforcementFactory {
-    private TestEnforcement<?> instance;
-
-    @Override
-    public <T> TestEnforcement<T> forBundle(
-        CommittedBundle<T> input, AppliedPTransform<?, ?, ?> consumer) {
-      TestEnforcement<T> newEnforcement = new TestEnforcement<>();
-      instance = newEnforcement;
-      return newEnforcement;
-    }
-  }
-
-  private static class TestEnforcement<T> implements ModelEnforcement<T> {
-    private final List<WindowedValue<T>> beforeElements = new ArrayList<>();
-    private final List<WindowedValue<T>> afterElements = new ArrayList<>();
-    private final List<TransformResult<?>> finishedBundles = new ArrayList<>();
-
-    @Override
-    public void beforeElement(WindowedValue<T> element) {
-      beforeElements.add(element);
-    }
-
-    @Override
-    public void afterElement(WindowedValue<T> element) {
-      afterElements.add(element);
-    }
-
-    @Override
-    public void afterFinish(
-        CommittedBundle<T> input,
-        TransformResult<T> result,
-        Iterable<? extends CommittedBundle<?>> outputs) {
-      finishedBundles.add(result);
-    }
-  }
-
-  private static class ThrowingEnforcementFactory implements ModelEnforcementFactory {
-    private final When when;
-
-    private ThrowingEnforcementFactory(When when) {
-      this.when = when;
-    }
-
-    enum When {
-      BEFORE_BUNDLE,
-      BEFORE_ELEMENT,
-      AFTER_ELEMENT,
-      AFTER_BUNDLE
-    }
-
-    @Override
-    public <T> ModelEnforcement<T> forBundle(
-        CommittedBundle<T> input, AppliedPTransform<?, ?, ?> consumer) {
-      if (when == When.BEFORE_BUNDLE) {
-        throw new RuntimeException("forBundle");
-      }
-      return new ThrowingEnforcement<>();
-    }
-
-    private class ThrowingEnforcement<T> implements ModelEnforcement<T> {
-      @Override
-      public void beforeElement(WindowedValue<T> element) {
-        if (when == When.BEFORE_ELEMENT) {
-          throw new RuntimeException("beforeElement");
-        }
-      }
-
-      @Override
-      public void afterElement(WindowedValue<T> element) {
-        if (when == When.AFTER_ELEMENT) {
-          throw new RuntimeException("afterElement");
-        }
-      }
-
-      @Override
-      public void afterFinish(
-          CommittedBundle<T> input,
-          TransformResult<T> result,
-          Iterable<? extends CommittedBundle<?>> outputs) {
-        if (when == When.AFTER_BUNDLE) {
-          throw new RuntimeException("afterFinish");
-        }
-      }
-    }
-  }
-}