You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by dh...@apache.org on 2016/08/15 21:17:06 UTC

[1/5] incubator-beam git commit: Closes #690

Repository: incubator-beam
Updated Branches:
  refs/heads/master 0b1f66421 -> 7c680079b


Closes #690


Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/7c680079
Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/7c680079
Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/7c680079

Branch: refs/heads/master
Commit: 7c680079b5074ff31257d7f8fff77af1dd9ea62c
Parents: 0b1f664 29cbdce
Author: Dan Halperin <dh...@google.com>
Authored: Mon Aug 15 14:16:54 2016 -0700
Committer: Dan Halperin <dh...@google.com>
Committed: Mon Aug 15 14:16:54 2016 -0700

----------------------------------------------------------------------
 .../direct/BoundedReadEvaluatorFactory.java     |   4 +
 .../beam/runners/direct/CloningThreadLocal.java |  43 --
 .../runners/direct/DoFnLifecycleManager.java    | 106 +++++
 ...ecycleManagerRemovingTransformEvaluator.java |  79 ++++
 .../runners/direct/DoFnLifecycleManagers.java   |  45 ++
 .../direct/ExecutorServiceParallelExecutor.java |   9 +-
 .../runners/direct/FlattenEvaluatorFactory.java |   3 +
 .../GroupAlsoByWindowEvaluatorFactory.java      |   6 +-
 .../direct/GroupByKeyOnlyEvaluatorFactory.java  |   4 +-
 .../direct/ParDoMultiEvaluatorFactory.java      |  55 ++-
 .../direct/ParDoSingleEvaluatorFactory.java     |  42 +-
 ...readLocalInvalidatingTransformEvaluator.java |  63 ---
 .../direct/TransformEvaluatorFactory.java       |   8 +
 .../direct/TransformEvaluatorRegistry.java      |  41 ++
 .../direct/UnboundedReadEvaluatorFactory.java   |   3 +
 .../runners/direct/ViewEvaluatorFactory.java    |   3 +
 .../runners/direct/WindowEvaluatorFactory.java  |   3 +
 .../runners/direct/CloningThreadLocalTest.java  |  92 ----
 ...leManagerRemovingTransformEvaluatorTest.java | 144 ++++++
 .../direct/DoFnLifecycleManagerTest.java        | 168 +++++++
 .../direct/DoFnLifecycleManagersTest.java       | 142 ++++++
 ...LocalInvalidatingTransformEvaluatorTest.java | 135 ------
 .../functions/FlinkDoFnFunction.java            |  12 +-
 .../functions/FlinkMultiOutputDoFnFunction.java |  31 +-
 .../streaming/FlinkAbstractParDoWrapper.java    |   2 +
 .../FlinkGroupAlsoByWindowWrapper.java          |   2 +
 runners/google-cloud-dataflow-java/pom.xml      |  10 +
 .../runners/spark/translation/DoFnFunction.java |  23 +-
 .../spark/translation/MultiDoFnFunction.java    |   1 +
 .../spark/translation/SparkProcessContext.java  |  17 +
 .../org/apache/beam/sdk/transforms/DoFn.java    |  31 +-
 .../beam/sdk/transforms/DoFnReflector.java      |  70 ++-
 .../org/apache/beam/sdk/transforms/OldDoFn.java |  25 ++
 .../org/apache/beam/sdk/transforms/ParDo.java   |  15 +-
 .../beam/sdk/transforms/DoFnReflectorTest.java  |  65 +++
 .../beam/sdk/transforms/ParDoLifecycleTest.java | 448 +++++++++++++++++++
 .../apache/beam/sdk/transforms/ParDoTest.java   |  15 +-
 37 files changed, 1553 insertions(+), 412 deletions(-)
----------------------------------------------------------------------



[5/5] incubator-beam git commit: Replace CloningThreadLocal with DoFnLifecycleManager

Posted by dh...@apache.org.
Replace CloningThreadLocal with DoFnLifecycleManager

This is a more focused interface that interacts with a DoFn before it
is available for use and after it has completed and the reference is
lost. It is required to properly support setup and teardown, as the
fields in a ThreadLocal cannot all be cleaned up without additional
tracking.

Part of BEAM-452.


Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/cf0bf3bf
Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/cf0bf3bf
Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/cf0bf3bf

Branch: refs/heads/master
Commit: cf0bf3bf9fcab2b01d69ff90d9ba3f602a8a5bd4
Parents: 12b1967
Author: Thomas Groh <tg...@google.com>
Authored: Tue Jul 19 11:03:15 2016 -0700
Committer: Dan Halperin <dh...@google.com>
Committed: Mon Aug 15 14:16:54 2016 -0700

----------------------------------------------------------------------
 .../beam/runners/direct/CloningThreadLocal.java |  43 ------
 .../runners/direct/DoFnLifecycleManager.java    |  78 ++++++++++
 ...ecycleManagerRemovingTransformEvaluator.java |  80 +++++++++++
 .../direct/ParDoMultiEvaluatorFactory.java      |  56 ++++----
 .../direct/ParDoSingleEvaluatorFactory.java     |  43 +++---
 ...readLocalInvalidatingTransformEvaluator.java |  63 --------
 .../runners/direct/CloningThreadLocalTest.java  |  92 ------------
 ...leManagerRemovingTransformEvaluatorTest.java | 144 +++++++++++++++++++
 .../direct/DoFnLifecycleManagerTest.java        | 119 +++++++++++++++
 ...LocalInvalidatingTransformEvaluatorTest.java | 135 -----------------
 10 files changed, 475 insertions(+), 378 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/cf0bf3bf/runners/direct-java/src/main/java/org/apache/beam/runners/direct/CloningThreadLocal.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/CloningThreadLocal.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/CloningThreadLocal.java
deleted file mode 100644
index b9dc4ca..0000000
--- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/CloningThreadLocal.java
+++ /dev/null
@@ -1,43 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.beam.runners.direct;
-
-import org.apache.beam.sdk.util.SerializableUtils;
-
-import java.io.Serializable;
-
-/**
- * A {@link ThreadLocal} that obtains the initial value by cloning an original value.
- */
-class CloningThreadLocal<T extends Serializable> extends ThreadLocal<T> {
-  public static <T extends Serializable> CloningThreadLocal<T> of(T original) {
-    return new CloningThreadLocal<>(original);
-  }
-
-  private final T original;
-
-  private CloningThreadLocal(T original) {
-    this.original = original;
-  }
-
-  @Override
-  public T initialValue() {
-    return SerializableUtils.clone(original);
-  }
-}

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/cf0bf3bf/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DoFnLifecycleManager.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DoFnLifecycleManager.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DoFnLifecycleManager.java
new file mode 100644
index 0000000..2783657
--- /dev/null
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DoFnLifecycleManager.java
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.direct;
+
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.OldDoFn;
+import org.apache.beam.sdk.util.SerializableUtils;
+
+import com.google.common.cache.CacheBuilder;
+import com.google.common.cache.CacheLoader;
+import com.google.common.cache.LoadingCache;
+
+/**
+ * Manages {@link DoFn} setup, teardown, and serialization.
+ *
+ * <p>{@link DoFnLifecycleManager} is similar to a {@link ThreadLocal} storing a {@link DoFn}, but
+ * calls the {@link DoFn} {@link Setup} the first time the {@link DoFn} is obtained and {@link
+ * Teardown} whenever the {@link DoFn} is removed, and provides a method for clearing all cached
+ * {@link DoFn DoFns}.
+ */
+class DoFnLifecycleManager {
+  public static DoFnLifecycleManager of(OldDoFn<?, ?> original) {
+    return new DoFnLifecycleManager(original);
+  }
+
+  private final LoadingCache<Thread, OldDoFn<?, ?>> outstanding;
+
+  private DoFnLifecycleManager(OldDoFn<?, ?> original) {
+    this.outstanding = CacheBuilder.newBuilder().build(new DeserializingCacheLoader(original));
+  }
+
+  public OldDoFn<?, ?> get() throws Exception {
+    Thread currentThread = Thread.currentThread();
+    return outstanding.get(currentThread);
+  }
+
+  public void remove() throws Exception {
+    Thread currentThread = Thread.currentThread();
+    outstanding.invalidate(currentThread);
+  }
+
+  /**
+   * Remove all {@link DoFn DoFns} from this {@link DoFnLifecycleManager}.
+   */
+  public void removeAll() throws Exception {
+    outstanding.invalidateAll();
+  }
+
+  private class DeserializingCacheLoader extends CacheLoader<Thread, OldDoFn<?, ?>> {
+    private final byte[] original;
+
+    public DeserializingCacheLoader(OldDoFn<?, ?> original) {
+      this.original = SerializableUtils.serializeToByteArray(original);
+    }
+
+    @Override
+    public OldDoFn<?, ?> load(Thread key) throws Exception {
+      return (OldDoFn<?, ?>) SerializableUtils.deserializeFromByteArray(original,
+          "DoFn Copy in thread " + key.getName());
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/cf0bf3bf/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DoFnLifecycleManagerRemovingTransformEvaluator.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DoFnLifecycleManagerRemovingTransformEvaluator.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DoFnLifecycleManagerRemovingTransformEvaluator.java
new file mode 100644
index 0000000..f3d1d4f
--- /dev/null
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DoFnLifecycleManagerRemovingTransformEvaluator.java
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.direct;
+
+import org.apache.beam.sdk.util.WindowedValue;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * A {@link TransformEvaluator} which delegates calls to an underlying {@link TransformEvaluator},
+ * clearing the value of a {@link DoFnLifecycleManager} if any call throws an exception.
+ */
+class DoFnLifecycleManagerRemovingTransformEvaluator<InputT> implements TransformEvaluator<InputT> {
+  private static final Logger LOG =
+      LoggerFactory.getLogger(DoFnLifecycleManagerRemovingTransformEvaluator.class);
+  private final TransformEvaluator<InputT> underlying;
+  private final DoFnLifecycleManager lifecycleManager;
+
+  public static <InputT> TransformEvaluator<InputT> wrapping(
+      TransformEvaluator<InputT> underlying, DoFnLifecycleManager threadLocal) {
+    return new DoFnLifecycleManagerRemovingTransformEvaluator<>(underlying, threadLocal);
+  }
+
+  private DoFnLifecycleManagerRemovingTransformEvaluator(
+      TransformEvaluator<InputT> underlying, DoFnLifecycleManager threadLocal) {
+    this.underlying = underlying;
+    this.lifecycleManager = threadLocal;
+  }
+
+  @Override
+  public void processElement(WindowedValue<InputT> element) throws Exception {
+    try {
+      underlying.processElement(element);
+    } catch (Exception e) {
+      try {
+        lifecycleManager.remove();
+      } catch (Exception removalException) {
+        LOG.error(
+            "Exception encountered while cleaning up after processing an element",
+            removalException);
+        e.addSuppressed(removalException);
+      }
+      throw e;
+    }
+  }
+
+  @Override
+  public TransformResult finishBundle() throws Exception {
+    try {
+      return underlying.finishBundle();
+    } catch (Exception e) {
+      try {
+        lifecycleManager.remove();
+      } catch (Exception removalException) {
+        LOG.error(
+            "Exception encountered while cleaning up after finishing a bundle",
+            removalException);
+        e.addSuppressed(removalException);
+      }
+      throw e;
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/cf0bf3bf/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiEvaluatorFactory.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiEvaluatorFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiEvaluatorFactory.java
index 40533c0..f2455e1 100644
--- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiEvaluatorFactory.java
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiEvaluatorFactory.java
@@ -31,6 +31,9 @@ import com.google.common.cache.CacheBuilder;
 import com.google.common.cache.CacheLoader;
 import com.google.common.cache.LoadingCache;
 
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
 import java.util.Map;
 
 /**
@@ -38,32 +41,26 @@ import java.util.Map;
  * {@link BoundMulti} primitive {@link PTransform}.
  */
 class ParDoMultiEvaluatorFactory implements TransformEvaluatorFactory {
-  private final LoadingCache<AppliedPTransform<?, ?, BoundMulti<?, ?>>, ThreadLocal<OldDoFn<?, ?>>>
+  private static final Logger LOG = LoggerFactory.getLogger(ParDoMultiEvaluatorFactory.class);
+  private final LoadingCache<AppliedPTransform<?, ?, BoundMulti<?, ?>>, DoFnLifecycleManager>
       fnClones;
 
   public ParDoMultiEvaluatorFactory() {
-    fnClones =
-        CacheBuilder.newBuilder()
-            .build(
-                new CacheLoader<
-                    AppliedPTransform<?, ?, BoundMulti<?, ?>>, ThreadLocal<OldDoFn<?, ?>>>() {
-                  @Override
-                  public ThreadLocal<OldDoFn<?, ?>> load(
-                      AppliedPTransform<?, ?, BoundMulti<?, ?>> key)
-                      throws Exception {
-                    @SuppressWarnings({"unchecked", "rawtypes"})
-                    ThreadLocal threadLocal =
-                        (ThreadLocal) CloningThreadLocal.of(key.getTransform().getFn());
-                    return threadLocal;
-                  }
-                });
+    fnClones = CacheBuilder.newBuilder()
+        .build(new CacheLoader<AppliedPTransform<?, ?, BoundMulti<?, ?>>, DoFnLifecycleManager>() {
+          @Override
+          public DoFnLifecycleManager load(AppliedPTransform<?, ?, BoundMulti<?, ?>> key)
+              throws Exception {
+            return DoFnLifecycleManager.of(key.getTransform().getFn());
+          }
+        });
   }
 
   @Override
   public <T> TransformEvaluator<T> forApplication(
       AppliedPTransform<?, ?, ?> application,
       CommittedBundle<?> inputBundle,
-      EvaluationContext evaluationContext) {
+      EvaluationContext evaluationContext) throws Exception {
     @SuppressWarnings({"unchecked", "rawtypes"})
     TransformEvaluator<T> evaluator =
         createMultiEvaluator((AppliedPTransform) application, inputBundle, evaluationContext);
@@ -71,38 +68,45 @@ class ParDoMultiEvaluatorFactory implements TransformEvaluatorFactory {
   }
 
   @Override
-  public void cleanup() {
-
+  public void cleanup() throws Exception {
+    for (DoFnLifecycleManager lifecycleManager : fnClones.asMap().values()) {
+      lifecycleManager.removeAll();
+    }
   }
 
   private <InT, OuT> TransformEvaluator<InT> createMultiEvaluator(
       AppliedPTransform<PCollection<InT>, PCollectionTuple, BoundMulti<InT, OuT>> application,
       CommittedBundle<InT> inputBundle,
-      EvaluationContext evaluationContext) {
+      EvaluationContext evaluationContext) throws Exception {
     Map<TupleTag<?>, PCollection<?>> outputs = application.getOutput().getAll();
 
-    @SuppressWarnings({"unchecked", "rawtypes"})
-    ThreadLocal<OldDoFn<InT, OuT>> fnLocal =
-        (ThreadLocal) fnClones.getUnchecked((AppliedPTransform) application);
+    DoFnLifecycleManager fnLocal = fnClones.getUnchecked((AppliedPTransform) application);
     String stepName = evaluationContext.getStepName(application);
     DirectStepContext stepContext =
         evaluationContext.getExecutionContext(application, inputBundle.getKey())
             .getOrCreateStepContext(stepName, stepName);
     try {
+      @SuppressWarnings({"unchecked", "rawtypes"})
       TransformEvaluator<InT> parDoEvaluator =
           ParDoEvaluator.create(
               evaluationContext,
               stepContext,
               inputBundle,
               application,
-              fnLocal.get(),
+              (OldDoFn) fnLocal.get(),
               application.getTransform().getSideInputs(),
               application.getTransform().getMainOutputTag(),
               application.getTransform().getSideOutputTags().getAll(),
               outputs);
-      return ThreadLocalInvalidatingTransformEvaluator.wrapping(parDoEvaluator, fnLocal);
+      return DoFnLifecycleManagerRemovingTransformEvaluator.wrapping(parDoEvaluator, fnLocal);
     } catch (Exception e) {
-      fnLocal.remove();
+      try {
+        fnLocal.remove();
+      } catch (Exception removalException) {
+        LOG.error("Exception encountered while cleaning up in ParDo evaluator construction",
+            removalException);
+        e.addSuppressed(removalException);
+      }
       throw e;
     }
   }

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/cf0bf3bf/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoSingleEvaluatorFactory.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoSingleEvaluatorFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoSingleEvaluatorFactory.java
index 201fb46..a0fbd1d 100644
--- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoSingleEvaluatorFactory.java
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoSingleEvaluatorFactory.java
@@ -31,6 +31,9 @@ import com.google.common.cache.CacheLoader;
 import com.google.common.cache.LoadingCache;
 import com.google.common.collect.ImmutableMap;
 
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
 import java.util.Collections;
 
 /**
@@ -38,22 +41,18 @@ import java.util.Collections;
  * {@link Bound ParDo.Bound} primitive {@link PTransform}.
  */
 class ParDoSingleEvaluatorFactory implements TransformEvaluatorFactory {
-  private final LoadingCache<AppliedPTransform<?, ?, Bound<?, ?>>, ThreadLocal<OldDoFn<?, ?>>>
-      fnClones;
+  private static final Logger LOG = LoggerFactory.getLogger(ParDoSingleEvaluatorFactory.class);
+  private final LoadingCache<AppliedPTransform<?, ?, Bound<?, ?>>, DoFnLifecycleManager> fnClones;
 
   public ParDoSingleEvaluatorFactory() {
     fnClones =
         CacheBuilder.newBuilder()
             .build(
-                new CacheLoader<
-                    AppliedPTransform<?, ?, Bound<?, ?>>, ThreadLocal<OldDoFn<?, ?>>>() {
+                new CacheLoader<AppliedPTransform<?, ?, Bound<?, ?>>, DoFnLifecycleManager>() {
                   @Override
-                  public ThreadLocal<OldDoFn<?, ?>> load(AppliedPTransform<?, ?, Bound<?, ?>> key)
+                  public DoFnLifecycleManager load(AppliedPTransform<?, ?, Bound<?, ?>> key)
                       throws Exception {
-                    @SuppressWarnings({"unchecked", "rawtypes"})
-                    ThreadLocal threadLocal =
-                        (ThreadLocal) CloningThreadLocal.of(key.getTransform().getFn());
-                    return threadLocal;
+                    return DoFnLifecycleManager.of(key.getTransform().getFn());
                   }
                 });
   }
@@ -62,7 +61,7 @@ class ParDoSingleEvaluatorFactory implements TransformEvaluatorFactory {
   public <T> TransformEvaluator<T> forApplication(
       final AppliedPTransform<?, ?, ?> application,
       CommittedBundle<?> inputBundle,
-      EvaluationContext evaluationContext) {
+      EvaluationContext evaluationContext) throws Exception {
     @SuppressWarnings({"unchecked", "rawtypes"})
     TransformEvaluator<T> evaluator =
         createSingleEvaluator((AppliedPTransform) application, inputBundle, evaluationContext);
@@ -70,39 +69,45 @@ class ParDoSingleEvaluatorFactory implements TransformEvaluatorFactory {
   }
 
   @Override
-  public void cleanup() {
-
+  public void cleanup() throws Exception {
+    for (DoFnLifecycleManager lifecycleManager : fnClones.asMap().values()) {
+      lifecycleManager.removeAll();
+    }
   }
 
   private <InputT, OutputT> TransformEvaluator<InputT> createSingleEvaluator(
       AppliedPTransform<PCollection<InputT>, PCollection<OutputT>, Bound<InputT, OutputT>>
           application,
       CommittedBundle<InputT> inputBundle,
-      EvaluationContext evaluationContext) {
+      EvaluationContext evaluationContext) throws Exception {
     TupleTag<OutputT> mainOutputTag = new TupleTag<>("out");
     String stepName = evaluationContext.getStepName(application);
     DirectStepContext stepContext =
         evaluationContext.getExecutionContext(application, inputBundle.getKey())
             .getOrCreateStepContext(stepName, stepName);
 
-    @SuppressWarnings({"unchecked", "rawtypes"})
-    ThreadLocal<OldDoFn<InputT, OutputT>> fnLocal =
-        (ThreadLocal) fnClones.getUnchecked((AppliedPTransform) application);
+    DoFnLifecycleManager fnLocal = fnClones.getUnchecked((AppliedPTransform) application);
     try {
+      @SuppressWarnings({"unchecked", "rawtypes"})
       ParDoEvaluator<InputT> parDoEvaluator =
           ParDoEvaluator.create(
               evaluationContext,
               stepContext,
               inputBundle,
               application,
-              fnLocal.get(),
+              (OldDoFn) fnLocal.get(),
               application.getTransform().getSideInputs(),
               mainOutputTag,
               Collections.<TupleTag<?>>emptyList(),
               ImmutableMap.<TupleTag<?>, PCollection<?>>of(mainOutputTag, application.getOutput()));
-      return ThreadLocalInvalidatingTransformEvaluator.wrapping(parDoEvaluator, fnLocal);
+      return DoFnLifecycleManagerRemovingTransformEvaluator.wrapping(parDoEvaluator, fnLocal);
     } catch (Exception e) {
-      fnLocal.remove();
+      try {
+        fnLocal.remove();
+      } catch (Exception removalException) {
+        LOG.error("Exception encountered constructing ParDo evaluator", removalException);
+        e.addSuppressed(removalException);
+      }
       throw e;
     }
   }

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/cf0bf3bf/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ThreadLocalInvalidatingTransformEvaluator.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ThreadLocalInvalidatingTransformEvaluator.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ThreadLocalInvalidatingTransformEvaluator.java
deleted file mode 100644
index d8a6bf9..0000000
--- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ThreadLocalInvalidatingTransformEvaluator.java
+++ /dev/null
@@ -1,63 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.beam.runners.direct;
-
-import org.apache.beam.sdk.util.WindowedValue;
-
-/**
- * A {@link TransformEvaluator} which delegates calls to an underlying {@link TransformEvaluator},
- * clearing the value of a {@link ThreadLocal} if any call throws an exception.
- */
-class ThreadLocalInvalidatingTransformEvaluator<InputT>
-    implements TransformEvaluator<InputT> {
-  private final TransformEvaluator<InputT> underlying;
-  private final ThreadLocal<?> threadLocal;
-
-  public static <InputT> TransformEvaluator<InputT> wrapping(
-      TransformEvaluator<InputT> underlying,
-      ThreadLocal<?> threadLocal) {
-    return new ThreadLocalInvalidatingTransformEvaluator<>(underlying, threadLocal);
-  }
-
-  private ThreadLocalInvalidatingTransformEvaluator(
-      TransformEvaluator<InputT> underlying, ThreadLocal<?> threadLocal) {
-    this.underlying = underlying;
-    this.threadLocal = threadLocal;
-  }
-
-  @Override
-  public void processElement(WindowedValue<InputT> element) throws Exception {
-    try {
-      underlying.processElement(element);
-    } catch (Exception e) {
-      threadLocal.remove();
-      throw e;
-    }
-  }
-
-  @Override
-  public TransformResult finishBundle() throws Exception {
-    try {
-      return underlying.finishBundle();
-    } catch (Exception e) {
-      threadLocal.remove();
-      throw e;
-    }
-  }
-}

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/cf0bf3bf/runners/direct-java/src/test/java/org/apache/beam/runners/direct/CloningThreadLocalTest.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/CloningThreadLocalTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/CloningThreadLocalTest.java
deleted file mode 100644
index 298db46..0000000
--- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/CloningThreadLocalTest.java
+++ /dev/null
@@ -1,92 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.beam.runners.direct;
-
-import static org.hamcrest.Matchers.equalTo;
-import static org.hamcrest.Matchers.nullValue;
-import static org.hamcrest.core.IsNot.not;
-import static org.hamcrest.core.IsSame.theInstance;
-import static org.junit.Assert.assertThat;
-
-import org.junit.Test;
-import org.junit.runner.RunWith;
-import org.junit.runners.JUnit4;
-
-import java.io.Serializable;
-import java.util.concurrent.Callable;
-import java.util.concurrent.Executors;
-
-/**
- * Tests for {@link CloningThreadLocalTest}.
- */
-@RunWith(JUnit4.class)
-public class CloningThreadLocalTest {
-  @Test
-  public void returnsCopiesOfOriginal() throws Exception {
-    Record original = new Record();
-    ThreadLocal<Record> loaded = CloningThreadLocal.of(original);
-    assertThat(loaded.get(), not(nullValue()));
-    assertThat(loaded.get(), equalTo(original));
-    assertThat(loaded.get(), not(theInstance(original)));
-  }
-
-  @Test
-  public void returnsDifferentCopiesInDifferentThreads() throws Exception {
-    Record original = new Record();
-    final ThreadLocal<Record> loaded = CloningThreadLocal.of(original);
-    assertThat(loaded.get(), not(nullValue()));
-    assertThat(loaded.get(), equalTo(original));
-    assertThat(loaded.get(), not(theInstance(original)));
-
-    Callable<Record> otherThread =
-        new Callable<Record>() {
-          @Override
-          public Record call() throws Exception {
-            return loaded.get();
-          }
-        };
-    Record sameThread = loaded.get();
-    Record firstOtherThread = Executors.newSingleThreadExecutor().submit(otherThread).get();
-    Record secondOtherThread = Executors.newSingleThreadExecutor().submit(otherThread).get();
-
-    assertThat(sameThread, equalTo(firstOtherThread));
-    assertThat(sameThread, equalTo(secondOtherThread));
-    assertThat(sameThread, not(theInstance(firstOtherThread)));
-    assertThat(sameThread, not(theInstance(secondOtherThread)));
-    assertThat(firstOtherThread, not(theInstance(secondOtherThread)));
-  }
-
-  private static class Record implements Serializable {
-    private final double rand = Math.random();
-
-    @Override
-    public boolean equals(Object other) {
-      if (!(other instanceof Record)) {
-        return false;
-      }
-      Record that = (Record) other;
-      return this.rand == that.rand;
-    }
-
-    @Override
-    public int hashCode() {
-      return 1;
-    }
-  }
-}

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/cf0bf3bf/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DoFnLifecycleManagerRemovingTransformEvaluatorTest.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DoFnLifecycleManagerRemovingTransformEvaluatorTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DoFnLifecycleManagerRemovingTransformEvaluatorTest.java
new file mode 100644
index 0000000..67f4ff4
--- /dev/null
+++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DoFnLifecycleManagerRemovingTransformEvaluatorTest.java
@@ -0,0 +1,144 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.direct;
+
+import static org.hamcrest.Matchers.containsInAnyOrder;
+import static org.hamcrest.Matchers.not;
+import static org.hamcrest.Matchers.nullValue;
+import static org.hamcrest.core.Is.is;
+import static org.junit.Assert.assertThat;
+import static org.junit.Assert.fail;
+
+import org.apache.beam.sdk.transforms.OldDoFn;
+import org.apache.beam.sdk.util.WindowedValue;
+
+import org.hamcrest.Matchers;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+import java.util.ArrayList;
+import java.util.List;
+
+/**
+ * Tests for {@link DoFnLifecycleManagerRemovingTransformEvaluator}.
+ */
+@RunWith(JUnit4.class)
+public class DoFnLifecycleManagerRemovingTransformEvaluatorTest {
+  private DoFnLifecycleManager lifecycleManager;
+
+  @Before
+  public void setup() {
+    lifecycleManager = DoFnLifecycleManager.of(new TestFn());
+  }
+
+  @Test
+  public void delegatesToUnderlying() throws Exception {
+    RecordingTransformEvaluator underlying = new RecordingTransformEvaluator();
+    OldDoFn<?, ?> original = lifecycleManager.get();
+    TransformEvaluator<Object> evaluator =
+        DoFnLifecycleManagerRemovingTransformEvaluator.wrapping(underlying, lifecycleManager);
+    WindowedValue<Object> first = WindowedValue.valueInGlobalWindow(new Object());
+    WindowedValue<Object> second = WindowedValue.valueInGlobalWindow(new Object());
+    evaluator.processElement(first);
+    assertThat(underlying.objects, containsInAnyOrder(first));
+    evaluator.processElement(second);
+    evaluator.finishBundle();
+
+    assertThat(underlying.finishBundleCalled, is(true));
+    assertThat(underlying.objects, containsInAnyOrder(second, first));
+  }
+
+  @Test
+  public void removesOnExceptionInProcessElement() throws Exception {
+    ThrowingTransformEvaluator underlying = new ThrowingTransformEvaluator();
+    OldDoFn<?, ?> original = lifecycleManager.get();
+    assertThat(original, not(nullValue()));
+    TransformEvaluator<Object> evaluator =
+        DoFnLifecycleManagerRemovingTransformEvaluator.wrapping(underlying, lifecycleManager);
+
+    try {
+      evaluator.processElement(WindowedValue.valueInGlobalWindow(new Object()));
+    } catch (Exception e) {
+      assertThat(lifecycleManager.get(), not(Matchers.<OldDoFn<?, ?>>theInstance(original)));
+      return;
+    }
+    fail("Expected ThrowingTransformEvaluator to throw on method call");
+  }
+
+  @Test
+  public void removesOnExceptionInFinishBundle() throws Exception {
+    ThrowingTransformEvaluator underlying = new ThrowingTransformEvaluator();
+    OldDoFn<?, ?> original = lifecycleManager.get();
+    // the LifecycleManager is set when the evaluator starts
+    assertThat(original, not(nullValue()));
+    TransformEvaluator<Object> evaluator =
+        DoFnLifecycleManagerRemovingTransformEvaluator.wrapping(underlying, lifecycleManager);
+
+    try {
+      evaluator.finishBundle();
+    } catch (Exception e) {
+      assertThat(lifecycleManager.get(),
+          Matchers.not(Matchers.<OldDoFn<?, ?>>theInstance(original)));
+      return;
+    }
+    fail("Expected ThrowingTransformEvaluator to throw on method call");
+  }
+
+  private class RecordingTransformEvaluator implements TransformEvaluator<Object> {
+    private boolean finishBundleCalled;
+    private List<WindowedValue<Object>> objects;
+
+    public RecordingTransformEvaluator() {
+      this.finishBundleCalled = true;
+      this.objects = new ArrayList<>();
+    }
+
+    @Override
+    public void processElement(WindowedValue<Object> element) throws Exception {
+      objects.add(element);
+    }
+
+    @Override
+    public TransformResult finishBundle() throws Exception {
+      finishBundleCalled = true;
+      return null;
+    }
+  }
+
+  private class ThrowingTransformEvaluator implements TransformEvaluator<Object> {
+    @Override
+    public void processElement(WindowedValue<Object> element) throws Exception {
+      throw new Exception();
+    }
+
+    @Override
+    public TransformResult finishBundle() throws Exception {
+      throw new Exception();
+    }
+  }
+
+
+  private static class TestFn extends OldDoFn<Object, Object> {
+    @Override
+    public void processElement(ProcessContext c) throws Exception {
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/cf0bf3bf/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DoFnLifecycleManagerTest.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DoFnLifecycleManagerTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DoFnLifecycleManagerTest.java
new file mode 100644
index 0000000..f316e19
--- /dev/null
+++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DoFnLifecycleManagerTest.java
@@ -0,0 +1,119 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.direct;
+
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.not;
+import static org.hamcrest.Matchers.theInstance;
+import static org.junit.Assert.assertThat;
+
+import org.apache.beam.sdk.transforms.OldDoFn;
+
+import org.hamcrest.Matchers;
+import org.junit.Test;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.concurrent.Callable;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
+import java.util.concurrent.TimeUnit;
+
+/**
+ * Tests for {@link DoFnLifecycleManager}.
+ */
+public class DoFnLifecycleManagerTest {
+  private TestFn fn = new TestFn();
+  private DoFnLifecycleManager mgr = DoFnLifecycleManager.of(fn);
+
+  @Test
+  public void setupOnGet() throws Exception {
+    TestFn obtained = (TestFn) mgr.get();
+
+    assertThat(obtained, not(theInstance(fn)));
+  }
+
+  @Test
+  public void getMultipleCallsSingleSetupCall() throws Exception {
+    TestFn obtained = (TestFn) mgr.get();
+    TestFn secondObtained = (TestFn) mgr.get();
+
+    assertThat(obtained, theInstance(secondObtained));
+  }
+
+  @Test
+  public void getMultipleThreadsDifferentInstances() throws Exception {
+    CountDownLatch startSignal = new CountDownLatch(1);
+    ExecutorService executor = Executors.newCachedThreadPool();
+    List<Future<TestFn>> futures = new ArrayList<>();
+    for (int i = 0; i < 10; i++) {
+      futures.add(executor.submit(new GetFnCallable(mgr, startSignal)));
+    }
+    startSignal.countDown();
+    List<TestFn> fns = new ArrayList<>();
+    for (Future<TestFn> future : futures) {
+      fns.add(future.get(1L, TimeUnit.SECONDS));
+    }
+
+    for (TestFn fn : fns) {
+      int sameInstances = 0;
+      for (TestFn otherFn : fns) {
+        if (otherFn == fn) {
+          sameInstances++;
+        }
+      }
+      assertThat(sameInstances, equalTo(1));
+    }
+  }
+
+  @Test
+  public void teardownOnRemove() throws Exception {
+    TestFn obtained = (TestFn) mgr.get();
+    mgr.remove();
+
+    assertThat(obtained, not(theInstance(fn)));
+
+    assertThat(mgr.get(), not(Matchers.<OldDoFn<?, ?>>theInstance(obtained)));
+  }
+
+  private static class GetFnCallable implements Callable<TestFn> {
+    private final DoFnLifecycleManager mgr;
+    private final CountDownLatch startSignal;
+
+    private GetFnCallable(DoFnLifecycleManager mgr, CountDownLatch startSignal) {
+      this.mgr = mgr;
+      this.startSignal = startSignal;
+    }
+
+    @Override
+    public TestFn call() throws Exception {
+      startSignal.await();
+      return (TestFn) mgr.get();
+    }
+  }
+
+
+  private static class TestFn extends OldDoFn<Object, Object> {
+    @Override
+    public void processElement(ProcessContext c) throws Exception {
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/cf0bf3bf/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ThreadLocalInvalidatingTransformEvaluatorTest.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ThreadLocalInvalidatingTransformEvaluatorTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ThreadLocalInvalidatingTransformEvaluatorTest.java
deleted file mode 100644
index 6e477d3..0000000
--- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ThreadLocalInvalidatingTransformEvaluatorTest.java
+++ /dev/null
@@ -1,135 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.beam.runners.direct;
-
-import static org.hamcrest.Matchers.containsInAnyOrder;
-import static org.hamcrest.Matchers.not;
-import static org.hamcrest.Matchers.nullValue;
-import static org.hamcrest.core.Is.is;
-import static org.junit.Assert.assertThat;
-import static org.junit.Assert.fail;
-
-import org.apache.beam.sdk.util.WindowedValue;
-
-import org.junit.Before;
-import org.junit.Test;
-import org.junit.runner.RunWith;
-import org.junit.runners.JUnit4;
-
-import java.util.ArrayList;
-import java.util.List;
-
-/**
- * Tests for {@link ThreadLocalInvalidatingTransformEvaluator}.
- */
-@RunWith(JUnit4.class)
-public class ThreadLocalInvalidatingTransformEvaluatorTest {
-  private ThreadLocal<Object> threadLocal;
-
-  @Before
-  public void setup() {
-    threadLocal = new ThreadLocal<>();
-    threadLocal.set(new Object());
-  }
-
-  @Test
-  public void delegatesToUnderlying() throws Exception {
-    RecordingTransformEvaluator underlying = new RecordingTransformEvaluator();
-    Object original = threadLocal.get();
-    TransformEvaluator<Object> evaluator =
-        ThreadLocalInvalidatingTransformEvaluator.wrapping(underlying, threadLocal);
-    WindowedValue<Object> first = WindowedValue.valueInGlobalWindow(new Object());
-    WindowedValue<Object> second = WindowedValue.valueInGlobalWindow(new Object());
-    evaluator.processElement(first);
-    assertThat(underlying.objects, containsInAnyOrder(first));
-    evaluator.processElement(second);
-    evaluator.finishBundle();
-
-    assertThat(underlying.finishBundleCalled, is(true));
-    assertThat(underlying.objects, containsInAnyOrder(second, first));
-  }
-
-  @Test
-  public void removesOnExceptionInProcessElement() {
-    ThrowingTransformEvaluator underlying = new ThrowingTransformEvaluator();
-    Object original = threadLocal.get();
-    assertThat(original, not(nullValue()));
-    TransformEvaluator<Object> evaluator =
-        ThreadLocalInvalidatingTransformEvaluator.wrapping(underlying, threadLocal);
-
-    try {
-      evaluator.processElement(WindowedValue.valueInGlobalWindow(new Object()));
-    } catch (Exception e) {
-      assertThat(threadLocal.get(), nullValue());
-      return;
-    }
-    fail("Expected ThrowingTransformEvaluator to throw on method call");
-  }
-
-  @Test
-  public void removesOnExceptionInFinishBundle() {
-    ThrowingTransformEvaluator underlying = new ThrowingTransformEvaluator();
-    Object original = threadLocal.get();
-    // the ThreadLocal is set when the evaluator starts
-    assertThat(original, not(nullValue()));
-    TransformEvaluator<Object> evaluator =
-        ThreadLocalInvalidatingTransformEvaluator.wrapping(underlying, threadLocal);
-
-    try {
-      evaluator.finishBundle();
-    } catch (Exception e) {
-      assertThat(threadLocal.get(), nullValue());
-      return;
-    }
-    fail("Expected ThrowingTransformEvaluator to throw on method call");
-  }
-
-  private class RecordingTransformEvaluator implements TransformEvaluator<Object> {
-    private boolean finishBundleCalled;
-    private List<WindowedValue<Object>> objects;
-
-    public RecordingTransformEvaluator() {
-      this.finishBundleCalled = true;
-      this.objects = new ArrayList<>();
-    }
-
-    @Override
-    public void processElement(WindowedValue<Object> element) throws Exception {
-      objects.add(element);
-    }
-
-    @Override
-    public TransformResult finishBundle() throws Exception {
-      finishBundleCalled = true;
-      return null;
-    }
-  }
-
-  private class ThrowingTransformEvaluator implements TransformEvaluator<Object> {
-    @Override
-    public void processElement(WindowedValue<Object> element) throws Exception {
-      throw new Exception();
-    }
-
-    @Override
-    public TransformResult finishBundle() throws Exception {
-      throw new Exception();
-    }
-  }
-}



[3/5] incubator-beam git commit: Add TransformEvaluatorFactory#cleanup

Posted by dh...@apache.org.
Add TransformEvaluatorFactory#cleanup

This cleans up any state stored within the Transform Evaluator Factory.


Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/12b19677
Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/12b19677
Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/12b19677

Branch: refs/heads/master
Commit: 12b19677280c11b0dca203ef266769b05c90937e
Parents: 0b1f664
Author: Thomas Groh <tg...@google.com>
Authored: Fri Jul 15 11:27:00 2016 -0700
Committer: Dan Halperin <dh...@google.com>
Committed: Mon Aug 15 14:16:54 2016 -0700

----------------------------------------------------------------------
 .../direct/BoundedReadEvaluatorFactory.java     |  4 ++
 .../direct/ExecutorServiceParallelExecutor.java |  9 ++++-
 .../runners/direct/FlattenEvaluatorFactory.java |  3 ++
 .../GroupAlsoByWindowEvaluatorFactory.java      |  6 ++-
 .../direct/GroupByKeyOnlyEvaluatorFactory.java  |  4 +-
 .../direct/ParDoMultiEvaluatorFactory.java      |  5 +++
 .../direct/ParDoSingleEvaluatorFactory.java     |  5 +++
 .../direct/TransformEvaluatorFactory.java       |  8 ++++
 .../direct/TransformEvaluatorRegistry.java      | 41 ++++++++++++++++++++
 .../direct/UnboundedReadEvaluatorFactory.java   |  3 ++
 .../runners/direct/ViewEvaluatorFactory.java    |  3 ++
 .../runners/direct/WindowEvaluatorFactory.java  |  3 ++
 12 files changed, 90 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/12b19677/runners/direct-java/src/main/java/org/apache/beam/runners/direct/BoundedReadEvaluatorFactory.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/BoundedReadEvaluatorFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/BoundedReadEvaluatorFactory.java
index 2f4f86c..0c4b7fd 100644
--- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/BoundedReadEvaluatorFactory.java
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/BoundedReadEvaluatorFactory.java
@@ -60,6 +60,10 @@ final class BoundedReadEvaluatorFactory implements TransformEvaluatorFactory {
     return getTransformEvaluator((AppliedPTransform) application, evaluationContext);
   }
 
+  @Override
+  public void cleanup() {
+  }
+
   /**
    * Get a {@link TransformEvaluator} that produces elements for the provided application of
    * {@link Bounded Read.Bounded}, initializing the queue of evaluators if required.

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/12b19677/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ExecutorServiceParallelExecutor.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ExecutorServiceParallelExecutor.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ExecutorServiceParallelExecutor.java
index a0a5ec0..8c6c6ed 100644
--- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ExecutorServiceParallelExecutor.java
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ExecutorServiceParallelExecutor.java
@@ -447,13 +447,18 @@ final class ExecutorServiceParallelExecutor implements PipelineExecutor {
     private boolean shouldShutdown() {
       boolean shouldShutdown = exceptionThrown || evaluationContext.isDone();
       if (shouldShutdown) {
+        LOG.debug("Pipeline has terminated. Shutting down.");
+        executorService.shutdown();
+        try {
+          registry.cleanup();
+        } catch (Exception e) {
+          visibleUpdates.add(VisibleExecutorUpdate.fromThrowable(e));
+        }
         if (evaluationContext.isDone()) {
-          LOG.debug("Pipeline is finished. Shutting down. {}");
           while (!visibleUpdates.offer(VisibleExecutorUpdate.finished())) {
             visibleUpdates.poll();
           }
         }
-        executorService.shutdown();
       }
       return shouldShutdown;
     }

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/12b19677/runners/direct-java/src/main/java/org/apache/beam/runners/direct/FlattenEvaluatorFactory.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/FlattenEvaluatorFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/FlattenEvaluatorFactory.java
index c84f620..5a0d31d 100644
--- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/FlattenEvaluatorFactory.java
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/FlattenEvaluatorFactory.java
@@ -43,6 +43,9 @@ class FlattenEvaluatorFactory implements TransformEvaluatorFactory {
     return evaluator;
   }
 
+  @Override
+  public void cleanup() throws Exception {}
+
   private <InputT> TransformEvaluator<InputT> createInMemoryEvaluator(
       final AppliedPTransform<
               PCollectionList<InputT>, PCollection<InputT>, FlattenPCollectionList<InputT>>

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/12b19677/runners/direct-java/src/main/java/org/apache/beam/runners/direct/GroupAlsoByWindowEvaluatorFactory.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/GroupAlsoByWindowEvaluatorFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/GroupAlsoByWindowEvaluatorFactory.java
index e052226..d16ffa0 100644
--- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/GroupAlsoByWindowEvaluatorFactory.java
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/GroupAlsoByWindowEvaluatorFactory.java
@@ -61,11 +61,15 @@ class GroupAlsoByWindowEvaluatorFactory implements TransformEvaluatorFactory {
     return evaluator;
   }
 
+  @Override
+  public void cleanup() {}
+
   private <K, V> TransformEvaluator<KeyedWorkItem<K, V>> createEvaluator(
       AppliedPTransform<
               PCollection<KeyedWorkItem<K, V>>,
               PCollection<KV<K, Iterable<V>>>,
-              DirectGroupAlsoByWindow<K, V>> application,
+              DirectGroupAlsoByWindow<K, V>>
+          application,
       CommittedBundle<KeyedWorkItem<K, V>> inputBundle,
       EvaluationContext evaluationContext) {
     return new GroupAlsoByWindowEvaluator<>(

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/12b19677/runners/direct-java/src/main/java/org/apache/beam/runners/direct/GroupByKeyOnlyEvaluatorFactory.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/GroupByKeyOnlyEvaluatorFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/GroupByKeyOnlyEvaluatorFactory.java
index 0e419c3..dbdbdaf 100644
--- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/GroupByKeyOnlyEvaluatorFactory.java
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/GroupByKeyOnlyEvaluatorFactory.java
@@ -18,7 +18,6 @@
 package org.apache.beam.runners.direct;
 
 import static org.apache.beam.sdk.util.CoderUtils.encodeToByteArray;
-
 import static com.google.common.base.Preconditions.checkState;
 
 import org.apache.beam.runners.direct.DirectGroupByKey.DirectGroupByKeyOnly;
@@ -61,6 +60,9 @@ class GroupByKeyOnlyEvaluatorFactory implements TransformEvaluatorFactory {
     return evaluator;
   }
 
+  @Override
+  public void cleanup() {}
+
   private <K, V> TransformEvaluator<KV<K, WindowedValue<V>>> createEvaluator(
       final AppliedPTransform<
           PCollection<KV<K, WindowedValue<V>>>,

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/12b19677/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiEvaluatorFactory.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiEvaluatorFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiEvaluatorFactory.java
index ce770ca..40533c0 100644
--- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiEvaluatorFactory.java
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiEvaluatorFactory.java
@@ -70,6 +70,11 @@ class ParDoMultiEvaluatorFactory implements TransformEvaluatorFactory {
     return evaluator;
   }
 
+  @Override
+  public void cleanup() {
+
+  }
+
   private <InT, OuT> TransformEvaluator<InT> createMultiEvaluator(
       AppliedPTransform<PCollection<InT>, PCollectionTuple, BoundMulti<InT, OuT>> application,
       CommittedBundle<InT> inputBundle,

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/12b19677/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoSingleEvaluatorFactory.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoSingleEvaluatorFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoSingleEvaluatorFactory.java
index 53af6af..201fb46 100644
--- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoSingleEvaluatorFactory.java
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoSingleEvaluatorFactory.java
@@ -69,6 +69,11 @@ class ParDoSingleEvaluatorFactory implements TransformEvaluatorFactory {
     return evaluator;
   }
 
+  @Override
+  public void cleanup() {
+
+  }
+
   private <InputT, OutputT> TransformEvaluator<InputT> createSingleEvaluator(
       AppliedPTransform<PCollection<InputT>, PCollection<OutputT>, Bound<InputT, OutputT>>
           application,

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/12b19677/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TransformEvaluatorFactory.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TransformEvaluatorFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TransformEvaluatorFactory.java
index d021b43..3655d26 100644
--- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TransformEvaluatorFactory.java
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TransformEvaluatorFactory.java
@@ -51,4 +51,12 @@ public interface TransformEvaluatorFactory {
   @Nullable <InputT> TransformEvaluator<InputT> forApplication(
       AppliedPTransform<?, ?, ?> application, @Nullable CommittedBundle<?> inputBundle,
       EvaluationContext evaluationContext) throws Exception;
+
+  /**
+   * Cleans up any state maintained by this {@link TransformEvaluatorFactory}. Called after a
+   * {@link Pipeline} is shut down. No more calls to
+   * {@link #forApplication(AppliedPTransform, CommittedBundle, EvaluationContext)} will be made
+   * after a call to {@link #cleanup()}.
+   */
+  void cleanup() throws Exception;
 }

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/12b19677/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TransformEvaluatorRegistry.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TransformEvaluatorRegistry.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TransformEvaluatorRegistry.java
index f0afc3b..b469237 100644
--- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TransformEvaluatorRegistry.java
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TransformEvaluatorRegistry.java
@@ -17,6 +17,8 @@
  */
 package org.apache.beam.runners.direct;
 
+import static com.google.common.base.Preconditions.checkState;
+
 import org.apache.beam.runners.direct.DirectGroupByKey.DirectGroupAlsoByWindow;
 import org.apache.beam.runners.direct.DirectGroupByKey.DirectGroupByKeyOnly;
 import org.apache.beam.runners.direct.DirectRunner.CommittedBundle;
@@ -29,7 +31,13 @@ import org.apache.beam.sdk.transforms.windowing.Window;
 
 import com.google.common.collect.ImmutableMap;
 
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.ArrayList;
+import java.util.Collection;
 import java.util.Map;
+import java.util.concurrent.atomic.AtomicBoolean;
 
 import javax.annotation.Nullable;
 
@@ -38,6 +46,7 @@ import javax.annotation.Nullable;
  * implementations based on the type of {@link PTransform} of the application.
  */
 class TransformEvaluatorRegistry implements TransformEvaluatorFactory {
+  private static final Logger LOG = LoggerFactory.getLogger(TransformEvaluatorRegistry.class);
   public static TransformEvaluatorRegistry defaultRegistry() {
     @SuppressWarnings("rawtypes")
     ImmutableMap<Class<? extends PTransform>, TransformEvaluatorFactory> primitives =
@@ -61,6 +70,8 @@ class TransformEvaluatorRegistry implements TransformEvaluatorFactory {
   @SuppressWarnings("rawtypes")
   private final Map<Class<? extends PTransform>, TransformEvaluatorFactory> factories;
 
+  private final AtomicBoolean finished = new AtomicBoolean(false);
+
   private TransformEvaluatorRegistry(
       @SuppressWarnings("rawtypes")
       Map<Class<? extends PTransform>, TransformEvaluatorFactory> factories) {
@@ -73,7 +84,37 @@ class TransformEvaluatorRegistry implements TransformEvaluatorFactory {
       @Nullable CommittedBundle<?> inputBundle,
       EvaluationContext evaluationContext)
       throws Exception {
+    checkState(
+        !finished.get(), "Tried to get an evaluator for a finished TransformEvaluatorRegistry");
     TransformEvaluatorFactory factory = factories.get(application.getTransform().getClass());
     return factory.forApplication(application, inputBundle, evaluationContext);
   }
+
+  @Override
+  public void cleanup() throws Exception {
+    Collection<Exception> thrownInCleanup = new ArrayList<>();
+    for (TransformEvaluatorFactory factory : factories.values()) {
+      try {
+        factory.cleanup();
+      } catch (Exception e) {
+        if (e instanceof InterruptedException) {
+          Thread.currentThread().interrupt();
+        }
+        thrownInCleanup.add(e);
+      }
+    }
+    finished.set(true);
+    if (!thrownInCleanup.isEmpty()) {
+      LOG.error("Exceptions {} thrown while cleaning up evaluators", thrownInCleanup);
+      Exception toThrow = null;
+      for (Exception e : thrownInCleanup) {
+        if (toThrow == null) {
+          toThrow = e;
+        } else {
+          toThrow.addSuppressed(e);
+        }
+      }
+      throw toThrow;
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/12b19677/runners/direct-java/src/main/java/org/apache/beam/runners/direct/UnboundedReadEvaluatorFactory.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/UnboundedReadEvaluatorFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/UnboundedReadEvaluatorFactory.java
index 0e2745b..c4d408b 100644
--- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/UnboundedReadEvaluatorFactory.java
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/UnboundedReadEvaluatorFactory.java
@@ -113,6 +113,9 @@ class UnboundedReadEvaluatorFactory implements TransformEvaluatorFactory {
     return evaluatorQueue.poll();
   }
 
+  @Override
+  public void cleanup() {}
+
   /**
    * A {@link UnboundedReadEvaluator} produces elements from an underlying {@link UnboundedSource},
    * discarding all input elements. Within the call to {@link #finishBundle()}, the evaluator

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/12b19677/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ViewEvaluatorFactory.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ViewEvaluatorFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ViewEvaluatorFactory.java
index 362e903..3b0de4b 100644
--- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ViewEvaluatorFactory.java
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ViewEvaluatorFactory.java
@@ -59,6 +59,9 @@ class ViewEvaluatorFactory implements TransformEvaluatorFactory {
     return evaluator;
   }
 
+  @Override
+  public void cleanup() throws Exception {}
+
   private <InT, OuT> TransformEvaluator<Iterable<InT>> createEvaluator(
       final AppliedPTransform<PCollection<Iterable<InT>>, PCollectionView<OuT>, WriteView<InT, OuT>>
           application,

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/12b19677/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WindowEvaluatorFactory.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WindowEvaluatorFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WindowEvaluatorFactory.java
index 67c2f17..f2e62cb 100644
--- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WindowEvaluatorFactory.java
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WindowEvaluatorFactory.java
@@ -66,6 +66,9 @@ class WindowEvaluatorFactory implements TransformEvaluatorFactory {
     return new WindowIntoEvaluator<>(transform, fn, outputBundle);
   }
 
+  @Override
+  public void cleanup() {}
+
   private static class WindowIntoEvaluator<InputT> implements TransformEvaluator<InputT> {
     private final AppliedPTransform<PCollection<InputT>, PCollection<InputT>, Window.Bound<InputT>>
         transform;


[2/5] incubator-beam git commit: Add DoFn @Setup and @Teardown

Posted by dh...@apache.org.
Add DoFn @Setup and @Teardown

Methods annotated with these annotations are used to perform expensive
setup work and clean up a DoFn after another method throws an exception
or the DoFn is discarded.


Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/12abb1b0
Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/12abb1b0
Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/12abb1b0

Branch: refs/heads/master
Commit: 12abb1b02246b8d36021c7b1a970daf1b64ba4b9
Parents: cf0bf3b
Author: Thomas Groh <tg...@google.com>
Authored: Thu Jul 14 14:51:02 2016 -0700
Committer: Dan Halperin <dh...@google.com>
Committed: Mon Aug 15 14:16:54 2016 -0700

----------------------------------------------------------------------
 .../runners/direct/DoFnLifecycleManager.java    |  38 +-
 ...ecycleManagerRemovingTransformEvaluator.java |  39 +-
 .../runners/direct/DoFnLifecycleManagers.java   |  45 ++
 .../direct/ParDoMultiEvaluatorFactory.java      |   4 +-
 .../direct/ParDoSingleEvaluatorFactory.java     |   4 +-
 .../direct/DoFnLifecycleManagerTest.java        |  49 +++
 .../direct/DoFnLifecycleManagersTest.java       | 142 +++++++
 .../functions/FlinkDoFnFunction.java            |  12 +-
 .../functions/FlinkMultiOutputDoFnFunction.java |  31 +-
 .../streaming/FlinkAbstractParDoWrapper.java    |   2 +
 .../FlinkGroupAlsoByWindowWrapper.java          |   2 +
 .../runners/spark/translation/DoFnFunction.java |  23 +-
 .../spark/translation/MultiDoFnFunction.java    |   1 +
 .../spark/translation/SparkProcessContext.java  |  17 +
 .../org/apache/beam/sdk/transforms/DoFn.java    |  31 +-
 .../beam/sdk/transforms/DoFnReflector.java      |  70 +++-
 .../org/apache/beam/sdk/transforms/OldDoFn.java |  25 ++
 .../org/apache/beam/sdk/transforms/ParDo.java   |  15 +-
 .../beam/sdk/transforms/DoFnReflectorTest.java  |  65 +++
 .../apache/beam/sdk/transforms/ParDoTest.java   | 420 ++++++++++++++++++-
 20 files changed, 970 insertions(+), 65 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/12abb1b0/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DoFnLifecycleManager.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DoFnLifecycleManager.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DoFnLifecycleManager.java
index 2783657..3f4f2c6 100644
--- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DoFnLifecycleManager.java
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DoFnLifecycleManager.java
@@ -18,6 +18,7 @@
 
 package org.apache.beam.runners.direct;
 
+import org.apache.beam.sdk.runners.PipelineRunner;
 import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.OldDoFn;
 import org.apache.beam.sdk.util.SerializableUtils;
@@ -26,6 +27,13 @@ import com.google.common.cache.CacheBuilder;
 import com.google.common.cache.CacheLoader;
 import com.google.common.cache.LoadingCache;
 
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Iterator;
+
 /**
  * Manages {@link DoFn} setup, teardown, and serialization.
  *
@@ -35,6 +43,8 @@ import com.google.common.cache.LoadingCache;
  * {@link DoFn DoFns}.
  */
 class DoFnLifecycleManager {
+  private static final Logger LOG = LoggerFactory.getLogger(DoFnLifecycleManager.class);
+
   public static DoFnLifecycleManager of(OldDoFn<?, ?> original) {
     return new DoFnLifecycleManager(original);
   }
@@ -52,14 +62,30 @@ class DoFnLifecycleManager {
 
   public void remove() throws Exception {
     Thread currentThread = Thread.currentThread();
-    outstanding.invalidate(currentThread);
+    OldDoFn<?, ?> fn = outstanding.asMap().remove(currentThread);
+    fn.teardown();
   }
 
   /**
-   * Remove all {@link DoFn DoFns} from this {@link DoFnLifecycleManager}.
+   * Remove all {@link DoFn DoFns} from this {@link DoFnLifecycleManager}. Returns all exceptions
+   * that were thrown while calling the remove methods.
+   *
+   * <p>If the returned Collection is nonempty, an exception was thrown from at least one
+   * {@link DoFn#teardown()} method, and the {@link PipelineRunner} should throw an exception.
    */
-  public void removeAll() throws Exception {
-    outstanding.invalidateAll();
+  public Collection<Exception> removeAll() throws Exception {
+    Iterator<OldDoFn<?, ?>> fns = outstanding.asMap().values().iterator();
+    Collection<Exception> thrown = new ArrayList<>();
+    while (fns.hasNext()) {
+      OldDoFn<?, ?> fn = fns.next();
+      fns.remove();
+      try {
+        fn.teardown();
+      } catch (Exception e) {
+        thrown.add(e);
+      }
+    }
+    return thrown;
   }
 
   private class DeserializingCacheLoader extends CacheLoader<Thread, OldDoFn<?, ?>> {
@@ -71,8 +97,10 @@ class DoFnLifecycleManager {
 
     @Override
     public OldDoFn<?, ?> load(Thread key) throws Exception {
-      return (OldDoFn<?, ?>) SerializableUtils.deserializeFromByteArray(original,
+      OldDoFn<?, ?> fn = (OldDoFn<?, ?>) SerializableUtils.deserializeFromByteArray(original,
           "DoFn Copy in thread " + key.getName());
+      fn.setup();
+      return fn;
     }
   }
 }

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/12abb1b0/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DoFnLifecycleManagerRemovingTransformEvaluator.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DoFnLifecycleManagerRemovingTransformEvaluator.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DoFnLifecycleManagerRemovingTransformEvaluator.java
index f3d1d4f..523273c 100644
--- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DoFnLifecycleManagerRemovingTransformEvaluator.java
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DoFnLifecycleManagerRemovingTransformEvaluator.java
@@ -34,14 +34,14 @@ class DoFnLifecycleManagerRemovingTransformEvaluator<InputT> implements Transfor
   private final DoFnLifecycleManager lifecycleManager;
 
   public static <InputT> TransformEvaluator<InputT> wrapping(
-      TransformEvaluator<InputT> underlying, DoFnLifecycleManager threadLocal) {
-    return new DoFnLifecycleManagerRemovingTransformEvaluator<>(underlying, threadLocal);
+      TransformEvaluator<InputT> underlying, DoFnLifecycleManager lifecycleManager) {
+    return new DoFnLifecycleManagerRemovingTransformEvaluator<>(underlying, lifecycleManager);
   }
 
   private DoFnLifecycleManagerRemovingTransformEvaluator(
-      TransformEvaluator<InputT> underlying, DoFnLifecycleManager threadLocal) {
+      TransformEvaluator<InputT> underlying, DoFnLifecycleManager lifecycleManager) {
     this.underlying = underlying;
-    this.lifecycleManager = threadLocal;
+    this.lifecycleManager = lifecycleManager;
   }
 
   @Override
@@ -49,14 +49,7 @@ class DoFnLifecycleManagerRemovingTransformEvaluator<InputT> implements Transfor
     try {
       underlying.processElement(element);
     } catch (Exception e) {
-      try {
-        lifecycleManager.remove();
-      } catch (Exception removalException) {
-        LOG.error(
-            "Exception encountered while cleaning up after processing an element",
-            removalException);
-        e.addSuppressed(removalException);
-      }
+      onException(e, "Exception encountered while cleaning up after processing an element");
       throw e;
     }
   }
@@ -66,15 +59,21 @@ class DoFnLifecycleManagerRemovingTransformEvaluator<InputT> implements Transfor
     try {
       return underlying.finishBundle();
     } catch (Exception e) {
-      try {
-        lifecycleManager.remove();
-      } catch (Exception removalException) {
-        LOG.error(
-            "Exception encountered while cleaning up after finishing a bundle",
-            removalException);
-        e.addSuppressed(removalException);
-      }
+      onException(e, "Exception encountered while cleaning up after finishing a bundle");
       throw e;
     }
   }
+
+  private void onException(Exception e, String msg) {
+    try {
+      lifecycleManager.remove();
+    } catch (Exception removalException) {
+      if (removalException instanceof InterruptedException) {
+        Thread.currentThread().interrupt();
+      }
+      LOG.error(msg, removalException);
+      e.addSuppressed(removalException);
+    }
+  }
+
 }

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/12abb1b0/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DoFnLifecycleManagers.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DoFnLifecycleManagers.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DoFnLifecycleManagers.java
new file mode 100644
index 0000000..6a1dd8f
--- /dev/null
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DoFnLifecycleManagers.java
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.direct;
+
+import java.util.ArrayList;
+import java.util.Collection;
+
+/**
+ * Utility methods for interacting with {@link DoFnLifecycleManager DoFnLifecycleManagers}.
+ */
+class DoFnLifecycleManagers {
+  private DoFnLifecycleManagers() {
+    /* Do not instantiate */
+  }
+
+  static void removeAllFromManagers(Iterable<DoFnLifecycleManager> managers) throws Exception {
+    Collection<Exception> thrown = new ArrayList<>();
+    for (DoFnLifecycleManager manager : managers) {
+      thrown.addAll(manager.removeAll());
+    }
+    if (!thrown.isEmpty()) {
+      Exception overallException = new Exception("Exceptions thrown while tearing down DoFns");
+      for (Exception e : thrown) {
+        overallException.addSuppressed(e);
+      }
+      throw overallException;
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/12abb1b0/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiEvaluatorFactory.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiEvaluatorFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiEvaluatorFactory.java
index f2455e1..2d05e68 100644
--- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiEvaluatorFactory.java
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiEvaluatorFactory.java
@@ -69,9 +69,7 @@ class ParDoMultiEvaluatorFactory implements TransformEvaluatorFactory {
 
   @Override
   public void cleanup() throws Exception {
-    for (DoFnLifecycleManager lifecycleManager : fnClones.asMap().values()) {
-      lifecycleManager.removeAll();
-    }
+    DoFnLifecycleManagers.removeAllFromManagers(fnClones.asMap().values());
   }
 
   private <InT, OuT> TransformEvaluator<InT> createMultiEvaluator(

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/12abb1b0/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoSingleEvaluatorFactory.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoSingleEvaluatorFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoSingleEvaluatorFactory.java
index a0fbd1d..97cbfa7 100644
--- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoSingleEvaluatorFactory.java
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoSingleEvaluatorFactory.java
@@ -70,9 +70,7 @@ class ParDoSingleEvaluatorFactory implements TransformEvaluatorFactory {
 
   @Override
   public void cleanup() throws Exception {
-    for (DoFnLifecycleManager lifecycleManager : fnClones.asMap().values()) {
-      lifecycleManager.removeAll();
-    }
+    DoFnLifecycleManagers.removeAllFromManagers(fnClones.asMap().values());
   }
 
   private <InputT, OutputT> TransformEvaluator<InputT> createSingleEvaluator(

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/12abb1b0/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DoFnLifecycleManagerTest.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DoFnLifecycleManagerTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DoFnLifecycleManagerTest.java
index f316e19..77b3296 100644
--- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DoFnLifecycleManagerTest.java
+++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DoFnLifecycleManagerTest.java
@@ -18,7 +18,9 @@
 
 package org.apache.beam.runners.direct;
 
+import static com.google.common.base.Preconditions.checkState;
 import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.is;
 import static org.hamcrest.Matchers.not;
 import static org.hamcrest.Matchers.theInstance;
 import static org.junit.Assert.assertThat;
@@ -49,6 +51,8 @@ public class DoFnLifecycleManagerTest {
     TestFn obtained = (TestFn) mgr.get();
 
     assertThat(obtained, not(theInstance(fn)));
+    assertThat(obtained.setupCalled, is(true));
+    assertThat(obtained.teardownCalled, is(false));
   }
 
   @Test
@@ -57,6 +61,8 @@ public class DoFnLifecycleManagerTest {
     TestFn secondObtained = (TestFn) mgr.get();
 
     assertThat(obtained, theInstance(secondObtained));
+    assertThat(obtained.setupCalled, is(true));
+    assertThat(obtained.teardownCalled, is(false));
   }
 
   @Test
@@ -74,6 +80,7 @@ public class DoFnLifecycleManagerTest {
     }
 
     for (TestFn fn : fns) {
+      assertThat(fn.setupCalled, is(true));
       int sameInstances = 0;
       for (TestFn otherFn : fns) {
         if (otherFn == fn) {
@@ -90,10 +97,33 @@ public class DoFnLifecycleManagerTest {
     mgr.remove();
 
     assertThat(obtained, not(theInstance(fn)));
+    assertThat(obtained.setupCalled, is(true));
+    assertThat(obtained.teardownCalled, is(true));
 
     assertThat(mgr.get(), not(Matchers.<OldDoFn<?, ?>>theInstance(obtained)));
   }
 
+  @Test
+  public void teardownAllOnRemoveAll() throws Exception {
+    CountDownLatch startSignal = new CountDownLatch(1);
+    ExecutorService executor = Executors.newCachedThreadPool();
+    List<Future<TestFn>> futures = new ArrayList<>();
+    for (int i = 0; i < 10; i++) {
+      futures.add(executor.submit(new GetFnCallable(mgr, startSignal)));
+    }
+    startSignal.countDown();
+    List<TestFn> fns = new ArrayList<>();
+    for (Future<TestFn> future : futures) {
+      fns.add(future.get(1L, TimeUnit.SECONDS));
+    }
+    mgr.removeAll();
+
+    for (TestFn fn : fns) {
+      assertThat(fn.setupCalled, is(true));
+      assertThat(fn.teardownCalled, is(true));
+    }
+  }
+
   private static class GetFnCallable implements Callable<TestFn> {
     private final DoFnLifecycleManager mgr;
     private final CountDownLatch startSignal;
@@ -112,8 +142,27 @@ public class DoFnLifecycleManagerTest {
 
 
   private static class TestFn extends OldDoFn<Object, Object> {
+    boolean setupCalled = false;
+    boolean teardownCalled = false;
+
+    @Override
+    public void setup() {
+      checkState(!setupCalled);
+      checkState(!teardownCalled);
+
+      setupCalled = true;
+    }
+
     @Override
     public void processElement(ProcessContext c) throws Exception {
     }
+
+    @Override
+    public void teardown() {
+      checkState(setupCalled);
+      checkState(!teardownCalled);
+
+      teardownCalled = true;
+    }
   }
 }

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/12abb1b0/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DoFnLifecycleManagersTest.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DoFnLifecycleManagersTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DoFnLifecycleManagersTest.java
new file mode 100644
index 0000000..8be3d52
--- /dev/null
+++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DoFnLifecycleManagersTest.java
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.direct;
+
+import static org.hamcrest.Matchers.equalTo;
+
+import org.apache.beam.sdk.transforms.OldDoFn;
+
+import com.google.common.collect.ImmutableList;
+
+import org.hamcrest.BaseMatcher;
+import org.hamcrest.Description;
+import org.hamcrest.Matcher;
+import org.hamcrest.Matchers;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.ExpectedException;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+import java.util.ArrayList;
+import java.util.Collection;
+
+/**
+ * Tests for {@link DoFnLifecycleManagers}.
+ */
+@RunWith(JUnit4.class)
+public class DoFnLifecycleManagersTest {
+  @Rule public ExpectedException thrown = ExpectedException.none();
+
+  @Test
+  public void removeAllWhenManagersThrowSuppressesAndThrows() throws Exception {
+    DoFnLifecycleManager first = DoFnLifecycleManager.of(new ThrowsInCleanupFn("foo"));
+    DoFnLifecycleManager second = DoFnLifecycleManager.of(new ThrowsInCleanupFn("bar"));
+    DoFnLifecycleManager third = DoFnLifecycleManager.of(new ThrowsInCleanupFn("baz"));
+    first.get();
+    second.get();
+    third.get();
+
+    final Collection<Matcher<? super Throwable>> suppressions = new ArrayList<>();
+    suppressions.add(new ThrowableMessageMatcher("foo"));
+    suppressions.add(new ThrowableMessageMatcher("bar"));
+    suppressions.add(new ThrowableMessageMatcher("baz"));
+
+    thrown.expect(
+        new BaseMatcher<Exception>() {
+          @Override
+          public void describeTo(Description description) {
+            description
+                .appendText("Exception suppressing ")
+                .appendList("[", ", ", "]", suppressions);
+          }
+
+          @Override
+          public boolean matches(Object item) {
+            if (!(item instanceof Exception)) {
+              return false;
+            }
+            Exception that = (Exception) item;
+            return Matchers.containsInAnyOrder(suppressions)
+                .matches(ImmutableList.copyOf(that.getSuppressed()));
+          }
+        });
+
+    DoFnLifecycleManagers.removeAllFromManagers(ImmutableList.of(first, second, third));
+  }
+
+  @Test
+  public void whenManagersSucceedSucceeds() throws Exception {
+    DoFnLifecycleManager first = DoFnLifecycleManager.of(new EmptyFn());
+    DoFnLifecycleManager second = DoFnLifecycleManager.of(new EmptyFn());
+    DoFnLifecycleManager third = DoFnLifecycleManager.of(new EmptyFn());
+    first.get();
+    second.get();
+    third.get();
+
+    DoFnLifecycleManagers.removeAllFromManagers(ImmutableList.of(first, second, third));
+  }
+
+  private static class ThrowsInCleanupFn extends OldDoFn<Object, Object> {
+    private final String message;
+
+    private ThrowsInCleanupFn(String message) {
+      this.message = message;
+    }
+
+    @Override
+    public void processElement(ProcessContext c) throws Exception {
+    }
+
+    @Override
+    public void teardown() throws Exception {
+      throw new Exception(message);
+    }
+  }
+
+
+  private static class ThrowableMessageMatcher extends BaseMatcher<Throwable> {
+    private final Matcher<String> messageMatcher;
+
+    public ThrowableMessageMatcher(String message) {
+      this.messageMatcher = equalTo(message);
+    }
+
+    @Override
+    public boolean matches(Object item) {
+      if (!(item instanceof Throwable)) {
+        return false;
+      }
+      Throwable that = (Throwable) item;
+      return messageMatcher.matches(that.getMessage());
+    }
+
+    @Override
+    public void describeTo(Description description) {
+      description.appendText("a throwable with a message ").appendDescriptionOf(messageMatcher);
+    }
+  }
+
+
+  private static class EmptyFn extends OldDoFn<Object, Object> {
+    @Override
+    public void processElement(ProcessContext c) throws Exception {
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/12abb1b0/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkDoFnFunction.java
----------------------------------------------------------------------
diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkDoFnFunction.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkDoFnFunction.java
index a4af1b0..fdf1e59 100644
--- a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkDoFnFunction.java
+++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkDoFnFunction.java
@@ -25,6 +25,7 @@ import org.apache.beam.sdk.util.WindowingStrategy;
 import org.apache.beam.sdk.values.PCollectionView;
 
 import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.configuration.Configuration;
 import org.apache.flink.util.Collector;
 
 import java.util.Map;
@@ -86,7 +87,7 @@ public class FlinkDoFnFunction<InputT, OutputT>
       // side inputs and window access also only works if an element
       // is in only one window
       for (WindowedValue<InputT> value : values) {
-        for (WindowedValue<InputT> explodedValue: value.explodeWindows()) {
+        for (WindowedValue<InputT> explodedValue : value.explodeWindows()) {
           context = context.forWindowedValue(value);
           doFn.processElement(context);
         }
@@ -99,4 +100,13 @@ public class FlinkDoFnFunction<InputT, OutputT>
     this.doFn.finishBundle(context);
   }
 
+  @Override
+  public void open(Configuration parameters) throws Exception {
+    doFn.setup();
+  }
+
+  @Override
+  public void close() throws Exception {
+    doFn.teardown();
+  }
 }

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/12abb1b0/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkMultiOutputDoFnFunction.java
----------------------------------------------------------------------
diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkMultiOutputDoFnFunction.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkMultiOutputDoFnFunction.java
index 6e673fc..5013b90 100644
--- a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkMultiOutputDoFnFunction.java
+++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkMultiOutputDoFnFunction.java
@@ -27,6 +27,7 @@ import org.apache.beam.sdk.values.PCollectionView;
 import org.apache.beam.sdk.values.TupleTag;
 
 import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.configuration.Configuration;
 import org.apache.flink.util.Collector;
 
 import java.util.Map;
@@ -75,14 +76,15 @@ public class FlinkMultiOutputDoFnFunction<InputT, OutputT>
       Iterable<WindowedValue<InputT>> values,
       Collector<WindowedValue<RawUnionValue>> out) throws Exception {
 
-    FlinkProcessContext<InputT, OutputT> context = new FlinkMultiOutputProcessContext<>(
-        serializedOptions.getPipelineOptions(),
-        getRuntimeContext(),
-        doFn,
-        windowingStrategy,
-        out,
-        outputMap,
-        sideInputs);
+    FlinkProcessContext<InputT, OutputT> context =
+        new FlinkMultiOutputProcessContext<>(
+            serializedOptions.getPipelineOptions(),
+            getRuntimeContext(),
+            doFn,
+            windowingStrategy,
+            out,
+            outputMap,
+            sideInputs);
 
     this.doFn.startBundle(context);
 
@@ -97,14 +99,23 @@ public class FlinkMultiOutputDoFnFunction<InputT, OutputT>
       // side inputs and window access also only works if an element
       // is in only one window
       for (WindowedValue<InputT> value : values) {
-        for (WindowedValue<InputT> explodedValue: value.explodeWindows()) {
+        for (WindowedValue<InputT> explodedValue : value.explodeWindows()) {
           context = context.forWindowedValue(value);
           doFn.processElement(context);
         }
       }
     }
 
-
     this.doFn.finishBundle(context);
   }
+
+  @Override
+  public void open(Configuration parameters) throws Exception {
+    doFn.setup();
+  }
+
+  @Override
+  public void close() throws Exception {
+    doFn.teardown();
+  }
 }

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/12abb1b0/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/FlinkAbstractParDoWrapper.java
----------------------------------------------------------------------
diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/FlinkAbstractParDoWrapper.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/FlinkAbstractParDoWrapper.java
index 74ec66a..a9dd865 100644
--- a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/FlinkAbstractParDoWrapper.java
+++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/FlinkAbstractParDoWrapper.java
@@ -70,6 +70,7 @@ public abstract class FlinkAbstractParDoWrapper<IN, OUTDF, OUTFL> extends RichFl
 
   @Override
   public void open(Configuration parameters) throws Exception {
+    doFn.setup();
   }
 
   @Override
@@ -78,6 +79,7 @@ public abstract class FlinkAbstractParDoWrapper<IN, OUTDF, OUTFL> extends RichFl
       // we have initialized the context
       this.doFn.finishBundle(this.context);
     }
+    this.doFn.teardown();
   }
 
   @Override

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/12abb1b0/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/FlinkGroupAlsoByWindowWrapper.java
----------------------------------------------------------------------
diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/FlinkGroupAlsoByWindowWrapper.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/FlinkGroupAlsoByWindowWrapper.java
index 103a12b..4fddb53 100644
--- a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/FlinkGroupAlsoByWindowWrapper.java
+++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/FlinkGroupAlsoByWindowWrapper.java
@@ -252,6 +252,7 @@ public class FlinkGroupAlsoByWindowWrapper<K, VIN, VACC, VOUT>
   @Override
   public void open() throws Exception {
     super.open();
+    operator.setup();
     this.context = new ProcessContext(operator, new TimestampedCollector<>(output), this.timerInternals);
     operator.startBundle(context);
   }
@@ -351,6 +352,7 @@ public class FlinkGroupAlsoByWindowWrapper<K, VIN, VACC, VOUT>
   @Override
   public void close() throws Exception {
     operator.finishBundle(context);
+    operator.teardown();
     super.close();
   }
 

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/12abb1b0/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/DoFnFunction.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/DoFnFunction.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/DoFnFunction.java
index f4ce516..c08d185 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/DoFnFunction.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/DoFnFunction.java
@@ -24,6 +24,8 @@ import org.apache.beam.sdk.util.WindowedValue;
 import org.apache.beam.sdk.values.TupleTag;
 
 import org.apache.spark.api.java.function.FlatMapFunction;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 
 import java.util.Iterator;
 import java.util.LinkedList;
@@ -40,6 +42,8 @@ public class DoFnFunction<InputT, OutputT>
     implements FlatMapFunction<Iterator<WindowedValue<InputT>>,
     WindowedValue<OutputT>> {
   private final OldDoFn<InputT, OutputT> mFunction;
+  private static final Logger LOG = LoggerFactory.getLogger(DoFnFunction.class);
+
   private final SparkRuntimeContext mRuntimeContext;
   private final Map<TupleTag<?>, BroadcastHelper<?>> mSideInputs;
 
@@ -61,8 +65,23 @@ public class DoFnFunction<InputT, OutputT>
       Exception {
     ProcCtxt ctxt = new ProcCtxt(mFunction, mRuntimeContext, mSideInputs);
     ctxt.setup();
-    mFunction.startBundle(ctxt);
-    return ctxt.getOutputIterable(iter, mFunction);
+    try {
+      mFunction.setup();
+      mFunction.startBundle(ctxt);
+      return ctxt.getOutputIterable(iter, mFunction);
+    } catch (Exception e) {
+      try {
+        // this teardown handles exceptions encountered in setup() and startBundle(). teardown
+        // after execution or due to exceptions in process element is called in the iterator
+        // produced by ctxt.getOutputIterable returned from this method.
+        mFunction.teardown();
+      } catch (Exception teardownException) {
+        LOG.error(
+            "Suppressing exception while tearing down Function {}", mFunction, teardownException);
+        e.addSuppressed(teardownException);
+      }
+      throw e;
+    }
   }
 
   private class ProcCtxt extends SparkProcessContext<InputT, OutputT, WindowedValue<OutputT>> {

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/12abb1b0/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/MultiDoFnFunction.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/MultiDoFnFunction.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/MultiDoFnFunction.java
index e33578d..abf0e83 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/MultiDoFnFunction.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/MultiDoFnFunction.java
@@ -65,6 +65,7 @@ class MultiDoFnFunction<InputT, OutputT>
   public Iterable<Tuple2<TupleTag<?>, WindowedValue<?>>>
       call(Iterator<WindowedValue<InputT>> iter) throws Exception {
     ProcCtxt ctxt = new ProcCtxt(mFunction, mRuntimeContext, mSideInputs);
+    mFunction.setup();
     mFunction.startBundle(ctxt);
     ctxt.setup();
     return ctxt.getOutputIterable(iter, mFunction);

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/12abb1b0/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkProcessContext.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkProcessContext.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkProcessContext.java
index 2f06a1c..1cdbd92 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkProcessContext.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkProcessContext.java
@@ -238,6 +238,7 @@ public abstract class SparkProcessContext<InputT, OutputT, ValueT>
           try {
             doFn.processElement(SparkProcessContext.this);
           } catch (Exception e) {
+            handleProcessingException(e);
             throw new SparkProcessException(e);
           }
           outputIterator = getOutputIterator();
@@ -249,15 +250,31 @@ public abstract class SparkProcessContext<InputT, OutputT, ValueT>
               calledFinish = true;
               doFn.finishBundle(SparkProcessContext.this);
             } catch (Exception e) {
+              handleProcessingException(e);
               throw new SparkProcessException(e);
             }
             outputIterator = getOutputIterator();
             continue; // try to consume outputIterator from start of loop
           }
+          try {
+            doFn.teardown();
+          } catch (Exception e) {
+            LOG.error(
+                "Suppressing teardown exception that occurred after processing entire input", e);
+          }
           return endOfData();
         }
       }
     }
+
+    private void handleProcessingException(Exception e) {
+      try {
+        doFn.teardown();
+      } catch (Exception e1) {
+        LOG.error("Exception while cleaning up DoFn", e1);
+        e.addSuppressed(e1);
+      }
+    }
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/12abb1b0/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java
index a06467e..80b67af 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java
@@ -342,6 +342,20 @@ public abstract class DoFn<InputT, OutputT> implements Serializable, HasDisplayD
 
   /////////////////////////////////////////////////////////////////////////////
 
+
+  /**
+   * Annotation for the method to use to prepare an instance for processing bundles of elements. The
+   * method annotated with this must satisfy the following constraints
+   * <ul>
+   *   <li>It must have zero arguments.
+   * </ul>
+   */
+  @Documented
+  @Retention(RetentionPolicy.RUNTIME)
+  @Target(ElementType.METHOD)
+  public @interface Setup {
+  }
+
   /**
    * Annotation for the method to use to prepare an instance for processing a batch of elements.
    * The method annotated with this must satisfy the following constraints:
@@ -371,7 +385,7 @@ public abstract class DoFn<InputT, OutputT> implements Serializable, HasDisplayD
   public @interface ProcessElement {}
 
   /**
-   * Annotation for the method to use to prepare an instance for processing a batch of elements.
+   * Annotation for the method to use to finish processing a batch of elements.
    * The method annotated with this must satisfy the following constraints:
    * <ul>
    *   <li>It must have at least one argument.
@@ -383,6 +397,21 @@ public abstract class DoFn<InputT, OutputT> implements Serializable, HasDisplayD
   @Target(ElementType.METHOD)
   public @interface FinishBundle {}
 
+
+  /**
+   * Annotation for the method to use to clean up this instance after processing bundles of
+   * elements. No other method will be called after a call to the annotated method is made.
+   * The method annotated with this must satisfy the following constraint:
+   * <ul>
+   *   <li>It must have zero arguments.
+   * </ul>
+   */
+  @Documented
+  @Retention(RetentionPolicy.RUNTIME)
+  @Target(ElementType.METHOD)
+  public @interface Teardown {
+  }
+
   /**
    * Returns an {@link Aggregator} with aggregation logic specified by the
    * {@link CombineFn} argument. The name provided must be unique across

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/12abb1b0/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnReflector.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnReflector.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnReflector.java
index 3dfda55..bf04041 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnReflector.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnReflector.java
@@ -17,11 +17,15 @@
  */
 package org.apache.beam.sdk.transforms;
 
+import static com.google.common.base.Preconditions.checkState;
+
 import org.apache.beam.sdk.options.PipelineOptions;
 import org.apache.beam.sdk.transforms.DoFn.ExtraContextFactory;
 import org.apache.beam.sdk.transforms.DoFn.FinishBundle;
 import org.apache.beam.sdk.transforms.DoFn.ProcessElement;
+import org.apache.beam.sdk.transforms.DoFn.Setup;
 import org.apache.beam.sdk.transforms.DoFn.StartBundle;
+import org.apache.beam.sdk.transforms.DoFn.Teardown;
 import org.apache.beam.sdk.transforms.display.DisplayData;
 import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
 import org.apache.beam.sdk.transforms.windowing.PaneInfo;
@@ -302,6 +306,15 @@ public abstract class DoFnReflector {
         new TypeParameter<OutputT>() {});
   }
 
+  @VisibleForTesting
+  static void verifyLifecycleMethodArguments(Method m) {
+    if (m == null) {
+      return;
+    }
+    checkState(void.class.equals(m.getReturnType()), "%s must have void return type", format(m));
+    checkState(m.getGenericParameterTypes().length == 0, "%s must take zero arguments", format(m));
+  }
+
   /**
    * Verify the method arguments for a given {@link DoFn} method.
    *
@@ -392,6 +405,8 @@ public abstract class DoFnReflector {
 
   /** Interface for invoking the {@code OldDoFn} processing methods. */
   public interface DoFnInvoker<InputT, OutputT>  {
+    /** Invoke {@link OldDoFn#setup} on the bound {@code OldDoFn}. */
+    void invokeSetup();
     /** Invoke {@link OldDoFn#startBundle} on the bound {@code OldDoFn}. */
     void invokeStartBundle(
         DoFn<InputT, OutputT>.Context c,
@@ -401,6 +416,9 @@ public abstract class DoFnReflector {
         DoFn<InputT, OutputT>.Context c,
         ExtraContextFactory<InputT, OutputT> extra);
 
+    /** Invoke {@link OldDoFn#teardown()} on the bound {@code DoFn}. */
+    void invokeTeardown();
+
     /** Invoke {@link OldDoFn#processElement} on the bound {@code OldDoFn}. */
     public void invokeProcessElement(
         DoFn<InputT, OutputT>.ProcessContext c,
@@ -412,9 +430,11 @@ public abstract class DoFnReflector {
    */
   private static class GenericDoFnReflector extends DoFnReflector {
 
+    private final Method setup;
     private final Method startBundle;
     private final Method processElement;
     private final Method finishBundle;
+    private final Method teardown;
     private final List<AdditionalParameter> processElementArgs;
     private final List<AdditionalParameter> startBundleArgs;
     private final List<AdditionalParameter> finishBundleArgs;
@@ -424,13 +444,17 @@ public abstract class DoFnReflector {
         @SuppressWarnings("rawtypes") Class<? extends DoFn> fn) {
       // Locate the annotated methods
       this.processElement = findAnnotatedMethod(ProcessElement.class, fn, true);
+      this.setup = findAnnotatedMethod(Setup.class, fn, false);
       this.startBundle = findAnnotatedMethod(StartBundle.class, fn, false);
       this.finishBundle = findAnnotatedMethod(FinishBundle.class, fn, false);
+      this.teardown = findAnnotatedMethod(Teardown.class, fn, false);
 
       // Verify that their method arguments satisfy our conditions.
       this.processElementArgs = verifyProcessMethodArguments(processElement);
       this.startBundleArgs = verifyBundleMethodArguments(startBundle);
       this.finishBundleArgs = verifyBundleMethodArguments(finishBundle);
+      verifyLifecycleMethodArguments(setup);
+      verifyLifecycleMethodArguments(teardown);
 
       this.constructor = createInvokerConstructor(fn);
     }
@@ -552,8 +576,17 @@ public abstract class DoFnReflector {
           .intercept(InvokerDelegation.create(
               startBundle, BeforeDelegation.INVOKE_PREPARE_FOR_PROCESSING, startBundleArgs))
           .method(ElementMatchers.named("invokeFinishBundle"))
-          .intercept(InvokerDelegation.create(
-              finishBundle, BeforeDelegation.NOOP, finishBundleArgs));
+          .intercept(InvokerDelegation.create(finishBundle,
+              BeforeDelegation.NOOP,
+              finishBundleArgs))
+          .method(ElementMatchers.named("invokeSetup"))
+          .intercept(InvokerDelegation.create(setup,
+              BeforeDelegation.NOOP,
+              Collections.<AdditionalParameter>emptyList()))
+          .method(ElementMatchers.named("invokeTeardown"))
+          .intercept(InvokerDelegation.create(teardown,
+              BeforeDelegation.NOOP,
+              Collections.<AdditionalParameter>emptyList()));
 
       @SuppressWarnings("unchecked")
       Class<? extends DoFnInvoker<?, ?>> dynamicClass = (Class<? extends DoFnInvoker<?, ?>>) builder
@@ -736,6 +769,11 @@ public abstract class DoFnReflector {
     }
 
     @Override
+    public void setup() throws Exception {
+      invoker.invokeSetup();
+    }
+
+    @Override
     public void startBundle(OldDoFn<InputT, OutputT>.Context c) throws Exception {
       ContextAdapter<InputT, OutputT> adapter = new ContextAdapter<>(fn, c);
       invoker.invokeStartBundle(adapter, adapter);
@@ -748,6 +786,11 @@ public abstract class DoFnReflector {
     }
 
     @Override
+    public void teardown() {
+      invoker.invokeTeardown();
+    }
+
+    @Override
     public void processElement(OldDoFn<InputT, OutputT>.ProcessContext c) throws Exception {
       ProcessContextAdapter<InputT, OutputT> adapter = new ProcessContextAdapter<>(fn, c);
       invoker.invokeProcessElement(adapter, adapter);
@@ -940,15 +983,20 @@ public abstract class DoFnReflector {
           new MethodDescription.ForLoadedMethod(target)).resolve(instrumentedMethod);
       ParameterList<?> params = targetMethod.getParameters();
 
-      // Instructions to setup the parameters for the call
-      ArrayList<StackManipulation> parameters = new ArrayList<>(args.size() + 1);
-      // 1. The first argument in the delegate method must be the context. This corresponds to
-      //    the first argument in the instrumented method, so copy that.
-      parameters.add(MethodVariableAccess.of(
-          params.get(0).getType().getSuperClass()).loadOffset(1));
-      // 2. For each of the extra arguments push the appropriate value.
-      for (AdditionalParameter arg : args) {
-        parameters.add(pushArgument(arg, instrumentedMethod));
+      List<StackManipulation> parameters;
+      if (!params.isEmpty()) {
+        // Instructions to setup the parameters for the call
+        parameters = new ArrayList<>(args.size() + 1);
+        // 1. The first argument in the delegate method must be the context. This corresponds to
+        //    the first argument in the instrumented method, so copy that.
+        parameters.add(MethodVariableAccess.of(params.get(0).getType().getSuperClass())
+            .loadOffset(1));
+        // 2. For each of the extra arguments push the appropriate value.
+        for (AdditionalParameter arg : args) {
+          parameters.add(pushArgument(arg, instrumentedMethod));
+        }
+      } else {
+        parameters = Collections.emptyList();
       }
 
       return new StackManipulation.Compound(

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/12abb1b0/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/OldDoFn.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/OldDoFn.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/OldDoFn.java
index 443599a..84cd997 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/OldDoFn.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/OldDoFn.java
@@ -339,6 +339,17 @@ public abstract class OldDoFn<InputT, OutputT> implements Serializable, HasDispl
   private boolean aggregatorsAreFinal;
 
   /**
+   * Prepares this {@link DoFn} instance for processing bundles.
+   *
+   * <p>{@link #setup()} will be called at most once per {@link DoFn} instance, and before any other
+   * {@link DoFn} method is called.
+   *
+   * <p>By default, does nothing.
+   */
+  public void setup() throws Exception {
+  }
+
+  /**
    * Prepares this {@code OldDoFn} instance for processing a batch of elements.
    *
    * <p>By default, does nothing.
@@ -373,6 +384,20 @@ public abstract class OldDoFn<InputT, OutputT> implements Serializable, HasDispl
   }
 
   /**
+   * Cleans up this {@link DoFn}.
+   *
+   * <p>{@link #teardown()} will be called before the {@link PipelineRunner} discards a {@link DoFn}
+   * instance, including due to another {@link DoFn} method throwing an {@link Exception}. No other
+   * {@link DoFn} methods will be called after a call to {@link #teardown()}.
+   *
+   * <p>By default, does nothing.
+   */
+  public void teardown() throws Exception {
+  }
+
+  /////////////////////////////////////////////////////////////////////////////
+
+  /**
    * {@inheritDoc}
    *
    * <p>By default, does not register any display data. Implementors may override this method

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/12abb1b0/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java
index ca6d9b2..aa57531 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java
@@ -67,11 +67,11 @@ import java.util.List;
  * For each bundle of input elements processing proceeds as follows:
  *
  * <ol>
- *   <li>If required, a fresh instance of the argument {@link OldDoFn} is created
- *     on a worker. This may be through deserialization or other means. A
- *     {@link PipelineRunner} may reuse {@link OldDoFn} instances for multiple bundles.
- *     A {@link OldDoFn} that has terminated abnormally (by throwing an {@link Exception}
- *     will never be reused.</li>
+ *   <li>If required, a fresh instance of the argument {@link DoFn} is created
+ *     on a worker, and {@link DoFn#setup()} is called on this instance. This may be through
+ *     deserialization or other means. A {@link PipelineRunner} may reuse {@link DoFn} instances for
+ *     multiple bundles. A {@link DoFn} that has terminated abnormally (by throwing an
+ *     {@link Exception}) will never be reused.</li>
  *   <li>The {@link OldDoFn OldDoFn's} {@link OldDoFn#startBundle} method is called to
  *     initialize it. If this method is not overridden, the call may be optimized
  *     away.</li>
@@ -83,6 +83,11 @@ import java.util.List;
  *     {@link OldDoFn#finishBundle}
  *     until a new call to {@link OldDoFn#startBundle} has occurred.
  *     If this method is not overridden, this call may be optimized away.</li>
+ *   <li>If any of {@link DoFn#setup}, {@link DoFn#startBundle}, {@link DoFn#processElement} or
+ *     {@link DoFn#finishBundle} throw an exception, {@link DoFn#teardown} will be called on the
+ *     {@link DoFn} instance.</li>
+ *   <li>If a runner will no longer use a {@link DoFn}, {@link DoFn#teardown()} will be called on
+ *     the discarded instance.</li>
  * </ol>
  *
  * Each of the calls to any of the {@link OldDoFn OldDoFn's} processing

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/12abb1b0/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/DoFnReflectorTest.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/DoFnReflectorTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/DoFnReflectorTest.java
index c47e0cf..e05e5e2 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/DoFnReflectorTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/DoFnReflectorTest.java
@@ -25,6 +25,8 @@ import org.apache.beam.sdk.transforms.DoFn.Context;
 import org.apache.beam.sdk.transforms.DoFn.ExtraContextFactory;
 import org.apache.beam.sdk.transforms.DoFn.ProcessContext;
 import org.apache.beam.sdk.transforms.DoFn.ProcessElement;
+import org.apache.beam.sdk.transforms.DoFn.Setup;
+import org.apache.beam.sdk.transforms.DoFn.Teardown;
 import org.apache.beam.sdk.transforms.dofnreflector.DoFnReflectorTestHelper;
 import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
 import org.apache.beam.sdk.util.UserCodeException;
@@ -53,6 +55,8 @@ public class DoFnReflectorTest {
     public boolean wasProcessElementInvoked = false;
     public boolean wasStartBundleInvoked = false;
     public boolean wasFinishBundleInvoked = false;
+    public boolean wasSetupInvoked = false;
+    public boolean wasTeardownInvoked = false;
     private final String name;
 
     public Invocations(String name) {
@@ -144,6 +148,33 @@ public class DoFnReflectorTest {
     }
   }
 
+  private void checkInvokeSetupWorks(DoFnReflector r, Invocations... invocations) throws Exception {
+    assertTrue("Need at least one invocation to check", invocations.length >= 1);
+    for (Invocations invocation : invocations) {
+      assertFalse("Should not yet have called setup on " + invocation.name,
+          invocation.wasSetupInvoked);
+    }
+    r.bindInvoker(fn).invokeSetup();
+    for (Invocations invocation : invocations) {
+      assertTrue("Should have called setup on " + invocation.name,
+          invocation.wasSetupInvoked);
+    }
+  }
+
+  private void checkInvokeTeardownWorks(DoFnReflector r, Invocations... invocations)
+      throws Exception {
+    assertTrue("Need at least one invocation to check", invocations.length >= 1);
+    for (Invocations invocation : invocations) {
+      assertFalse("Should not yet have called teardown on " + invocation.name,
+          invocation.wasTeardownInvoked);
+    }
+    r.bindInvoker(fn).invokeTeardown();
+    for (Invocations invocation : invocations) {
+      assertTrue("Should have called teardown on " + invocation.name,
+          invocation.wasTeardownInvoked);
+    }
+  }
+
   @Test
   public void testDoFnWithNoExtraContext() throws Exception {
     final Invocations invocations = new Invocations("AnonymousClass");
@@ -325,6 +356,40 @@ public class DoFnReflectorTest {
   }
 
   @Test
+  public void testDoFnWithSetupTeardown() throws Exception {
+    final Invocations invocations = new Invocations("AnonymousClass");
+    DoFnReflector reflector = underTest(new DoFn<String, String>() {
+      @ProcessElement
+      public void processElement(@SuppressWarnings("unused") ProcessContext c) {}
+
+      @StartBundle
+      public void startBundle(Context c) {
+        invocations.wasStartBundleInvoked = true;
+        assertSame(c, mockContext);
+      }
+
+      @FinishBundle
+      public void finishBundle(Context c) {
+        invocations.wasFinishBundleInvoked = true;
+        assertSame(c, mockContext);
+      }
+
+      @Setup
+      public void before() {
+        invocations.wasSetupInvoked = true;
+      }
+
+      @Teardown
+      public void after() {
+        invocations.wasTeardownInvoked = true;
+      }
+    });
+
+    checkInvokeSetupWorks(reflector, invocations);
+    checkInvokeTeardownWorks(reflector, invocations);
+  }
+
+  @Test
   public void testNoProcessElement() throws Exception {
     thrown.expect(IllegalStateException.class);
     thrown.expectMessage("No method annotated with @ProcessElement found");

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/12abb1b0/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java
index 7fe053c..8460124 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java
@@ -24,17 +24,18 @@ import static org.apache.beam.sdk.transforms.display.DisplayDataMatchers.include
 import static org.apache.beam.sdk.util.SerializableUtils.serializeToByteArray;
 import static org.apache.beam.sdk.util.StringUtils.byteArrayToJsonString;
 import static org.apache.beam.sdk.util.StringUtils.jsonStringToByteArray;
-
 import static com.google.common.base.Preconditions.checkNotNull;
-
 import static org.hamcrest.Matchers.allOf;
 import static org.hamcrest.Matchers.anyOf;
 import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.greaterThan;
+import static org.hamcrest.Matchers.is;
 import static org.hamcrest.collection.IsIterableContainingInAnyOrder.containsInAnyOrder;
 import static org.hamcrest.collection.IsIterableContainingInOrder.contains;
 import static org.junit.Assert.assertArrayEquals;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertThat;
+import static org.junit.Assert.fail;
 
 import org.apache.beam.sdk.Pipeline;
 import org.apache.beam.sdk.coders.AtomicCoder;
@@ -53,6 +54,7 @@ import org.apache.beam.sdk.transforms.windowing.FixedWindows;
 import org.apache.beam.sdk.transforms.windowing.Window;
 import org.apache.beam.sdk.util.common.ElementByteSizeObserver;
 import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PCollectionList;
 import org.apache.beam.sdk.values.PCollectionTuple;
 import org.apache.beam.sdk.values.PCollectionView;
 import org.apache.beam.sdk.values.TimestampedValue;
@@ -60,6 +62,7 @@ import org.apache.beam.sdk.values.TupleTag;
 import org.apache.beam.sdk.values.TupleTagList;
 
 import com.fasterxml.jackson.annotation.JsonCreator;
+
 import org.joda.time.Duration;
 import org.joda.time.Instant;
 import org.junit.Rule;
@@ -77,6 +80,7 @@ import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collections;
 import java.util.List;
+import java.util.concurrent.atomic.AtomicBoolean;
 
 /**
  * Tests for ParDo.
@@ -169,8 +173,10 @@ public class ParDoTest implements Serializable {
   }
 
   static class TestDoFn extends DoFn<Integer, String> {
-    enum State { UNSTARTED, STARTED, PROCESSING, FINISHED }
-    State state = State.UNSTARTED;
+    enum State {NOT_SET_UP, UNSTARTED, STARTED, PROCESSING, FINISHED}
+
+
+    State state = State.NOT_SET_UP;
 
     final List<PCollectionView<Integer>> sideInputViews = new ArrayList<>();
     final List<TupleTag<String>> sideOutputTupleTags = new ArrayList<>();
@@ -184,6 +190,12 @@ public class ParDoTest implements Serializable {
       this.sideOutputTupleTags.addAll(sideOutputTupleTags);
     }
 
+    @Setup
+    public void prepare() {
+      assertEquals(State.NOT_SET_UP, state);
+      state = State.UNSTARTED;
+    }
+
     @StartBundle
     public void startBundle(Context c) {
       assertEquals(State.UNSTARTED, state);
@@ -1463,4 +1475,404 @@ public class ParDoTest implements Serializable {
     assertThat(displayData, includesDisplayDataFrom(fn));
     assertThat(displayData, hasDisplayItem("fn", fn.getClass()));
   }
+
+  @Test
+  @Category(RunnableOnService.class)
+  public void testFnCallSequence() {
+    TestPipeline p = TestPipeline.create();
+    PCollectionList.of(p.apply("Impolite", Create.of(1, 2, 4)))
+        .and(p.apply("Polite", Create.of(3, 5, 6, 7)))
+        .apply(Flatten.<Integer>pCollections())
+        .apply(ParDo.of(new CallSequenceEnforcingOldFn<Integer>()));
+
+    p.run();
+  }
+
+  @Test
+  @Category(RunnableOnService.class)
+  public void testFnCallSequenceMulti() {
+    TestPipeline p = TestPipeline.create();
+    PCollectionList.of(p.apply("Impolite", Create.of(1, 2, 4)))
+        .and(p.apply("Polite", Create.of(3, 5, 6, 7)))
+        .apply(Flatten.<Integer>pCollections())
+        .apply(ParDo.of(new CallSequenceEnforcingOldFn<Integer>())
+                .withOutputTags(new TupleTag<Integer>() {}, TupleTagList.empty()));
+
+    p.run();
+  }
+
+  private static class CallSequenceEnforcingOldFn<T> extends OldDoFn<T, T> {
+    private boolean setupCalled = false;
+    private int startBundleCalls = 0;
+    private int finishBundleCalls = 0;
+    private boolean teardownCalled = false;
+
+    @Override
+    public void setup() {
+      assertThat("setup should not be called twice", setupCalled, is(false));
+      assertThat("setup should be called before startBundle", startBundleCalls, equalTo(0));
+      assertThat("setup should be called before finishBundle", finishBundleCalls, equalTo(0));
+      assertThat("setup should be called before teardown", teardownCalled, is(false));
+      setupCalled = true;
+    }
+
+    @Override
+    public void startBundle(Context c) {
+      assertThat("setup should have been called", setupCalled, is(true));
+      assertThat(
+          "Even number of startBundle and finishBundle calls in startBundle",
+          startBundleCalls,
+          equalTo(finishBundleCalls));
+      assertThat("teardown should not have been called", teardownCalled, is(false));
+      startBundleCalls++;
+    }
+
+    @Override
+    public void processElement(ProcessContext c) throws Exception {
+      assertThat("startBundle should have been called", startBundleCalls, greaterThan(0));
+      assertThat(
+          "there should be one startBundle call with no call to finishBundle",
+          startBundleCalls,
+          equalTo(finishBundleCalls + 1));
+      assertThat("teardown should not have been called", teardownCalled, is(false));
+    }
+
+    @Override
+    public void finishBundle(Context c) {
+      assertThat("startBundle should have been called", startBundleCalls, greaterThan(0));
+      assertThat(
+          "there should be one bundle that has been started but not finished",
+          startBundleCalls,
+          equalTo(finishBundleCalls + 1));
+      assertThat("teardown should not have been called", teardownCalled, is(false));
+      finishBundleCalls++;
+    }
+
+    @Override
+    public void teardown() {
+      assertThat(setupCalled, is(true));
+      assertThat(startBundleCalls, anyOf(equalTo(finishBundleCalls)));
+      assertThat(teardownCalled, is(false));
+      teardownCalled = true;
+    }
+  }
+
+  @Test
+  @Category(RunnableOnService.class)
+  public void testFnWithContextCallSequence() {
+    TestPipeline p = TestPipeline.create();
+    PCollectionList.of(p.apply("Impolite", Create.of(1, 2, 4)))
+        .and(p.apply("Polite", Create.of(3, 5, 6, 7)))
+        .apply(Flatten.<Integer>pCollections())
+        .apply(ParDo.of(new CallSequenceEnforcingFn<Integer>()));
+
+    p.run();
+  }
+
+  @Test
+  @Category(RunnableOnService.class)
+  public void testFnWithContextCallSequenceMulti() {
+    TestPipeline p = TestPipeline.create();
+    PCollectionList.of(p.apply("Impolite", Create.of(1, 2, 4)))
+        .and(p.apply("Polite", Create.of(3, 5, 6, 7)))
+        .apply(Flatten.<Integer>pCollections())
+        .apply(ParDo.of(new CallSequenceEnforcingFn<Integer>())
+            .withOutputTags(new TupleTag<Integer>() {
+            }, TupleTagList.empty()));
+
+    p.run();
+  }
+
+  private static class CallSequenceEnforcingFn<T> extends DoFn<T, T> {
+    private boolean setupCalled = false;
+    private int startBundleCalls = 0;
+    private int finishBundleCalls = 0;
+    private boolean teardownCalled = false;
+
+    @Setup
+    public void before() {
+      assertThat("setup should not be called twice", setupCalled, is(false));
+      assertThat("setup should be called before startBundle", startBundleCalls, equalTo(0));
+      assertThat("setup should be called before finishBundle", finishBundleCalls, equalTo(0));
+      assertThat("setup should be called before teardown", teardownCalled, is(false));
+      setupCalled = true;
+    }
+
+    @StartBundle
+    public void begin(Context c) {
+      assertThat("setup should have been called", setupCalled, is(true));
+      assertThat("Even number of startBundle and finishBundle calls in startBundle",
+          startBundleCalls,
+          equalTo(finishBundleCalls));
+      assertThat("teardown should not have been called", teardownCalled, is(false));
+      startBundleCalls++;
+    }
+
+    @ProcessElement
+    public void process(ProcessContext c) throws Exception {
+      assertThat("startBundle should have been called", startBundleCalls, greaterThan(0));
+      assertThat("there should be one startBundle call with no call to finishBundle",
+          startBundleCalls,
+          equalTo(finishBundleCalls + 1));
+      assertThat("teardown should not have been called", teardownCalled, is(false));
+    }
+
+    @FinishBundle
+    public void end(Context c) {
+      assertThat("startBundle should have been called", startBundleCalls, greaterThan(0));
+      assertThat("there should be one bundle that has been started but not finished",
+          startBundleCalls,
+          equalTo(finishBundleCalls + 1));
+      assertThat("teardown should not have been called", teardownCalled, is(false));
+      finishBundleCalls++;
+    }
+
+    @Teardown
+    public void after() {
+      assertThat(setupCalled, is(true));
+      assertThat(startBundleCalls, anyOf(equalTo(finishBundleCalls)));
+      assertThat(teardownCalled, is(false));
+      teardownCalled = true;
+    }
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void testTeardownCalledAfterExceptionInSetup() {
+    TestPipeline p = TestPipeline.create();
+    ExceptionThrowingOldFn fn = new ExceptionThrowingOldFn(MethodForException.SETUP);
+    p
+        .apply(Create.of(1, 2, 3))
+        .apply(ParDo.of(fn));
+    try {
+      p.run();
+      fail("Pipeline should have failed with an exception");
+    } catch (Exception e) {
+      assertThat(
+          "Function should have been torn down after exception",
+          ExceptionThrowingOldFn.teardownCalled.get(),
+          is(true));
+    }
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void testTeardownCalledAfterExceptionInStartBundle() {
+    TestPipeline p = TestPipeline.create();
+    ExceptionThrowingOldFn fn = new ExceptionThrowingOldFn(MethodForException.START_BUNDLE);
+    p
+        .apply(Create.of(1, 2, 3))
+        .apply(ParDo.of(fn));
+    try {
+      p.run();
+      fail("Pipeline should have failed with an exception");
+    } catch (Exception e) {
+      assertThat(
+          "Function should have been torn down after exception",
+          ExceptionThrowingOldFn.teardownCalled.get(),
+          is(true));
+    }
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void testTeardownCalledAfterExceptionInProcessElement() {
+    TestPipeline p = TestPipeline.create();
+    ExceptionThrowingOldFn fn = new ExceptionThrowingOldFn(MethodForException.PROCESS_ELEMENT);
+    p
+        .apply(Create.of(1, 2, 3))
+        .apply(ParDo.of(fn));
+    try {
+      p.run();
+      fail("Pipeline should have failed with an exception");
+    } catch (Exception e) {
+      assertThat(
+          "Function should have been torn down after exception",
+          ExceptionThrowingOldFn.teardownCalled.get(),
+          is(true));
+    }
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void testTeardownCalledAfterExceptionInFinishBundle() {
+    TestPipeline p = TestPipeline.create();
+    ExceptionThrowingOldFn fn = new ExceptionThrowingOldFn(MethodForException.FINISH_BUNDLE);
+    p
+        .apply(Create.of(1, 2, 3))
+        .apply(ParDo.of(fn));
+    try {
+      p.run();
+      fail("Pipeline should have failed with an exception");
+    } catch (Exception e) {
+      assertThat(
+          "Function should have been torn down after exception",
+          ExceptionThrowingOldFn.teardownCalled.get(),
+          is(true));
+    }
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void testWithContextTeardownCalledAfterExceptionInSetup() {
+    TestPipeline p = TestPipeline.create();
+    ExceptionThrowingOldFn fn = new ExceptionThrowingOldFn(MethodForException.SETUP);
+    p.apply(Create.of(1, 2, 3)).apply(ParDo.of(fn));
+    try {
+      p.run();
+      fail("Pipeline should have failed with an exception");
+    } catch (Exception e) {
+      assertThat("Function should have been torn down after exception",
+          ExceptionThrowingOldFn.teardownCalled.get(),
+          is(true));
+    }
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void testWithContextTeardownCalledAfterExceptionInStartBundle() {
+    TestPipeline p = TestPipeline.create();
+    ExceptionThrowingOldFn fn = new ExceptionThrowingOldFn(MethodForException.START_BUNDLE);
+    p.apply(Create.of(1, 2, 3)).apply(ParDo.of(fn));
+    try {
+      p.run();
+      fail("Pipeline should have failed with an exception");
+    } catch (Exception e) {
+      assertThat("Function should have been torn down after exception",
+          ExceptionThrowingOldFn.teardownCalled.get(),
+          is(true));
+    }
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void testWithContextTeardownCalledAfterExceptionInProcessElement() {
+    TestPipeline p = TestPipeline.create();
+    ExceptionThrowingOldFn fn = new ExceptionThrowingOldFn(MethodForException.PROCESS_ELEMENT);
+    p.apply(Create.of(1, 2, 3)).apply(ParDo.of(fn));
+    try {
+      p.run();
+      fail("Pipeline should have failed with an exception");
+    } catch (Exception e) {
+      assertThat("Function should have been torn down after exception",
+          ExceptionThrowingOldFn.teardownCalled.get(),
+          is(true));
+    }
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void testWithContextTeardownCalledAfterExceptionInFinishBundle() {
+    TestPipeline p = TestPipeline.create();
+    ExceptionThrowingOldFn fn = new ExceptionThrowingOldFn(MethodForException.FINISH_BUNDLE);
+    p.apply(Create.of(1, 2, 3)).apply(ParDo.of(fn));
+    try {
+      p.run();
+      fail("Pipeline should have failed with an exception");
+    } catch (Exception e) {
+      assertThat("Function should have been torn down after exception",
+          ExceptionThrowingOldFn.teardownCalled.get(),
+          is(true));
+    }
+  }
+
+  private static class ExceptionThrowingOldFn extends OldDoFn<Object, Object> {
+    static AtomicBoolean teardownCalled = new AtomicBoolean(false);
+
+    private final MethodForException toThrow;
+    private boolean thrown;
+
+    private ExceptionThrowingOldFn(MethodForException toThrow) {
+      this.toThrow = toThrow;
+    }
+
+    @Override
+    public void setup() throws Exception {
+      throwIfNecessary(MethodForException.SETUP);
+    }
+
+    @Override
+    public void startBundle(Context c) throws Exception {
+      throwIfNecessary(MethodForException.START_BUNDLE);
+    }
+
+    @Override
+    public void processElement(ProcessContext c) throws Exception {
+      throwIfNecessary(MethodForException.PROCESS_ELEMENT);
+    }
+
+    @Override
+    public void finishBundle(Context c) throws Exception {
+      throwIfNecessary(MethodForException.FINISH_BUNDLE);
+    }
+
+    private void throwIfNecessary(MethodForException method) throws Exception {
+      if (toThrow == method && !thrown) {
+        thrown = true;
+        throw new Exception("Hasn't yet thrown");
+      }
+    }
+
+    @Override
+    public void teardown() {
+      if (!thrown) {
+        fail("Excepted to have a processing method throw an exception");
+      }
+      teardownCalled.set(true);
+    }
+  }
+
+
+  private static class ExceptionThrowingFn extends DoFn<Object, Object> {
+    static AtomicBoolean teardownCalled = new AtomicBoolean(false);
+
+    private final MethodForException toThrow;
+    private boolean thrown;
+
+    private ExceptionThrowingFn(MethodForException toThrow) {
+      this.toThrow = toThrow;
+    }
+
+    @Setup
+    public void before() throws Exception {
+      throwIfNecessary(MethodForException.SETUP);
+    }
+
+    @StartBundle
+    public void preBundle(Context c) throws Exception {
+      throwIfNecessary(MethodForException.START_BUNDLE);
+    }
+
+    @ProcessElement
+    public void perElement(ProcessContext c) throws Exception {
+      throwIfNecessary(MethodForException.PROCESS_ELEMENT);
+    }
+
+    @FinishBundle
+    public void postBundle(Context c) throws Exception {
+      throwIfNecessary(MethodForException.FINISH_BUNDLE);
+    }
+
+    private void throwIfNecessary(MethodForException method) throws Exception {
+      if (toThrow == method && !thrown) {
+        thrown = true;
+        throw new Exception("Hasn't yet thrown");
+      }
+    }
+
+    @Teardown
+    public void after() {
+      if (!thrown) {
+        fail("Excepted to have a processing method throw an exception");
+      }
+      teardownCalled.set(true);
+    }
+  }
+
+  private enum MethodForException {
+    SETUP,
+    START_BUNDLE,
+    PROCESS_ELEMENT,
+    FINISH_BUNDLE
+  }
 }


[4/5] incubator-beam git commit: Move ParDo Lifecycle tests to their own file

Posted by dh...@apache.org.
Move ParDo Lifecycle tests to their own file

These tests are not yet functional in all runners, and this makes them
easier to ignore.


Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/29cbdceb
Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/29cbdceb
Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/29cbdceb

Branch: refs/heads/master
Commit: 29cbdceb5b78ce86ad0d90050d7542b0d5b45362
Parents: 12abb1b
Author: Thomas Groh <tg...@google.com>
Authored: Thu Aug 11 10:45:43 2016 -0700
Committer: Dan Halperin <dh...@google.com>
Committed: Mon Aug 15 14:16:54 2016 -0700

----------------------------------------------------------------------
 runners/google-cloud-dataflow-java/pom.xml      |  10 +
 .../beam/sdk/transforms/ParDoLifecycleTest.java | 448 +++++++++++++++++++
 .../apache/beam/sdk/transforms/ParDoTest.java   | 405 -----------------
 3 files changed, 458 insertions(+), 405 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/29cbdceb/runners/google-cloud-dataflow-java/pom.xml
----------------------------------------------------------------------
diff --git a/runners/google-cloud-dataflow-java/pom.xml b/runners/google-cloud-dataflow-java/pom.xml
index 86991b7..c32e184 100644
--- a/runners/google-cloud-dataflow-java/pom.xml
+++ b/runners/google-cloud-dataflow-java/pom.xml
@@ -60,6 +60,16 @@
             <beamUseDummyRunner>true</beamUseDummyRunner>
           </systemPropertyVariables>
         </configuration>
+        <executions>
+          <execution>
+            <id>runnable-on-service-tests</id>
+            <configuration>
+              <excludes>
+                <exclude>org/apache/beam/sdk/transforms/ParDoLifecycleTest.java</exclude>
+              </excludes>
+            </configuration>
+          </execution>
+        </executions>
       </plugin>
 
       <!-- Run CheckStyle pass on transforms, as they are release in

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/29cbdceb/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoLifecycleTest.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoLifecycleTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoLifecycleTest.java
new file mode 100644
index 0000000..272fea7
--- /dev/null
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoLifecycleTest.java
@@ -0,0 +1,448 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.sdk.transforms;
+
+import static org.hamcrest.Matchers.anyOf;
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.greaterThan;
+import static org.hamcrest.Matchers.is;
+import static org.junit.Assert.assertThat;
+import static org.junit.Assert.fail;
+
+import org.apache.beam.sdk.testing.NeedsRunner;
+import org.apache.beam.sdk.testing.RunnableOnService;
+import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.values.PCollectionList;
+import org.apache.beam.sdk.values.TupleTag;
+import org.apache.beam.sdk.values.TupleTagList;
+
+import org.junit.Test;
+import org.junit.experimental.categories.Category;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+import java.util.concurrent.atomic.AtomicBoolean;
+
+/**
+ * Tests that {@link ParDo} exercises {@link DoFn} methods in the appropriate sequence.
+ */
+@RunWith(JUnit4.class)
+public class ParDoLifecycleTest {
+  @Test
+  @Category(RunnableOnService.class)
+  public void testOldFnCallSequence() {
+    TestPipeline p = TestPipeline.create();
+    PCollectionList.of(p.apply("Impolite", Create.of(1, 2, 4)))
+        .and(p.apply("Polite", Create.of(3, 5, 6, 7)))
+        .apply(Flatten.<Integer>pCollections())
+        .apply(ParDo.of(new CallSequenceEnforcingOldFn<Integer>()));
+
+    p.run();
+  }
+
+  @Test
+  @Category(RunnableOnService.class)
+  public void testOldFnCallSequenceMulti() {
+    TestPipeline p = TestPipeline.create();
+    PCollectionList.of(p.apply("Impolite", Create.of(1, 2, 4)))
+        .and(p.apply("Polite", Create.of(3, 5, 6, 7)))
+        .apply(Flatten.<Integer>pCollections())
+        .apply(ParDo.of(new CallSequenceEnforcingOldFn<Integer>())
+            .withOutputTags(new TupleTag<Integer>() {}, TupleTagList.empty()));
+
+    p.run();
+  }
+
+  private static class CallSequenceEnforcingOldFn<T> extends OldDoFn<T, T> {
+    private boolean setupCalled = false;
+    private int startBundleCalls = 0;
+    private int finishBundleCalls = 0;
+    private boolean teardownCalled = false;
+
+    @Override
+    public void setup() {
+      assertThat("setup should not be called twice", setupCalled, is(false));
+      assertThat("setup should be called before startBundle", startBundleCalls, equalTo(0));
+      assertThat("setup should be called before finishBundle", finishBundleCalls, equalTo(0));
+      assertThat("setup should be called before teardown", teardownCalled, is(false));
+      setupCalled = true;
+    }
+
+    @Override
+    public void startBundle(Context c) {
+      assertThat("setup should have been called", setupCalled, is(true));
+      assertThat(
+          "Even number of startBundle and finishBundle calls in startBundle",
+          startBundleCalls,
+          equalTo(finishBundleCalls));
+      assertThat("teardown should not have been called", teardownCalled, is(false));
+      startBundleCalls++;
+    }
+
+    @Override
+    public void processElement(ProcessContext c) throws Exception {
+      assertThat("startBundle should have been called", startBundleCalls, greaterThan(0));
+      assertThat(
+          "there should be one startBundle call with no call to finishBundle",
+          startBundleCalls,
+          equalTo(finishBundleCalls + 1));
+      assertThat("teardown should not have been called", teardownCalled, is(false));
+    }
+
+    @Override
+    public void finishBundle(Context c) {
+      assertThat("startBundle should have been called", startBundleCalls, greaterThan(0));
+      assertThat(
+          "there should be one bundle that has been started but not finished",
+          startBundleCalls,
+          equalTo(finishBundleCalls + 1));
+      assertThat("teardown should not have been called", teardownCalled, is(false));
+      finishBundleCalls++;
+    }
+
+    @Override
+    public void teardown() {
+      assertThat(setupCalled, is(true));
+      assertThat(startBundleCalls, anyOf(equalTo(finishBundleCalls)));
+      assertThat(teardownCalled, is(false));
+      teardownCalled = true;
+    }
+  }
+
+  @Test
+  @Category(RunnableOnService.class)
+  public void testFnCallSequence() {
+    TestPipeline p = TestPipeline.create();
+    PCollectionList.of(p.apply("Impolite", Create.of(1, 2, 4)))
+        .and(p.apply("Polite", Create.of(3, 5, 6, 7)))
+        .apply(Flatten.<Integer>pCollections())
+        .apply(ParDo.of(new CallSequenceEnforcingFn<Integer>()));
+
+    p.run();
+  }
+
+  @Test
+  @Category(RunnableOnService.class)
+  public void testFnCallSequenceMulti() {
+    TestPipeline p = TestPipeline.create();
+    PCollectionList.of(p.apply("Impolite", Create.of(1, 2, 4)))
+        .and(p.apply("Polite", Create.of(3, 5, 6, 7)))
+        .apply(Flatten.<Integer>pCollections())
+        .apply(ParDo.of(new CallSequenceEnforcingFn<Integer>())
+            .withOutputTags(new TupleTag<Integer>() {
+            }, TupleTagList.empty()));
+
+    p.run();
+  }
+
+  private static class CallSequenceEnforcingFn<T> extends DoFn<T, T> {
+    private boolean setupCalled = false;
+    private int startBundleCalls = 0;
+    private int finishBundleCalls = 0;
+    private boolean teardownCalled = false;
+
+    @Setup
+    public void before() {
+      assertThat("setup should not be called twice", setupCalled, is(false));
+      assertThat("setup should be called before startBundle", startBundleCalls, equalTo(0));
+      assertThat("setup should be called before finishBundle", finishBundleCalls, equalTo(0));
+      assertThat("setup should be called before teardown", teardownCalled, is(false));
+      setupCalled = true;
+    }
+
+    @StartBundle
+    public void begin(Context c) {
+      assertThat("setup should have been called", setupCalled, is(true));
+      assertThat("Even number of startBundle and finishBundle calls in startBundle",
+          startBundleCalls,
+          equalTo(finishBundleCalls));
+      assertThat("teardown should not have been called", teardownCalled, is(false));
+      startBundleCalls++;
+    }
+
+    @ProcessElement
+    public void process(ProcessContext c) throws Exception {
+      assertThat("startBundle should have been called", startBundleCalls, greaterThan(0));
+      assertThat("there should be one startBundle call with no call to finishBundle",
+          startBundleCalls,
+          equalTo(finishBundleCalls + 1));
+      assertThat("teardown should not have been called", teardownCalled, is(false));
+    }
+
+    @FinishBundle
+    public void end(Context c) {
+      assertThat("startBundle should have been called", startBundleCalls, greaterThan(0));
+      assertThat("there should be one bundle that has been started but not finished",
+          startBundleCalls,
+          equalTo(finishBundleCalls + 1));
+      assertThat("teardown should not have been called", teardownCalled, is(false));
+      finishBundleCalls++;
+    }
+
+    @Teardown
+    public void after() {
+      assertThat(setupCalled, is(true));
+      assertThat(startBundleCalls, anyOf(equalTo(finishBundleCalls)));
+      assertThat(teardownCalled, is(false));
+      teardownCalled = true;
+    }
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void testTeardownCalledAfterExceptionInSetup() {
+    TestPipeline p = TestPipeline.create();
+    ExceptionThrowingOldFn fn = new ExceptionThrowingOldFn(MethodForException.SETUP);
+    p
+        .apply(Create.of(1, 2, 3))
+        .apply(ParDo.of(fn));
+    try {
+      p.run();
+      fail("Pipeline should have failed with an exception");
+    } catch (Exception e) {
+      assertThat(
+          "Function should have been torn down after exception",
+          ExceptionThrowingOldFn.teardownCalled.get(),
+          is(true));
+    }
+  }
+
+
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void testTeardownCalledAfterExceptionInStartBundle() {
+    TestPipeline p = TestPipeline.create();
+    ExceptionThrowingOldFn fn = new ExceptionThrowingOldFn(MethodForException.START_BUNDLE);
+    p
+        .apply(Create.of(1, 2, 3))
+        .apply(ParDo.of(fn));
+    try {
+      p.run();
+      fail("Pipeline should have failed with an exception");
+    } catch (Exception e) {
+      assertThat(
+          "Function should have been torn down after exception",
+          ExceptionThrowingOldFn.teardownCalled.get(),
+          is(true));
+    }
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void testTeardownCalledAfterExceptionInProcessElement() {
+    TestPipeline p = TestPipeline.create();
+    ExceptionThrowingOldFn fn = new ExceptionThrowingOldFn(MethodForException.PROCESS_ELEMENT);
+    p
+        .apply(Create.of(1, 2, 3))
+        .apply(ParDo.of(fn));
+    try {
+      p.run();
+      fail("Pipeline should have failed with an exception");
+    } catch (Exception e) {
+      assertThat(
+          "Function should have been torn down after exception",
+          ExceptionThrowingOldFn.teardownCalled.get(),
+          is(true));
+    }
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void testTeardownCalledAfterExceptionInFinishBundle() {
+    TestPipeline p = TestPipeline.create();
+    ExceptionThrowingOldFn fn = new ExceptionThrowingOldFn(MethodForException.FINISH_BUNDLE);
+    p
+        .apply(Create.of(1, 2, 3))
+        .apply(ParDo.of(fn));
+    try {
+      p.run();
+      fail("Pipeline should have failed with an exception");
+    } catch (Exception e) {
+      assertThat(
+          "Function should have been torn down after exception",
+          ExceptionThrowingOldFn.teardownCalled.get(),
+          is(true));
+    }
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void testWithContextTeardownCalledAfterExceptionInSetup() {
+    TestPipeline p = TestPipeline.create();
+    ExceptionThrowingOldFn fn = new ExceptionThrowingOldFn(MethodForException.SETUP);
+    p.apply(Create.of(1, 2, 3)).apply(ParDo.of(fn));
+    try {
+      p.run();
+      fail("Pipeline should have failed with an exception");
+    } catch (Exception e) {
+      assertThat("Function should have been torn down after exception",
+          ExceptionThrowingOldFn.teardownCalled.get(),
+          is(true));
+    }
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void testWithContextTeardownCalledAfterExceptionInStartBundle() {
+    TestPipeline p = TestPipeline.create();
+    ExceptionThrowingOldFn fn = new ExceptionThrowingOldFn(MethodForException.START_BUNDLE);
+    p.apply(Create.of(1, 2, 3)).apply(ParDo.of(fn));
+    try {
+      p.run();
+      fail("Pipeline should have failed with an exception");
+    } catch (Exception e) {
+      assertThat("Function should have been torn down after exception",
+          ExceptionThrowingOldFn.teardownCalled.get(),
+          is(true));
+    }
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void testWithContextTeardownCalledAfterExceptionInProcessElement() {
+    TestPipeline p = TestPipeline.create();
+    ExceptionThrowingOldFn fn = new ExceptionThrowingOldFn(MethodForException.PROCESS_ELEMENT);
+    p.apply(Create.of(1, 2, 3)).apply(ParDo.of(fn));
+    try {
+      p.run();
+      fail("Pipeline should have failed with an exception");
+    } catch (Exception e) {
+      assertThat("Function should have been torn down after exception",
+          ExceptionThrowingOldFn.teardownCalled.get(),
+          is(true));
+    }
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void testWithContextTeardownCalledAfterExceptionInFinishBundle() {
+    TestPipeline p = TestPipeline.create();
+    ExceptionThrowingOldFn fn = new ExceptionThrowingOldFn(MethodForException.FINISH_BUNDLE);
+    p.apply(Create.of(1, 2, 3)).apply(ParDo.of(fn));
+    try {
+      p.run();
+      fail("Pipeline should have failed with an exception");
+    } catch (Exception e) {
+      assertThat("Function should have been torn down after exception",
+          ExceptionThrowingOldFn.teardownCalled.get(),
+          is(true));
+    }
+  }
+
+  private static class ExceptionThrowingOldFn extends OldDoFn<Object, Object> {
+    static AtomicBoolean teardownCalled = new AtomicBoolean(false);
+
+    private final MethodForException toThrow;
+    private boolean thrown;
+
+    private ExceptionThrowingOldFn(MethodForException toThrow) {
+      this.toThrow = toThrow;
+    }
+
+    @Override
+    public void setup() throws Exception {
+      throwIfNecessary(MethodForException.SETUP);
+    }
+
+    @Override
+    public void startBundle(Context c) throws Exception {
+      throwIfNecessary(MethodForException.START_BUNDLE);
+    }
+
+    @Override
+    public void processElement(ProcessContext c) throws Exception {
+      throwIfNecessary(MethodForException.PROCESS_ELEMENT);
+    }
+
+    @Override
+    public void finishBundle(Context c) throws Exception {
+      throwIfNecessary(MethodForException.FINISH_BUNDLE);
+    }
+
+    private void throwIfNecessary(MethodForException method) throws Exception {
+      if (toThrow == method && !thrown) {
+        thrown = true;
+        throw new Exception("Hasn't yet thrown");
+      }
+    }
+
+    @Override
+    public void teardown() {
+      if (!thrown) {
+        fail("Excepted to have a processing method throw an exception");
+      }
+      teardownCalled.set(true);
+    }
+  }
+
+
+  private static class ExceptionThrowingFn extends DoFn<Object, Object> {
+    static AtomicBoolean teardownCalled = new AtomicBoolean(false);
+
+    private final MethodForException toThrow;
+    private boolean thrown;
+
+    private ExceptionThrowingFn(MethodForException toThrow) {
+      this.toThrow = toThrow;
+    }
+
+    @Setup
+    public void before() throws Exception {
+      throwIfNecessary(MethodForException.SETUP);
+    }
+
+    @StartBundle
+    public void preBundle(Context c) throws Exception {
+      throwIfNecessary(MethodForException.START_BUNDLE);
+    }
+
+    @ProcessElement
+    public void perElement(ProcessContext c) throws Exception {
+      throwIfNecessary(MethodForException.PROCESS_ELEMENT);
+    }
+
+    @FinishBundle
+    public void postBundle(Context c) throws Exception {
+      throwIfNecessary(MethodForException.FINISH_BUNDLE);
+    }
+
+    private void throwIfNecessary(MethodForException method) throws Exception {
+      if (toThrow == method && !thrown) {
+        thrown = true;
+        throw new Exception("Hasn't yet thrown");
+      }
+    }
+
+    @Teardown
+    public void after() {
+      if (!thrown) {
+        fail("Excepted to have a processing method throw an exception");
+      }
+      teardownCalled.set(true);
+    }
+  }
+
+  private enum MethodForException {
+    SETUP,
+    START_BUNDLE,
+    PROCESS_ELEMENT,
+    FINISH_BUNDLE
+  }
+}

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/29cbdceb/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java
index 8460124..c384114 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java
@@ -28,14 +28,11 @@ import static com.google.common.base.Preconditions.checkNotNull;
 import static org.hamcrest.Matchers.allOf;
 import static org.hamcrest.Matchers.anyOf;
 import static org.hamcrest.Matchers.equalTo;
-import static org.hamcrest.Matchers.greaterThan;
-import static org.hamcrest.Matchers.is;
 import static org.hamcrest.collection.IsIterableContainingInAnyOrder.containsInAnyOrder;
 import static org.hamcrest.collection.IsIterableContainingInOrder.contains;
 import static org.junit.Assert.assertArrayEquals;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertThat;
-import static org.junit.Assert.fail;
 
 import org.apache.beam.sdk.Pipeline;
 import org.apache.beam.sdk.coders.AtomicCoder;
@@ -54,7 +51,6 @@ import org.apache.beam.sdk.transforms.windowing.FixedWindows;
 import org.apache.beam.sdk.transforms.windowing.Window;
 import org.apache.beam.sdk.util.common.ElementByteSizeObserver;
 import org.apache.beam.sdk.values.PCollection;
-import org.apache.beam.sdk.values.PCollectionList;
 import org.apache.beam.sdk.values.PCollectionTuple;
 import org.apache.beam.sdk.values.PCollectionView;
 import org.apache.beam.sdk.values.TimestampedValue;
@@ -80,7 +76,6 @@ import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collections;
 import java.util.List;
-import java.util.concurrent.atomic.AtomicBoolean;
 
 /**
  * Tests for ParDo.
@@ -1475,404 +1470,4 @@ public class ParDoTest implements Serializable {
     assertThat(displayData, includesDisplayDataFrom(fn));
     assertThat(displayData, hasDisplayItem("fn", fn.getClass()));
   }
-
-  @Test
-  @Category(RunnableOnService.class)
-  public void testFnCallSequence() {
-    TestPipeline p = TestPipeline.create();
-    PCollectionList.of(p.apply("Impolite", Create.of(1, 2, 4)))
-        .and(p.apply("Polite", Create.of(3, 5, 6, 7)))
-        .apply(Flatten.<Integer>pCollections())
-        .apply(ParDo.of(new CallSequenceEnforcingOldFn<Integer>()));
-
-    p.run();
-  }
-
-  @Test
-  @Category(RunnableOnService.class)
-  public void testFnCallSequenceMulti() {
-    TestPipeline p = TestPipeline.create();
-    PCollectionList.of(p.apply("Impolite", Create.of(1, 2, 4)))
-        .and(p.apply("Polite", Create.of(3, 5, 6, 7)))
-        .apply(Flatten.<Integer>pCollections())
-        .apply(ParDo.of(new CallSequenceEnforcingOldFn<Integer>())
-                .withOutputTags(new TupleTag<Integer>() {}, TupleTagList.empty()));
-
-    p.run();
-  }
-
-  private static class CallSequenceEnforcingOldFn<T> extends OldDoFn<T, T> {
-    private boolean setupCalled = false;
-    private int startBundleCalls = 0;
-    private int finishBundleCalls = 0;
-    private boolean teardownCalled = false;
-
-    @Override
-    public void setup() {
-      assertThat("setup should not be called twice", setupCalled, is(false));
-      assertThat("setup should be called before startBundle", startBundleCalls, equalTo(0));
-      assertThat("setup should be called before finishBundle", finishBundleCalls, equalTo(0));
-      assertThat("setup should be called before teardown", teardownCalled, is(false));
-      setupCalled = true;
-    }
-
-    @Override
-    public void startBundle(Context c) {
-      assertThat("setup should have been called", setupCalled, is(true));
-      assertThat(
-          "Even number of startBundle and finishBundle calls in startBundle",
-          startBundleCalls,
-          equalTo(finishBundleCalls));
-      assertThat("teardown should not have been called", teardownCalled, is(false));
-      startBundleCalls++;
-    }
-
-    @Override
-    public void processElement(ProcessContext c) throws Exception {
-      assertThat("startBundle should have been called", startBundleCalls, greaterThan(0));
-      assertThat(
-          "there should be one startBundle call with no call to finishBundle",
-          startBundleCalls,
-          equalTo(finishBundleCalls + 1));
-      assertThat("teardown should not have been called", teardownCalled, is(false));
-    }
-
-    @Override
-    public void finishBundle(Context c) {
-      assertThat("startBundle should have been called", startBundleCalls, greaterThan(0));
-      assertThat(
-          "there should be one bundle that has been started but not finished",
-          startBundleCalls,
-          equalTo(finishBundleCalls + 1));
-      assertThat("teardown should not have been called", teardownCalled, is(false));
-      finishBundleCalls++;
-    }
-
-    @Override
-    public void teardown() {
-      assertThat(setupCalled, is(true));
-      assertThat(startBundleCalls, anyOf(equalTo(finishBundleCalls)));
-      assertThat(teardownCalled, is(false));
-      teardownCalled = true;
-    }
-  }
-
-  @Test
-  @Category(RunnableOnService.class)
-  public void testFnWithContextCallSequence() {
-    TestPipeline p = TestPipeline.create();
-    PCollectionList.of(p.apply("Impolite", Create.of(1, 2, 4)))
-        .and(p.apply("Polite", Create.of(3, 5, 6, 7)))
-        .apply(Flatten.<Integer>pCollections())
-        .apply(ParDo.of(new CallSequenceEnforcingFn<Integer>()));
-
-    p.run();
-  }
-
-  @Test
-  @Category(RunnableOnService.class)
-  public void testFnWithContextCallSequenceMulti() {
-    TestPipeline p = TestPipeline.create();
-    PCollectionList.of(p.apply("Impolite", Create.of(1, 2, 4)))
-        .and(p.apply("Polite", Create.of(3, 5, 6, 7)))
-        .apply(Flatten.<Integer>pCollections())
-        .apply(ParDo.of(new CallSequenceEnforcingFn<Integer>())
-            .withOutputTags(new TupleTag<Integer>() {
-            }, TupleTagList.empty()));
-
-    p.run();
-  }
-
-  private static class CallSequenceEnforcingFn<T> extends DoFn<T, T> {
-    private boolean setupCalled = false;
-    private int startBundleCalls = 0;
-    private int finishBundleCalls = 0;
-    private boolean teardownCalled = false;
-
-    @Setup
-    public void before() {
-      assertThat("setup should not be called twice", setupCalled, is(false));
-      assertThat("setup should be called before startBundle", startBundleCalls, equalTo(0));
-      assertThat("setup should be called before finishBundle", finishBundleCalls, equalTo(0));
-      assertThat("setup should be called before teardown", teardownCalled, is(false));
-      setupCalled = true;
-    }
-
-    @StartBundle
-    public void begin(Context c) {
-      assertThat("setup should have been called", setupCalled, is(true));
-      assertThat("Even number of startBundle and finishBundle calls in startBundle",
-          startBundleCalls,
-          equalTo(finishBundleCalls));
-      assertThat("teardown should not have been called", teardownCalled, is(false));
-      startBundleCalls++;
-    }
-
-    @ProcessElement
-    public void process(ProcessContext c) throws Exception {
-      assertThat("startBundle should have been called", startBundleCalls, greaterThan(0));
-      assertThat("there should be one startBundle call with no call to finishBundle",
-          startBundleCalls,
-          equalTo(finishBundleCalls + 1));
-      assertThat("teardown should not have been called", teardownCalled, is(false));
-    }
-
-    @FinishBundle
-    public void end(Context c) {
-      assertThat("startBundle should have been called", startBundleCalls, greaterThan(0));
-      assertThat("there should be one bundle that has been started but not finished",
-          startBundleCalls,
-          equalTo(finishBundleCalls + 1));
-      assertThat("teardown should not have been called", teardownCalled, is(false));
-      finishBundleCalls++;
-    }
-
-    @Teardown
-    public void after() {
-      assertThat(setupCalled, is(true));
-      assertThat(startBundleCalls, anyOf(equalTo(finishBundleCalls)));
-      assertThat(teardownCalled, is(false));
-      teardownCalled = true;
-    }
-  }
-
-  @Test
-  @Category(NeedsRunner.class)
-  public void testTeardownCalledAfterExceptionInSetup() {
-    TestPipeline p = TestPipeline.create();
-    ExceptionThrowingOldFn fn = new ExceptionThrowingOldFn(MethodForException.SETUP);
-    p
-        .apply(Create.of(1, 2, 3))
-        .apply(ParDo.of(fn));
-    try {
-      p.run();
-      fail("Pipeline should have failed with an exception");
-    } catch (Exception e) {
-      assertThat(
-          "Function should have been torn down after exception",
-          ExceptionThrowingOldFn.teardownCalled.get(),
-          is(true));
-    }
-  }
-
-  @Test
-  @Category(NeedsRunner.class)
-  public void testTeardownCalledAfterExceptionInStartBundle() {
-    TestPipeline p = TestPipeline.create();
-    ExceptionThrowingOldFn fn = new ExceptionThrowingOldFn(MethodForException.START_BUNDLE);
-    p
-        .apply(Create.of(1, 2, 3))
-        .apply(ParDo.of(fn));
-    try {
-      p.run();
-      fail("Pipeline should have failed with an exception");
-    } catch (Exception e) {
-      assertThat(
-          "Function should have been torn down after exception",
-          ExceptionThrowingOldFn.teardownCalled.get(),
-          is(true));
-    }
-  }
-
-  @Test
-  @Category(NeedsRunner.class)
-  public void testTeardownCalledAfterExceptionInProcessElement() {
-    TestPipeline p = TestPipeline.create();
-    ExceptionThrowingOldFn fn = new ExceptionThrowingOldFn(MethodForException.PROCESS_ELEMENT);
-    p
-        .apply(Create.of(1, 2, 3))
-        .apply(ParDo.of(fn));
-    try {
-      p.run();
-      fail("Pipeline should have failed with an exception");
-    } catch (Exception e) {
-      assertThat(
-          "Function should have been torn down after exception",
-          ExceptionThrowingOldFn.teardownCalled.get(),
-          is(true));
-    }
-  }
-
-  @Test
-  @Category(NeedsRunner.class)
-  public void testTeardownCalledAfterExceptionInFinishBundle() {
-    TestPipeline p = TestPipeline.create();
-    ExceptionThrowingOldFn fn = new ExceptionThrowingOldFn(MethodForException.FINISH_BUNDLE);
-    p
-        .apply(Create.of(1, 2, 3))
-        .apply(ParDo.of(fn));
-    try {
-      p.run();
-      fail("Pipeline should have failed with an exception");
-    } catch (Exception e) {
-      assertThat(
-          "Function should have been torn down after exception",
-          ExceptionThrowingOldFn.teardownCalled.get(),
-          is(true));
-    }
-  }
-
-  @Test
-  @Category(NeedsRunner.class)
-  public void testWithContextTeardownCalledAfterExceptionInSetup() {
-    TestPipeline p = TestPipeline.create();
-    ExceptionThrowingOldFn fn = new ExceptionThrowingOldFn(MethodForException.SETUP);
-    p.apply(Create.of(1, 2, 3)).apply(ParDo.of(fn));
-    try {
-      p.run();
-      fail("Pipeline should have failed with an exception");
-    } catch (Exception e) {
-      assertThat("Function should have been torn down after exception",
-          ExceptionThrowingOldFn.teardownCalled.get(),
-          is(true));
-    }
-  }
-
-  @Test
-  @Category(NeedsRunner.class)
-  public void testWithContextTeardownCalledAfterExceptionInStartBundle() {
-    TestPipeline p = TestPipeline.create();
-    ExceptionThrowingOldFn fn = new ExceptionThrowingOldFn(MethodForException.START_BUNDLE);
-    p.apply(Create.of(1, 2, 3)).apply(ParDo.of(fn));
-    try {
-      p.run();
-      fail("Pipeline should have failed with an exception");
-    } catch (Exception e) {
-      assertThat("Function should have been torn down after exception",
-          ExceptionThrowingOldFn.teardownCalled.get(),
-          is(true));
-    }
-  }
-
-  @Test
-  @Category(NeedsRunner.class)
-  public void testWithContextTeardownCalledAfterExceptionInProcessElement() {
-    TestPipeline p = TestPipeline.create();
-    ExceptionThrowingOldFn fn = new ExceptionThrowingOldFn(MethodForException.PROCESS_ELEMENT);
-    p.apply(Create.of(1, 2, 3)).apply(ParDo.of(fn));
-    try {
-      p.run();
-      fail("Pipeline should have failed with an exception");
-    } catch (Exception e) {
-      assertThat("Function should have been torn down after exception",
-          ExceptionThrowingOldFn.teardownCalled.get(),
-          is(true));
-    }
-  }
-
-  @Test
-  @Category(NeedsRunner.class)
-  public void testWithContextTeardownCalledAfterExceptionInFinishBundle() {
-    TestPipeline p = TestPipeline.create();
-    ExceptionThrowingOldFn fn = new ExceptionThrowingOldFn(MethodForException.FINISH_BUNDLE);
-    p.apply(Create.of(1, 2, 3)).apply(ParDo.of(fn));
-    try {
-      p.run();
-      fail("Pipeline should have failed with an exception");
-    } catch (Exception e) {
-      assertThat("Function should have been torn down after exception",
-          ExceptionThrowingOldFn.teardownCalled.get(),
-          is(true));
-    }
-  }
-
-  private static class ExceptionThrowingOldFn extends OldDoFn<Object, Object> {
-    static AtomicBoolean teardownCalled = new AtomicBoolean(false);
-
-    private final MethodForException toThrow;
-    private boolean thrown;
-
-    private ExceptionThrowingOldFn(MethodForException toThrow) {
-      this.toThrow = toThrow;
-    }
-
-    @Override
-    public void setup() throws Exception {
-      throwIfNecessary(MethodForException.SETUP);
-    }
-
-    @Override
-    public void startBundle(Context c) throws Exception {
-      throwIfNecessary(MethodForException.START_BUNDLE);
-    }
-
-    @Override
-    public void processElement(ProcessContext c) throws Exception {
-      throwIfNecessary(MethodForException.PROCESS_ELEMENT);
-    }
-
-    @Override
-    public void finishBundle(Context c) throws Exception {
-      throwIfNecessary(MethodForException.FINISH_BUNDLE);
-    }
-
-    private void throwIfNecessary(MethodForException method) throws Exception {
-      if (toThrow == method && !thrown) {
-        thrown = true;
-        throw new Exception("Hasn't yet thrown");
-      }
-    }
-
-    @Override
-    public void teardown() {
-      if (!thrown) {
-        fail("Excepted to have a processing method throw an exception");
-      }
-      teardownCalled.set(true);
-    }
-  }
-
-
-  private static class ExceptionThrowingFn extends DoFn<Object, Object> {
-    static AtomicBoolean teardownCalled = new AtomicBoolean(false);
-
-    private final MethodForException toThrow;
-    private boolean thrown;
-
-    private ExceptionThrowingFn(MethodForException toThrow) {
-      this.toThrow = toThrow;
-    }
-
-    @Setup
-    public void before() throws Exception {
-      throwIfNecessary(MethodForException.SETUP);
-    }
-
-    @StartBundle
-    public void preBundle(Context c) throws Exception {
-      throwIfNecessary(MethodForException.START_BUNDLE);
-    }
-
-    @ProcessElement
-    public void perElement(ProcessContext c) throws Exception {
-      throwIfNecessary(MethodForException.PROCESS_ELEMENT);
-    }
-
-    @FinishBundle
-    public void postBundle(Context c) throws Exception {
-      throwIfNecessary(MethodForException.FINISH_BUNDLE);
-    }
-
-    private void throwIfNecessary(MethodForException method) throws Exception {
-      if (toThrow == method && !thrown) {
-        thrown = true;
-        throw new Exception("Hasn't yet thrown");
-      }
-    }
-
-    @Teardown
-    public void after() {
-      if (!thrown) {
-        fail("Excepted to have a processing method throw an exception");
-      }
-      teardownCalled.set(true);
-    }
-  }
-
-  private enum MethodForException {
-    SETUP,
-    START_BUNDLE,
-    PROCESS_ELEMENT,
-    FINISH_BUNDLE
-  }
 }