You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by dh...@apache.org on 2016/08/15 21:17:10 UTC

[5/5] incubator-beam git commit: Replace CloningThreadLocal with DoFnLifecycleManager

Replace CloningThreadLocal with DoFnLifecycleManager

This is a more focused interface that interacts with a DoFn before it
is available for use and after it has completed and the reference is
lost. It is required to properly support setup and teardown, as the
fields in a ThreadLocal cannot all be cleaned up without additional
tracking.

Part of BEAM-452.


Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/cf0bf3bf
Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/cf0bf3bf
Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/cf0bf3bf

Branch: refs/heads/master
Commit: cf0bf3bf9fcab2b01d69ff90d9ba3f602a8a5bd4
Parents: 12b1967
Author: Thomas Groh <tg...@google.com>
Authored: Tue Jul 19 11:03:15 2016 -0700
Committer: Dan Halperin <dh...@google.com>
Committed: Mon Aug 15 14:16:54 2016 -0700

----------------------------------------------------------------------
 .../beam/runners/direct/CloningThreadLocal.java |  43 ------
 .../runners/direct/DoFnLifecycleManager.java    |  78 ++++++++++
 ...ecycleManagerRemovingTransformEvaluator.java |  80 +++++++++++
 .../direct/ParDoMultiEvaluatorFactory.java      |  56 ++++----
 .../direct/ParDoSingleEvaluatorFactory.java     |  43 +++---
 ...readLocalInvalidatingTransformEvaluator.java |  63 --------
 .../runners/direct/CloningThreadLocalTest.java  |  92 ------------
 ...leManagerRemovingTransformEvaluatorTest.java | 144 +++++++++++++++++++
 .../direct/DoFnLifecycleManagerTest.java        | 119 +++++++++++++++
 ...LocalInvalidatingTransformEvaluatorTest.java | 135 -----------------
 10 files changed, 475 insertions(+), 378 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/cf0bf3bf/runners/direct-java/src/main/java/org/apache/beam/runners/direct/CloningThreadLocal.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/CloningThreadLocal.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/CloningThreadLocal.java
deleted file mode 100644
index b9dc4ca..0000000
--- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/CloningThreadLocal.java
+++ /dev/null
@@ -1,43 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.beam.runners.direct;
-
-import org.apache.beam.sdk.util.SerializableUtils;
-
-import java.io.Serializable;
-
-/**
- * A {@link ThreadLocal} that obtains the initial value by cloning an original value.
- */
-class CloningThreadLocal<T extends Serializable> extends ThreadLocal<T> {
-  public static <T extends Serializable> CloningThreadLocal<T> of(T original) {
-    return new CloningThreadLocal<>(original);
-  }
-
-  private final T original;
-
-  private CloningThreadLocal(T original) {
-    this.original = original;
-  }
-
-  @Override
-  public T initialValue() {
-    return SerializableUtils.clone(original);
-  }
-}

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/cf0bf3bf/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DoFnLifecycleManager.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DoFnLifecycleManager.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DoFnLifecycleManager.java
new file mode 100644
index 0000000..2783657
--- /dev/null
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DoFnLifecycleManager.java
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.direct;
+
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.OldDoFn;
+import org.apache.beam.sdk.util.SerializableUtils;
+
+import com.google.common.cache.CacheBuilder;
+import com.google.common.cache.CacheLoader;
+import com.google.common.cache.LoadingCache;
+
+/**
+ * Manages {@link DoFn} setup, teardown, and serialization.
+ *
+ * <p>{@link DoFnLifecycleManager} is similar to a {@link ThreadLocal} storing a {@link DoFn}, but
+ * calls the {@link DoFn} {@link Setup} the first time the {@link DoFn} is obtained and {@link
+ * Teardown} whenever the {@link DoFn} is removed, and provides a method for clearing all cached
+ * {@link DoFn DoFns}.
+ */
+class DoFnLifecycleManager {
+  public static DoFnLifecycleManager of(OldDoFn<?, ?> original) {
+    return new DoFnLifecycleManager(original);
+  }
+
+  private final LoadingCache<Thread, OldDoFn<?, ?>> outstanding;
+
+  private DoFnLifecycleManager(OldDoFn<?, ?> original) {
+    this.outstanding = CacheBuilder.newBuilder().build(new DeserializingCacheLoader(original));
+  }
+
+  public OldDoFn<?, ?> get() throws Exception {
+    Thread currentThread = Thread.currentThread();
+    return outstanding.get(currentThread);
+  }
+
+  public void remove() throws Exception {
+    Thread currentThread = Thread.currentThread();
+    outstanding.invalidate(currentThread);
+  }
+
+  /**
+   * Remove all {@link DoFn DoFns} from this {@link DoFnLifecycleManager}.
+   */
+  public void removeAll() throws Exception {
+    outstanding.invalidateAll();
+  }
+
+  private class DeserializingCacheLoader extends CacheLoader<Thread, OldDoFn<?, ?>> {
+    private final byte[] original;
+
+    public DeserializingCacheLoader(OldDoFn<?, ?> original) {
+      this.original = SerializableUtils.serializeToByteArray(original);
+    }
+
+    @Override
+    public OldDoFn<?, ?> load(Thread key) throws Exception {
+      return (OldDoFn<?, ?>) SerializableUtils.deserializeFromByteArray(original,
+          "DoFn Copy in thread " + key.getName());
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/cf0bf3bf/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DoFnLifecycleManagerRemovingTransformEvaluator.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DoFnLifecycleManagerRemovingTransformEvaluator.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DoFnLifecycleManagerRemovingTransformEvaluator.java
new file mode 100644
index 0000000..f3d1d4f
--- /dev/null
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DoFnLifecycleManagerRemovingTransformEvaluator.java
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.direct;
+
+import org.apache.beam.sdk.util.WindowedValue;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * A {@link TransformEvaluator} which delegates calls to an underlying {@link TransformEvaluator},
+ * clearing the value of a {@link DoFnLifecycleManager} if any call throws an exception.
+ */
+class DoFnLifecycleManagerRemovingTransformEvaluator<InputT> implements TransformEvaluator<InputT> {
+  private static final Logger LOG =
+      LoggerFactory.getLogger(DoFnLifecycleManagerRemovingTransformEvaluator.class);
+  private final TransformEvaluator<InputT> underlying;
+  private final DoFnLifecycleManager lifecycleManager;
+
+  public static <InputT> TransformEvaluator<InputT> wrapping(
+      TransformEvaluator<InputT> underlying, DoFnLifecycleManager threadLocal) {
+    return new DoFnLifecycleManagerRemovingTransformEvaluator<>(underlying, threadLocal);
+  }
+
+  private DoFnLifecycleManagerRemovingTransformEvaluator(
+      TransformEvaluator<InputT> underlying, DoFnLifecycleManager threadLocal) {
+    this.underlying = underlying;
+    this.lifecycleManager = threadLocal;
+  }
+
+  @Override
+  public void processElement(WindowedValue<InputT> element) throws Exception {
+    try {
+      underlying.processElement(element);
+    } catch (Exception e) {
+      try {
+        lifecycleManager.remove();
+      } catch (Exception removalException) {
+        LOG.error(
+            "Exception encountered while cleaning up after processing an element",
+            removalException);
+        e.addSuppressed(removalException);
+      }
+      throw e;
+    }
+  }
+
+  @Override
+  public TransformResult finishBundle() throws Exception {
+    try {
+      return underlying.finishBundle();
+    } catch (Exception e) {
+      try {
+        lifecycleManager.remove();
+      } catch (Exception removalException) {
+        LOG.error(
+            "Exception encountered while cleaning up after finishing a bundle",
+            removalException);
+        e.addSuppressed(removalException);
+      }
+      throw e;
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/cf0bf3bf/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiEvaluatorFactory.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiEvaluatorFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiEvaluatorFactory.java
index 40533c0..f2455e1 100644
--- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiEvaluatorFactory.java
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiEvaluatorFactory.java
@@ -31,6 +31,9 @@ import com.google.common.cache.CacheBuilder;
 import com.google.common.cache.CacheLoader;
 import com.google.common.cache.LoadingCache;
 
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
 import java.util.Map;
 
 /**
@@ -38,32 +41,26 @@ import java.util.Map;
  * {@link BoundMulti} primitive {@link PTransform}.
  */
 class ParDoMultiEvaluatorFactory implements TransformEvaluatorFactory {
-  private final LoadingCache<AppliedPTransform<?, ?, BoundMulti<?, ?>>, ThreadLocal<OldDoFn<?, ?>>>
+  private static final Logger LOG = LoggerFactory.getLogger(ParDoMultiEvaluatorFactory.class);
+  private final LoadingCache<AppliedPTransform<?, ?, BoundMulti<?, ?>>, DoFnLifecycleManager>
       fnClones;
 
   public ParDoMultiEvaluatorFactory() {
-    fnClones =
-        CacheBuilder.newBuilder()
-            .build(
-                new CacheLoader<
-                    AppliedPTransform<?, ?, BoundMulti<?, ?>>, ThreadLocal<OldDoFn<?, ?>>>() {
-                  @Override
-                  public ThreadLocal<OldDoFn<?, ?>> load(
-                      AppliedPTransform<?, ?, BoundMulti<?, ?>> key)
-                      throws Exception {
-                    @SuppressWarnings({"unchecked", "rawtypes"})
-                    ThreadLocal threadLocal =
-                        (ThreadLocal) CloningThreadLocal.of(key.getTransform().getFn());
-                    return threadLocal;
-                  }
-                });
+    fnClones = CacheBuilder.newBuilder()
+        .build(new CacheLoader<AppliedPTransform<?, ?, BoundMulti<?, ?>>, DoFnLifecycleManager>() {
+          @Override
+          public DoFnLifecycleManager load(AppliedPTransform<?, ?, BoundMulti<?, ?>> key)
+              throws Exception {
+            return DoFnLifecycleManager.of(key.getTransform().getFn());
+          }
+        });
   }
 
   @Override
   public <T> TransformEvaluator<T> forApplication(
       AppliedPTransform<?, ?, ?> application,
       CommittedBundle<?> inputBundle,
-      EvaluationContext evaluationContext) {
+      EvaluationContext evaluationContext) throws Exception {
     @SuppressWarnings({"unchecked", "rawtypes"})
     TransformEvaluator<T> evaluator =
         createMultiEvaluator((AppliedPTransform) application, inputBundle, evaluationContext);
@@ -71,38 +68,45 @@ class ParDoMultiEvaluatorFactory implements TransformEvaluatorFactory {
   }
 
   @Override
-  public void cleanup() {
-
+  public void cleanup() throws Exception {
+    for (DoFnLifecycleManager lifecycleManager : fnClones.asMap().values()) {
+      lifecycleManager.removeAll();
+    }
   }
 
   private <InT, OuT> TransformEvaluator<InT> createMultiEvaluator(
       AppliedPTransform<PCollection<InT>, PCollectionTuple, BoundMulti<InT, OuT>> application,
       CommittedBundle<InT> inputBundle,
-      EvaluationContext evaluationContext) {
+      EvaluationContext evaluationContext) throws Exception {
     Map<TupleTag<?>, PCollection<?>> outputs = application.getOutput().getAll();
 
-    @SuppressWarnings({"unchecked", "rawtypes"})
-    ThreadLocal<OldDoFn<InT, OuT>> fnLocal =
-        (ThreadLocal) fnClones.getUnchecked((AppliedPTransform) application);
+    DoFnLifecycleManager fnLocal = fnClones.getUnchecked((AppliedPTransform) application);
     String stepName = evaluationContext.getStepName(application);
     DirectStepContext stepContext =
         evaluationContext.getExecutionContext(application, inputBundle.getKey())
             .getOrCreateStepContext(stepName, stepName);
     try {
+      @SuppressWarnings({"unchecked", "rawtypes"})
       TransformEvaluator<InT> parDoEvaluator =
           ParDoEvaluator.create(
               evaluationContext,
               stepContext,
               inputBundle,
               application,
-              fnLocal.get(),
+              (OldDoFn) fnLocal.get(),
               application.getTransform().getSideInputs(),
               application.getTransform().getMainOutputTag(),
               application.getTransform().getSideOutputTags().getAll(),
               outputs);
-      return ThreadLocalInvalidatingTransformEvaluator.wrapping(parDoEvaluator, fnLocal);
+      return DoFnLifecycleManagerRemovingTransformEvaluator.wrapping(parDoEvaluator, fnLocal);
     } catch (Exception e) {
-      fnLocal.remove();
+      try {
+        fnLocal.remove();
+      } catch (Exception removalException) {
+        LOG.error("Exception encountered while cleaning up in ParDo evaluator construction",
+            removalException);
+        e.addSuppressed(removalException);
+      }
       throw e;
     }
   }

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/cf0bf3bf/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoSingleEvaluatorFactory.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoSingleEvaluatorFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoSingleEvaluatorFactory.java
index 201fb46..a0fbd1d 100644
--- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoSingleEvaluatorFactory.java
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoSingleEvaluatorFactory.java
@@ -31,6 +31,9 @@ import com.google.common.cache.CacheLoader;
 import com.google.common.cache.LoadingCache;
 import com.google.common.collect.ImmutableMap;
 
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
 import java.util.Collections;
 
 /**
@@ -38,22 +41,18 @@ import java.util.Collections;
  * {@link Bound ParDo.Bound} primitive {@link PTransform}.
  */
 class ParDoSingleEvaluatorFactory implements TransformEvaluatorFactory {
-  private final LoadingCache<AppliedPTransform<?, ?, Bound<?, ?>>, ThreadLocal<OldDoFn<?, ?>>>
-      fnClones;
+  private static final Logger LOG = LoggerFactory.getLogger(ParDoSingleEvaluatorFactory.class);
+  private final LoadingCache<AppliedPTransform<?, ?, Bound<?, ?>>, DoFnLifecycleManager> fnClones;
 
   public ParDoSingleEvaluatorFactory() {
     fnClones =
         CacheBuilder.newBuilder()
             .build(
-                new CacheLoader<
-                    AppliedPTransform<?, ?, Bound<?, ?>>, ThreadLocal<OldDoFn<?, ?>>>() {
+                new CacheLoader<AppliedPTransform<?, ?, Bound<?, ?>>, DoFnLifecycleManager>() {
                   @Override
-                  public ThreadLocal<OldDoFn<?, ?>> load(AppliedPTransform<?, ?, Bound<?, ?>> key)
+                  public DoFnLifecycleManager load(AppliedPTransform<?, ?, Bound<?, ?>> key)
                       throws Exception {
-                    @SuppressWarnings({"unchecked", "rawtypes"})
-                    ThreadLocal threadLocal =
-                        (ThreadLocal) CloningThreadLocal.of(key.getTransform().getFn());
-                    return threadLocal;
+                    return DoFnLifecycleManager.of(key.getTransform().getFn());
                   }
                 });
   }
@@ -62,7 +61,7 @@ class ParDoSingleEvaluatorFactory implements TransformEvaluatorFactory {
   public <T> TransformEvaluator<T> forApplication(
       final AppliedPTransform<?, ?, ?> application,
       CommittedBundle<?> inputBundle,
-      EvaluationContext evaluationContext) {
+      EvaluationContext evaluationContext) throws Exception {
     @SuppressWarnings({"unchecked", "rawtypes"})
     TransformEvaluator<T> evaluator =
         createSingleEvaluator((AppliedPTransform) application, inputBundle, evaluationContext);
@@ -70,39 +69,45 @@ class ParDoSingleEvaluatorFactory implements TransformEvaluatorFactory {
   }
 
   @Override
-  public void cleanup() {
-
+  public void cleanup() throws Exception {
+    for (DoFnLifecycleManager lifecycleManager : fnClones.asMap().values()) {
+      lifecycleManager.removeAll();
+    }
   }
 
   private <InputT, OutputT> TransformEvaluator<InputT> createSingleEvaluator(
       AppliedPTransform<PCollection<InputT>, PCollection<OutputT>, Bound<InputT, OutputT>>
           application,
       CommittedBundle<InputT> inputBundle,
-      EvaluationContext evaluationContext) {
+      EvaluationContext evaluationContext) throws Exception {
     TupleTag<OutputT> mainOutputTag = new TupleTag<>("out");
     String stepName = evaluationContext.getStepName(application);
     DirectStepContext stepContext =
         evaluationContext.getExecutionContext(application, inputBundle.getKey())
             .getOrCreateStepContext(stepName, stepName);
 
-    @SuppressWarnings({"unchecked", "rawtypes"})
-    ThreadLocal<OldDoFn<InputT, OutputT>> fnLocal =
-        (ThreadLocal) fnClones.getUnchecked((AppliedPTransform) application);
+    DoFnLifecycleManager fnLocal = fnClones.getUnchecked((AppliedPTransform) application);
     try {
+      @SuppressWarnings({"unchecked", "rawtypes"})
       ParDoEvaluator<InputT> parDoEvaluator =
           ParDoEvaluator.create(
               evaluationContext,
               stepContext,
               inputBundle,
               application,
-              fnLocal.get(),
+              (OldDoFn) fnLocal.get(),
               application.getTransform().getSideInputs(),
               mainOutputTag,
               Collections.<TupleTag<?>>emptyList(),
               ImmutableMap.<TupleTag<?>, PCollection<?>>of(mainOutputTag, application.getOutput()));
-      return ThreadLocalInvalidatingTransformEvaluator.wrapping(parDoEvaluator, fnLocal);
+      return DoFnLifecycleManagerRemovingTransformEvaluator.wrapping(parDoEvaluator, fnLocal);
     } catch (Exception e) {
-      fnLocal.remove();
+      try {
+        fnLocal.remove();
+      } catch (Exception removalException) {
+        LOG.error("Exception encountered constructing ParDo evaluator", removalException);
+        e.addSuppressed(removalException);
+      }
       throw e;
     }
   }

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/cf0bf3bf/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ThreadLocalInvalidatingTransformEvaluator.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ThreadLocalInvalidatingTransformEvaluator.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ThreadLocalInvalidatingTransformEvaluator.java
deleted file mode 100644
index d8a6bf9..0000000
--- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ThreadLocalInvalidatingTransformEvaluator.java
+++ /dev/null
@@ -1,63 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.beam.runners.direct;
-
-import org.apache.beam.sdk.util.WindowedValue;
-
-/**
- * A {@link TransformEvaluator} which delegates calls to an underlying {@link TransformEvaluator},
- * clearing the value of a {@link ThreadLocal} if any call throws an exception.
- */
-class ThreadLocalInvalidatingTransformEvaluator<InputT>
-    implements TransformEvaluator<InputT> {
-  private final TransformEvaluator<InputT> underlying;
-  private final ThreadLocal<?> threadLocal;
-
-  public static <InputT> TransformEvaluator<InputT> wrapping(
-      TransformEvaluator<InputT> underlying,
-      ThreadLocal<?> threadLocal) {
-    return new ThreadLocalInvalidatingTransformEvaluator<>(underlying, threadLocal);
-  }
-
-  private ThreadLocalInvalidatingTransformEvaluator(
-      TransformEvaluator<InputT> underlying, ThreadLocal<?> threadLocal) {
-    this.underlying = underlying;
-    this.threadLocal = threadLocal;
-  }
-
-  @Override
-  public void processElement(WindowedValue<InputT> element) throws Exception {
-    try {
-      underlying.processElement(element);
-    } catch (Exception e) {
-      threadLocal.remove();
-      throw e;
-    }
-  }
-
-  @Override
-  public TransformResult finishBundle() throws Exception {
-    try {
-      return underlying.finishBundle();
-    } catch (Exception e) {
-      threadLocal.remove();
-      throw e;
-    }
-  }
-}

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/cf0bf3bf/runners/direct-java/src/test/java/org/apache/beam/runners/direct/CloningThreadLocalTest.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/CloningThreadLocalTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/CloningThreadLocalTest.java
deleted file mode 100644
index 298db46..0000000
--- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/CloningThreadLocalTest.java
+++ /dev/null
@@ -1,92 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.beam.runners.direct;
-
-import static org.hamcrest.Matchers.equalTo;
-import static org.hamcrest.Matchers.nullValue;
-import static org.hamcrest.core.IsNot.not;
-import static org.hamcrest.core.IsSame.theInstance;
-import static org.junit.Assert.assertThat;
-
-import org.junit.Test;
-import org.junit.runner.RunWith;
-import org.junit.runners.JUnit4;
-
-import java.io.Serializable;
-import java.util.concurrent.Callable;
-import java.util.concurrent.Executors;
-
-/**
- * Tests for {@link CloningThreadLocalTest}.
- */
-@RunWith(JUnit4.class)
-public class CloningThreadLocalTest {
-  @Test
-  public void returnsCopiesOfOriginal() throws Exception {
-    Record original = new Record();
-    ThreadLocal<Record> loaded = CloningThreadLocal.of(original);
-    assertThat(loaded.get(), not(nullValue()));
-    assertThat(loaded.get(), equalTo(original));
-    assertThat(loaded.get(), not(theInstance(original)));
-  }
-
-  @Test
-  public void returnsDifferentCopiesInDifferentThreads() throws Exception {
-    Record original = new Record();
-    final ThreadLocal<Record> loaded = CloningThreadLocal.of(original);
-    assertThat(loaded.get(), not(nullValue()));
-    assertThat(loaded.get(), equalTo(original));
-    assertThat(loaded.get(), not(theInstance(original)));
-
-    Callable<Record> otherThread =
-        new Callable<Record>() {
-          @Override
-          public Record call() throws Exception {
-            return loaded.get();
-          }
-        };
-    Record sameThread = loaded.get();
-    Record firstOtherThread = Executors.newSingleThreadExecutor().submit(otherThread).get();
-    Record secondOtherThread = Executors.newSingleThreadExecutor().submit(otherThread).get();
-
-    assertThat(sameThread, equalTo(firstOtherThread));
-    assertThat(sameThread, equalTo(secondOtherThread));
-    assertThat(sameThread, not(theInstance(firstOtherThread)));
-    assertThat(sameThread, not(theInstance(secondOtherThread)));
-    assertThat(firstOtherThread, not(theInstance(secondOtherThread)));
-  }
-
-  private static class Record implements Serializable {
-    private final double rand = Math.random();
-
-    @Override
-    public boolean equals(Object other) {
-      if (!(other instanceof Record)) {
-        return false;
-      }
-      Record that = (Record) other;
-      return this.rand == that.rand;
-    }
-
-    @Override
-    public int hashCode() {
-      return 1;
-    }
-  }
-}

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/cf0bf3bf/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DoFnLifecycleManagerRemovingTransformEvaluatorTest.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DoFnLifecycleManagerRemovingTransformEvaluatorTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DoFnLifecycleManagerRemovingTransformEvaluatorTest.java
new file mode 100644
index 0000000..67f4ff4
--- /dev/null
+++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DoFnLifecycleManagerRemovingTransformEvaluatorTest.java
@@ -0,0 +1,144 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.direct;
+
+import static org.hamcrest.Matchers.containsInAnyOrder;
+import static org.hamcrest.Matchers.not;
+import static org.hamcrest.Matchers.nullValue;
+import static org.hamcrest.core.Is.is;
+import static org.junit.Assert.assertThat;
+import static org.junit.Assert.fail;
+
+import org.apache.beam.sdk.transforms.OldDoFn;
+import org.apache.beam.sdk.util.WindowedValue;
+
+import org.hamcrest.Matchers;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+import java.util.ArrayList;
+import java.util.List;
+
+/**
+ * Tests for {@link DoFnLifecycleManagerRemovingTransformEvaluator}.
+ */
+@RunWith(JUnit4.class)
+public class DoFnLifecycleManagerRemovingTransformEvaluatorTest {
+  private DoFnLifecycleManager lifecycleManager;
+
+  @Before
+  public void setup() {
+    lifecycleManager = DoFnLifecycleManager.of(new TestFn());
+  }
+
+  @Test
+  public void delegatesToUnderlying() throws Exception {
+    RecordingTransformEvaluator underlying = new RecordingTransformEvaluator();
+    OldDoFn<?, ?> original = lifecycleManager.get();
+    TransformEvaluator<Object> evaluator =
+        DoFnLifecycleManagerRemovingTransformEvaluator.wrapping(underlying, lifecycleManager);
+    WindowedValue<Object> first = WindowedValue.valueInGlobalWindow(new Object());
+    WindowedValue<Object> second = WindowedValue.valueInGlobalWindow(new Object());
+    evaluator.processElement(first);
+    assertThat(underlying.objects, containsInAnyOrder(first));
+    evaluator.processElement(second);
+    evaluator.finishBundle();
+
+    assertThat(underlying.finishBundleCalled, is(true));
+    assertThat(underlying.objects, containsInAnyOrder(second, first));
+  }
+
+  @Test
+  public void removesOnExceptionInProcessElement() throws Exception {
+    ThrowingTransformEvaluator underlying = new ThrowingTransformEvaluator();
+    OldDoFn<?, ?> original = lifecycleManager.get();
+    assertThat(original, not(nullValue()));
+    TransformEvaluator<Object> evaluator =
+        DoFnLifecycleManagerRemovingTransformEvaluator.wrapping(underlying, lifecycleManager);
+
+    try {
+      evaluator.processElement(WindowedValue.valueInGlobalWindow(new Object()));
+    } catch (Exception e) {
+      assertThat(lifecycleManager.get(), not(Matchers.<OldDoFn<?, ?>>theInstance(original)));
+      return;
+    }
+    fail("Expected ThrowingTransformEvaluator to throw on method call");
+  }
+
+  @Test
+  public void removesOnExceptionInFinishBundle() throws Exception {
+    ThrowingTransformEvaluator underlying = new ThrowingTransformEvaluator();
+    OldDoFn<?, ?> original = lifecycleManager.get();
+    // the LifecycleManager is set when the evaluator starts
+    assertThat(original, not(nullValue()));
+    TransformEvaluator<Object> evaluator =
+        DoFnLifecycleManagerRemovingTransformEvaluator.wrapping(underlying, lifecycleManager);
+
+    try {
+      evaluator.finishBundle();
+    } catch (Exception e) {
+      assertThat(lifecycleManager.get(),
+          Matchers.not(Matchers.<OldDoFn<?, ?>>theInstance(original)));
+      return;
+    }
+    fail("Expected ThrowingTransformEvaluator to throw on method call");
+  }
+
+  private class RecordingTransformEvaluator implements TransformEvaluator<Object> {
+    private boolean finishBundleCalled;
+    private List<WindowedValue<Object>> objects;
+
+    public RecordingTransformEvaluator() {
+      this.finishBundleCalled = true;
+      this.objects = new ArrayList<>();
+    }
+
+    @Override
+    public void processElement(WindowedValue<Object> element) throws Exception {
+      objects.add(element);
+    }
+
+    @Override
+    public TransformResult finishBundle() throws Exception {
+      finishBundleCalled = true;
+      return null;
+    }
+  }
+
+  private class ThrowingTransformEvaluator implements TransformEvaluator<Object> {
+    @Override
+    public void processElement(WindowedValue<Object> element) throws Exception {
+      throw new Exception();
+    }
+
+    @Override
+    public TransformResult finishBundle() throws Exception {
+      throw new Exception();
+    }
+  }
+
+
+  private static class TestFn extends OldDoFn<Object, Object> {
+    @Override
+    public void processElement(ProcessContext c) throws Exception {
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/cf0bf3bf/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DoFnLifecycleManagerTest.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DoFnLifecycleManagerTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DoFnLifecycleManagerTest.java
new file mode 100644
index 0000000..f316e19
--- /dev/null
+++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DoFnLifecycleManagerTest.java
@@ -0,0 +1,119 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.direct;
+
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.not;
+import static org.hamcrest.Matchers.theInstance;
+import static org.junit.Assert.assertThat;
+
+import org.apache.beam.sdk.transforms.OldDoFn;
+
+import org.hamcrest.Matchers;
+import org.junit.Test;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.concurrent.Callable;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
+import java.util.concurrent.TimeUnit;
+
+/**
+ * Tests for {@link DoFnLifecycleManager}.
+ */
+public class DoFnLifecycleManagerTest {
+  private TestFn fn = new TestFn();
+  private DoFnLifecycleManager mgr = DoFnLifecycleManager.of(fn);
+
+  @Test
+  public void setupOnGet() throws Exception {
+    TestFn obtained = (TestFn) mgr.get();
+
+    assertThat(obtained, not(theInstance(fn)));
+  }
+
+  @Test
+  public void getMultipleCallsSingleSetupCall() throws Exception {
+    TestFn obtained = (TestFn) mgr.get();
+    TestFn secondObtained = (TestFn) mgr.get();
+
+    assertThat(obtained, theInstance(secondObtained));
+  }
+
+  @Test
+  public void getMultipleThreadsDifferentInstances() throws Exception {
+    CountDownLatch startSignal = new CountDownLatch(1);
+    ExecutorService executor = Executors.newCachedThreadPool();
+    List<Future<TestFn>> futures = new ArrayList<>();
+    for (int i = 0; i < 10; i++) {
+      futures.add(executor.submit(new GetFnCallable(mgr, startSignal)));
+    }
+    startSignal.countDown();
+    List<TestFn> fns = new ArrayList<>();
+    for (Future<TestFn> future : futures) {
+      fns.add(future.get(1L, TimeUnit.SECONDS));
+    }
+
+    for (TestFn fn : fns) {
+      int sameInstances = 0;
+      for (TestFn otherFn : fns) {
+        if (otherFn == fn) {
+          sameInstances++;
+        }
+      }
+      assertThat(sameInstances, equalTo(1));
+    }
+  }
+
+  @Test
+  public void teardownOnRemove() throws Exception {
+    TestFn obtained = (TestFn) mgr.get();
+    mgr.remove();
+
+    assertThat(obtained, not(theInstance(fn)));
+
+    assertThat(mgr.get(), not(Matchers.<OldDoFn<?, ?>>theInstance(obtained)));
+  }
+
+  private static class GetFnCallable implements Callable<TestFn> {
+    private final DoFnLifecycleManager mgr;
+    private final CountDownLatch startSignal;
+
+    private GetFnCallable(DoFnLifecycleManager mgr, CountDownLatch startSignal) {
+      this.mgr = mgr;
+      this.startSignal = startSignal;
+    }
+
+    @Override
+    public TestFn call() throws Exception {
+      startSignal.await();
+      return (TestFn) mgr.get();
+    }
+  }
+
+
+  private static class TestFn extends OldDoFn<Object, Object> {
+    @Override
+    public void processElement(ProcessContext c) throws Exception {
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/cf0bf3bf/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ThreadLocalInvalidatingTransformEvaluatorTest.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ThreadLocalInvalidatingTransformEvaluatorTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ThreadLocalInvalidatingTransformEvaluatorTest.java
deleted file mode 100644
index 6e477d3..0000000
--- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ThreadLocalInvalidatingTransformEvaluatorTest.java
+++ /dev/null
@@ -1,135 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.beam.runners.direct;
-
-import static org.hamcrest.Matchers.containsInAnyOrder;
-import static org.hamcrest.Matchers.not;
-import static org.hamcrest.Matchers.nullValue;
-import static org.hamcrest.core.Is.is;
-import static org.junit.Assert.assertThat;
-import static org.junit.Assert.fail;
-
-import org.apache.beam.sdk.util.WindowedValue;
-
-import org.junit.Before;
-import org.junit.Test;
-import org.junit.runner.RunWith;
-import org.junit.runners.JUnit4;
-
-import java.util.ArrayList;
-import java.util.List;
-
-/**
- * Tests for {@link ThreadLocalInvalidatingTransformEvaluator}.
- */
-@RunWith(JUnit4.class)
-public class ThreadLocalInvalidatingTransformEvaluatorTest {
-  private ThreadLocal<Object> threadLocal;
-
-  @Before
-  public void setup() {
-    threadLocal = new ThreadLocal<>();
-    threadLocal.set(new Object());
-  }
-
-  @Test
-  public void delegatesToUnderlying() throws Exception {
-    RecordingTransformEvaluator underlying = new RecordingTransformEvaluator();
-    Object original = threadLocal.get();
-    TransformEvaluator<Object> evaluator =
-        ThreadLocalInvalidatingTransformEvaluator.wrapping(underlying, threadLocal);
-    WindowedValue<Object> first = WindowedValue.valueInGlobalWindow(new Object());
-    WindowedValue<Object> second = WindowedValue.valueInGlobalWindow(new Object());
-    evaluator.processElement(first);
-    assertThat(underlying.objects, containsInAnyOrder(first));
-    evaluator.processElement(second);
-    evaluator.finishBundle();
-
-    assertThat(underlying.finishBundleCalled, is(true));
-    assertThat(underlying.objects, containsInAnyOrder(second, first));
-  }
-
-  @Test
-  public void removesOnExceptionInProcessElement() {
-    ThrowingTransformEvaluator underlying = new ThrowingTransformEvaluator();
-    Object original = threadLocal.get();
-    assertThat(original, not(nullValue()));
-    TransformEvaluator<Object> evaluator =
-        ThreadLocalInvalidatingTransformEvaluator.wrapping(underlying, threadLocal);
-
-    try {
-      evaluator.processElement(WindowedValue.valueInGlobalWindow(new Object()));
-    } catch (Exception e) {
-      assertThat(threadLocal.get(), nullValue());
-      return;
-    }
-    fail("Expected ThrowingTransformEvaluator to throw on method call");
-  }
-
-  @Test
-  public void removesOnExceptionInFinishBundle() {
-    ThrowingTransformEvaluator underlying = new ThrowingTransformEvaluator();
-    Object original = threadLocal.get();
-    // the ThreadLocal is set when the evaluator starts
-    assertThat(original, not(nullValue()));
-    TransformEvaluator<Object> evaluator =
-        ThreadLocalInvalidatingTransformEvaluator.wrapping(underlying, threadLocal);
-
-    try {
-      evaluator.finishBundle();
-    } catch (Exception e) {
-      assertThat(threadLocal.get(), nullValue());
-      return;
-    }
-    fail("Expected ThrowingTransformEvaluator to throw on method call");
-  }
-
-  private class RecordingTransformEvaluator implements TransformEvaluator<Object> {
-    private boolean finishBundleCalled;
-    private List<WindowedValue<Object>> objects;
-
-    public RecordingTransformEvaluator() {
-      this.finishBundleCalled = true;
-      this.objects = new ArrayList<>();
-    }
-
-    @Override
-    public void processElement(WindowedValue<Object> element) throws Exception {
-      objects.add(element);
-    }
-
-    @Override
-    public TransformResult finishBundle() throws Exception {
-      finishBundleCalled = true;
-      return null;
-    }
-  }
-
-  private class ThrowingTransformEvaluator implements TransformEvaluator<Object> {
-    @Override
-    public void processElement(WindowedValue<Object> element) throws Exception {
-      throw new Exception();
-    }
-
-    @Override
-    public TransformResult finishBundle() throws Exception {
-      throw new Exception();
-    }
-  }
-}