You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by ke...@apache.org on 2016/11/28 21:17:23 UTC

[4/5] incubator-beam git commit: Add support for Stateful ParDo in the Direct runner

Add support for Stateful ParDo in the Direct runner

This adds overrides and new evaluators to ensure that
state is accessed in a single-threaded manner per key
and is cleaned up when a window expires.


Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/ec2c0e06
Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/ec2c0e06
Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/ec2c0e06

Branch: refs/heads/master
Commit: ec2c0e0698c1380b309a609eb642aba445c77e27
Parents: 7e158e4
Author: Kenneth Knowles <kl...@google.com>
Authored: Wed Nov 9 21:59:15 2016 -0800
Committer: Kenneth Knowles <kl...@google.com>
Committed: Mon Nov 28 11:48:32 2016 -0800

----------------------------------------------------------------------
 .../beam/runners/direct/EvaluationContext.java  |  15 +
 .../beam/runners/direct/ParDoEvaluator.java     |  11 +-
 .../runners/direct/ParDoEvaluatorFactory.java   |  53 +++-
 .../direct/ParDoMultiOverrideFactory.java       |  76 ++++-
 .../ParDoSingleViaMultiOverrideFactory.java     |   6 +-
 .../direct/StatefulParDoEvaluatorFactory.java   | 256 ++++++++++++++++
 .../direct/TransformEvaluatorRegistry.java      |   2 +
 .../direct/WatermarkCallbackExecutor.java       |  34 +++
 .../StatefulParDoEvaluatorFactoryTest.java      | 300 +++++++++++++++++++
 .../org/apache/beam/sdk/transforms/DoFn.java    |   4 +-
 .../org/apache/beam/sdk/transforms/OldDoFn.java |   8 +-
 11 files changed, 741 insertions(+), 24 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/ec2c0e06/runners/direct-java/src/main/java/org/apache/beam/runners/direct/EvaluationContext.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/EvaluationContext.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/EvaluationContext.java
index c1225f6..201aaed 100644
--- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/EvaluationContext.java
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/EvaluationContext.java
@@ -296,6 +296,21 @@ class EvaluationContext {
     fireAvailableCallbacks(lookupProducing(value));
   }
 
+  /**
+   * Schedule a callback to be executed after the given window is expired.
+   *
+   * <p>For example, upstream state associated with the window may be cleared.
+   */
+  public void scheduleAfterWindowExpiration(
+      AppliedPTransform<?, ?, ?> producing,
+      BoundedWindow window,
+      WindowingStrategy<?, ?> windowingStrategy,
+      Runnable runnable) {
+    callbackExecutor.callOnWindowExpiration(producing, window, windowingStrategy, runnable);
+
+    fireAvailableCallbacks(producing);
+  }
+
   private AppliedPTransform<?, ?, ?> getProducing(PValue value) {
     if (value.getProducingTransformInternal() != null) {
       return value.getProducingTransformInternal();

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/ec2c0e06/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoEvaluator.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoEvaluator.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoEvaluator.java
index 3285c7e..750e5f1 100644
--- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoEvaluator.java
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoEvaluator.java
@@ -42,6 +42,7 @@ import org.apache.beam.sdk.values.PCollectionView;
 import org.apache.beam.sdk.values.TupleTag;
 
 class ParDoEvaluator<InputT, OutputT> implements TransformEvaluator<InputT> {
+
   public static <InputT, OutputT> ParDoEvaluator<InputT, OutputT> create(
       EvaluationContext evaluationContext,
       DirectStepContext stepContext,
@@ -84,11 +85,17 @@ class ParDoEvaluator<InputT, OutputT> implements TransformEvaluator<InputT> {
     }
 
     return new ParDoEvaluator<>(
-        runner, application, aggregatorChanges, outputBundles.values(), stepContext);
+        evaluationContext,
+        runner,
+        application,
+        aggregatorChanges,
+        outputBundles.values(),
+        stepContext);
   }
 
   ////////////////////////////////////////////////////////////////////////////////////////////////
 
+  private final EvaluationContext evaluationContext;
   private final PushbackSideInputDoFnRunner<InputT, ?> fnRunner;
   private final AppliedPTransform<?, ?, ?> transform;
   private final AggregatorContainer.Mutator aggregatorChanges;
@@ -98,11 +105,13 @@ class ParDoEvaluator<InputT, OutputT> implements TransformEvaluator<InputT> {
   private final ImmutableList.Builder<WindowedValue<InputT>> unprocessedElements;
 
   private ParDoEvaluator(
+      EvaluationContext evaluationContext,
       PushbackSideInputDoFnRunner<InputT, ?> fnRunner,
       AppliedPTransform<?, ?, ?> transform,
       AggregatorContainer.Mutator aggregatorChanges,
       Collection<UncommittedBundle<?>> outputBundles,
       DirectStepContext stepContext) {
+    this.evaluationContext = evaluationContext;
     this.fnRunner = fnRunner;
     this.transform = transform;
     this.outputBundles = outputBundles;

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/ec2c0e06/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoEvaluatorFactory.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoEvaluatorFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoEvaluatorFactory.java
index b776da1..02e034a 100644
--- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoEvaluatorFactory.java
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoEvaluatorFactory.java
@@ -20,14 +20,16 @@ package org.apache.beam.runners.direct;
 import com.google.common.cache.CacheBuilder;
 import com.google.common.cache.CacheLoader;
 import com.google.common.cache.LoadingCache;
+import java.util.List;
 import org.apache.beam.runners.direct.DirectExecutionContext.DirectStepContext;
 import org.apache.beam.runners.direct.DirectRunner.CommittedBundle;
 import org.apache.beam.sdk.transforms.AppliedPTransform;
 import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.ParDo;
-import org.apache.beam.sdk.transforms.ParDo.BoundMulti;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.PCollectionTuple;
+import org.apache.beam.sdk.values.PCollectionView;
+import org.apache.beam.sdk.values.TupleTag;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -54,10 +56,26 @@ final class ParDoEvaluatorFactory<InputT, OutputT> implements TransformEvaluator
   @Override
   public <T> TransformEvaluator<T> forApplication(
       AppliedPTransform<?, ?, ?> application, CommittedBundle<?> inputBundle) throws Exception {
+
+    AppliedPTransform<PCollection<InputT>, PCollectionTuple, ParDo.BoundMulti<InputT, OutputT>>
+        parDoApplication =
+            (AppliedPTransform<
+                    PCollection<InputT>, PCollectionTuple, ParDo.BoundMulti<InputT, OutputT>>)
+                application;
+
+    ParDo.BoundMulti<InputT, OutputT> transform = parDoApplication.getTransform();
+    final DoFn<InputT, OutputT> doFn = transform.getNewFn();
+
     @SuppressWarnings({"unchecked", "rawtypes"})
     TransformEvaluator<T> evaluator =
         (TransformEvaluator<T>)
-            createEvaluator((AppliedPTransform) application, (CommittedBundle) inputBundle);
+            createEvaluator(
+                (AppliedPTransform) application,
+                inputBundle.getKey(),
+                doFn,
+                transform.getSideInputs(),
+                transform.getMainOutputTag(),
+                transform.getSideOutputTags().getAll());
     return evaluator;
   }
 
@@ -66,21 +84,32 @@ final class ParDoEvaluatorFactory<InputT, OutputT> implements TransformEvaluator
     DoFnLifecycleManagers.removeAllFromManagers(fnClones.asMap().values());
   }
 
+  /**
+   * Creates an evaluator for an arbitrary {@link AppliedPTransform} node, with the pieces of the
+   * {@link ParDo} unpacked.
+   *
+   * <p>This can thus be invoked regardless of whether the types in the {@link AppliedPTransform}
+   * correspond with the type in the unpacked {@link DoFn}, side inputs, and output tags.
+   */
   @SuppressWarnings({"unchecked", "rawtypes"})
-  private TransformEvaluator<InputT> createEvaluator(
-      AppliedPTransform<PCollection<InputT>, PCollectionTuple, BoundMulti<InputT, OutputT>>
-          application,
-      CommittedBundle<InputT> inputBundle)
+  TransformEvaluator<InputT> createEvaluator(
+        AppliedPTransform<PCollection<?>, PCollectionTuple, ?>
+        application,
+        StructuralKey<?> inputBundleKey,
+        DoFn<InputT, OutputT> doFn,
+        List<PCollectionView<?>> sideInputs,
+        TupleTag<OutputT> mainOutputTag,
+        List<TupleTag<?>> sideOutputTags)
       throws Exception {
     String stepName = evaluationContext.getStepName(application);
     DirectStepContext stepContext =
         evaluationContext
-            .getExecutionContext(application, inputBundle.getKey())
+            .getExecutionContext(application, inputBundleKey)
             .getOrCreateStepContext(stepName, stepName);
 
-    DoFnLifecycleManager fnManager = fnClones.getUnchecked(application.getTransform().getNewFn());
+    DoFnLifecycleManager fnManager = fnClones.getUnchecked(doFn);
+
     try {
-      ParDo.BoundMulti<InputT, OutputT> transform = application.getTransform();
       return DoFnLifecycleManagerRemovingTransformEvaluator.wrapping(
           ParDoEvaluator.<InputT, OutputT>create(
               evaluationContext,
@@ -88,9 +117,9 @@ final class ParDoEvaluatorFactory<InputT, OutputT> implements TransformEvaluator
               application,
               application.getInput().getWindowingStrategy(),
               fnManager.get(),
-              transform.getSideInputs(),
-              transform.getMainOutputTag(),
-              transform.getSideOutputTags().getAll(),
+              sideInputs,
+              mainOutputTag,
+              sideOutputTags,
               application.getOutput().getAll()),
           fnManager);
     } catch (Exception e) {

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/ec2c0e06/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java
index 6cc3e6e..8db5159 100644
--- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java
@@ -18,13 +18,19 @@
 package org.apache.beam.runners.direct;
 
 import org.apache.beam.runners.core.SplittableParDo;
+import org.apache.beam.sdk.coders.CannotProvideCoderException;
+import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.GroupByKey;
 import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.transforms.ParDo;
 import org.apache.beam.sdk.transforms.reflect.DoFnSignature;
 import org.apache.beam.sdk.transforms.reflect.DoFnSignatures;
+import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.PCollectionTuple;
+import org.apache.beam.sdk.values.TupleTagList;
+import org.apache.beam.sdk.values.TypedPValue;
 
 /**
  * A {@link PTransformOverrideFactory} that provides overrides for applications of a {@link ParDo}
@@ -42,10 +48,74 @@ class ParDoMultiOverrideFactory<InputT, OutputT>
 
     DoFn<InputT, OutputT> fn = transform.getNewFn();
     DoFnSignature signature = DoFnSignatures.getSignature(fn.getClass());
-    if (!signature.processElement().isSplittable()) {
-      return transform;
-    } else {
+    if (signature.processElement().isSplittable()) {
       return new SplittableParDo(fn);
+    } else if (signature.stateDeclarations().size() > 0
+        || signature.timerDeclarations().size() > 0) {
+
+      // Based on the fact that the signature is stateful, DoFnSignatures ensures
+      // that it is also keyed
+      ParDo.BoundMulti<KV<?, ?>, OutputT> keyedTransform =
+          (ParDo.BoundMulti<KV<?, ?>, OutputT>) transform;
+
+      return new GbkThenStatefulParDo(keyedTransform);
+    } else {
+      return transform;
+    }
+  }
+
+  static class GbkThenStatefulParDo<K, InputT, OutputT>
+      extends PTransform<PCollection<KV<K, InputT>>, PCollectionTuple> {
+    private final ParDo.BoundMulti<KV<K, InputT>, OutputT> underlyingParDo;
+
+    public GbkThenStatefulParDo(ParDo.BoundMulti<KV<K, InputT>, OutputT> underlyingParDo) {
+      this.underlyingParDo = underlyingParDo;
+    }
+
+    @Override
+    public PCollectionTuple apply(PCollection<KV<K, InputT>> input) {
+
+      PCollectionTuple outputs = input
+          .apply("Group by key", GroupByKey.<K, InputT>create())
+          .apply("Stateful ParDo", new StatefulParDo<>(underlyingParDo, input));
+
+      return outputs;
+    }
+  }
+
+  static class StatefulParDo<K, InputT, OutputT>
+      extends PTransform<PCollection<? extends KV<K, Iterable<InputT>>>, PCollectionTuple> {
+    private final transient ParDo.BoundMulti<KV<K, InputT>, OutputT> underlyingParDo;
+    private final transient PCollection<KV<K, InputT>> originalInput;
+
+    public StatefulParDo(
+        ParDo.BoundMulti<KV<K, InputT>, OutputT> underlyingParDo,
+        PCollection<KV<K, InputT>> originalInput) {
+      this.underlyingParDo = underlyingParDo;
+      this.originalInput = originalInput;
+    }
+
+    public ParDo.BoundMulti<KV<K, InputT>, OutputT> getUnderlyingParDo() {
+      return underlyingParDo;
+    }
+
+    @Override
+    public <T> Coder<T> getDefaultOutputCoder(
+        PCollection<? extends KV<K, Iterable<InputT>>> input, TypedPValue<T> output)
+        throws CannotProvideCoderException {
+      return underlyingParDo.getDefaultOutputCoder(originalInput, output);
+    }
+
+    public PCollectionTuple apply(PCollection<? extends KV<K, Iterable<InputT>>> input) {
+
+      PCollectionTuple outputs = PCollectionTuple.ofPrimitiveOutputsInternal(
+          input.getPipeline(),
+          TupleTagList.of(underlyingParDo.getMainOutputTag())
+              .and(underlyingParDo.getSideOutputTags().getAll()),
+          input.getWindowingStrategy(),
+          input.isBounded());
+
+      return outputs;
     }
   }
 }

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/ec2c0e06/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoSingleViaMultiOverrideFactory.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoSingleViaMultiOverrideFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoSingleViaMultiOverrideFactory.java
index ee3dfc5..f220a46 100644
--- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoSingleViaMultiOverrideFactory.java
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoSingleViaMultiOverrideFactory.java
@@ -54,13 +54,15 @@ class ParDoSingleViaMultiOverrideFactory<InputT, OutputT>
       // Output tags for ParDo need only be unique up to applied transform
       TupleTag<OutputT> mainOutputTag = new TupleTag<OutputT>(MAIN_OUTPUT_TAG);
 
-      PCollectionTuple output =
+      PCollectionTuple outputs =
           input.apply(
               ParDo.of(underlyingParDo.getNewFn())
                   .withSideInputs(underlyingParDo.getSideInputs())
                   .withOutputTags(mainOutputTag, TupleTagList.empty()));
+      PCollection<OutputT> output = outputs.get(mainOutputTag);
 
-      return output.get(mainOutputTag);
+      output.setTypeDescriptorInternal(underlyingParDo.getNewFn().getOutputTypeDescriptor());
+      return output;
     }
   }
 }

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/ec2c0e06/runners/direct-java/src/main/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactory.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactory.java
new file mode 100644
index 0000000..1f3286c
--- /dev/null
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactory.java
@@ -0,0 +1,256 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.direct;
+
+import com.google.auto.value.AutoValue;
+import com.google.common.cache.CacheBuilder;
+import com.google.common.cache.CacheLoader;
+import com.google.common.cache.LoadingCache;
+import com.google.common.collect.Lists;
+import java.util.Collections;
+import org.apache.beam.runners.direct.DirectExecutionContext.DirectStepContext;
+import org.apache.beam.runners.direct.DirectRunner.CommittedBundle;
+import org.apache.beam.runners.direct.ParDoMultiOverrideFactory.StatefulParDo;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.transforms.AppliedPTransform;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.transforms.reflect.DoFnSignature;
+import org.apache.beam.sdk.transforms.reflect.DoFnSignature.StateDeclaration;
+import org.apache.beam.sdk.transforms.reflect.DoFnSignatures;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.sdk.util.WindowingStrategy;
+import org.apache.beam.sdk.util.state.StateNamespace;
+import org.apache.beam.sdk.util.state.StateNamespaces;
+import org.apache.beam.sdk.util.state.StateSpec;
+import org.apache.beam.sdk.util.state.StateTag;
+import org.apache.beam.sdk.util.state.StateTags;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PCollectionTuple;
+
+/** A {@link TransformEvaluatorFactory} for stateful {@link ParDo}. */
+final class StatefulParDoEvaluatorFactory<K, InputT, OutputT> implements TransformEvaluatorFactory {
+
+  private final LoadingCache<AppliedPTransformOutputKeyAndWindow<K, InputT, OutputT>, Runnable>
+      cleanupRegistry;
+
+  private final ParDoEvaluatorFactory<KV<K, InputT>, OutputT> delegateFactory;
+
+  StatefulParDoEvaluatorFactory(EvaluationContext evaluationContext) {
+    this.delegateFactory = new ParDoEvaluatorFactory<>(evaluationContext);
+    this.cleanupRegistry =
+        CacheBuilder.newBuilder()
+            .weakValues()
+            .build(new CleanupSchedulingLoader(evaluationContext));
+  }
+
+  @Override
+  public <T> TransformEvaluator<T> forApplication(
+      AppliedPTransform<?, ?, ?> application, CommittedBundle<?> inputBundle) throws Exception {
+    @SuppressWarnings({"unchecked", "rawtypes"})
+    TransformEvaluator<T> evaluator =
+        (TransformEvaluator<T>)
+            createEvaluator((AppliedPTransform) application, (CommittedBundle) inputBundle);
+    return evaluator;
+  }
+
+  @Override
+  public void cleanup() throws Exception {
+    delegateFactory.cleanup();
+  }
+
+  @SuppressWarnings({"unchecked", "rawtypes"})
+  private TransformEvaluator<KV<K, Iterable<InputT>>> createEvaluator(
+      AppliedPTransform<
+              PCollection<? extends KV<K, Iterable<InputT>>>, PCollectionTuple,
+              StatefulParDo<K, InputT, OutputT>>
+          application,
+      CommittedBundle<KV<K, Iterable<InputT>>> inputBundle)
+      throws Exception {
+
+    final DoFn<KV<K, InputT>, OutputT> doFn =
+        application.getTransform().getUnderlyingParDo().getNewFn();
+    final DoFnSignature signature = DoFnSignatures.getSignature(doFn.getClass());
+
+    // If the DoFn is stateful, schedule state clearing.
+    // It is semantically correct to schedule any number of redundant clear tasks; the
+    // cache is used to limit the number of tasks to avoid performance degradation.
+    if (signature.stateDeclarations().size() > 0) {
+      for (final WindowedValue<?> element : inputBundle.getElements()) {
+        for (final BoundedWindow window : element.getWindows()) {
+          cleanupRegistry.get(
+              AppliedPTransformOutputKeyAndWindow.create(
+                  application, (StructuralKey<K>) inputBundle.getKey(), window));
+        }
+      }
+    }
+
+    TransformEvaluator<KV<K, InputT>> delegateEvaluator =
+        delegateFactory.createEvaluator(
+            (AppliedPTransform) application,
+            inputBundle.getKey(),
+            doFn,
+            application.getTransform().getUnderlyingParDo().getSideInputs(),
+            application.getTransform().getUnderlyingParDo().getMainOutputTag(),
+            application.getTransform().getUnderlyingParDo().getSideOutputTags().getAll());
+
+    return new StatefulParDoEvaluator<>(delegateEvaluator);
+  }
+
+  private class CleanupSchedulingLoader
+      extends CacheLoader<AppliedPTransformOutputKeyAndWindow<K, InputT, OutputT>, Runnable> {
+
+    private final EvaluationContext evaluationContext;
+
+    public CleanupSchedulingLoader(EvaluationContext evaluationContext) {
+      this.evaluationContext = evaluationContext;
+    }
+
+    @Override
+    public Runnable load(
+        final AppliedPTransformOutputKeyAndWindow<K, InputT, OutputT> transformOutputWindow) {
+      String stepName = evaluationContext.getStepName(transformOutputWindow.getTransform());
+
+      PCollection<?> pc =
+          transformOutputWindow
+              .getTransform()
+              .getOutput()
+              .get(
+                  transformOutputWindow
+                      .getTransform()
+                      .getTransform()
+                      .getUnderlyingParDo()
+                      .getMainOutputTag());
+      WindowingStrategy<?, ?> windowingStrategy = pc.getWindowingStrategy();
+      BoundedWindow window = transformOutputWindow.getWindow();
+      final DoFn<?, ?> doFn =
+          transformOutputWindow.getTransform().getTransform().getUnderlyingParDo().getNewFn();
+      final DoFnSignature signature = DoFnSignatures.getSignature(doFn.getClass());
+
+      final DirectStepContext stepContext =
+          evaluationContext
+              .getExecutionContext(
+                  transformOutputWindow.getTransform(), transformOutputWindow.getKey())
+              .getOrCreateStepContext(stepName, stepName);
+
+      final StateNamespace namespace =
+          StateNamespaces.window(
+              (Coder<BoundedWindow>) windowingStrategy.getWindowFn().windowCoder(), window);
+
+      Runnable cleanup =
+          new Runnable() {
+            @Override
+            public void run() {
+              for (StateDeclaration stateDecl : signature.stateDeclarations().values()) {
+                StateTag<Object, ?> tag;
+                try {
+                  tag =
+                      StateTags.tagForSpec(stateDecl.id(), (StateSpec) stateDecl.field().get(doFn));
+                } catch (IllegalAccessException e) {
+                  throw new RuntimeException(
+                      String.format(
+                          "Error accessing %s for %s",
+                          StateSpec.class.getName(), doFn.getClass().getName()),
+                      e);
+                }
+                stepContext.stateInternals().state(namespace, tag).clear();
+              }
+              cleanupRegistry.invalidate(transformOutputWindow);
+            }
+          };
+
+      evaluationContext.scheduleAfterWindowExpiration(
+          transformOutputWindow.getTransform(), window, windowingStrategy, cleanup);
+      return cleanup;
+    }
+  }
+
+  @AutoValue
+  abstract static class AppliedPTransformOutputKeyAndWindow<K, InputT, OutputT> {
+    abstract AppliedPTransform<
+            PCollection<? extends KV<K, Iterable<InputT>>>, PCollectionTuple,
+            StatefulParDo<K, InputT, OutputT>>
+        getTransform();
+
+    abstract StructuralKey<K> getKey();
+
+    abstract BoundedWindow getWindow();
+
+    static <K, InputT, OutputT> AppliedPTransformOutputKeyAndWindow<K, InputT, OutputT> create(
+        AppliedPTransform<
+                PCollection<? extends KV<K, Iterable<InputT>>>, PCollectionTuple,
+                StatefulParDo<K, InputT, OutputT>>
+            transform,
+        StructuralKey<K> key,
+        BoundedWindow w) {
+      return new AutoValue_StatefulParDoEvaluatorFactory_AppliedPTransformOutputKeyAndWindow<>(
+          transform, key, w);
+    }
+  }
+
+  private static class StatefulParDoEvaluator<K, InputT>
+      implements TransformEvaluator<KV<K, Iterable<InputT>>> {
+
+    private final TransformEvaluator<KV<K, InputT>> delegateEvaluator;
+
+    public StatefulParDoEvaluator(TransformEvaluator<KV<K, InputT>> delegateEvaluator) {
+      this.delegateEvaluator = delegateEvaluator;
+    }
+
+    @Override
+    public void processElement(WindowedValue<KV<K, Iterable<InputT>>> gbkResult) throws Exception {
+
+      for (InputT value : gbkResult.getValue().getValue()) {
+        delegateEvaluator.processElement(
+            gbkResult.withValue(KV.of(gbkResult.getValue().getKey(), value)));
+      }
+    }
+
+    @Override
+    public TransformResult<KV<K, Iterable<InputT>>> finishBundle() throws Exception {
+      TransformResult<KV<K, InputT>> delegateResult = delegateEvaluator.finishBundle();
+
+      StepTransformResult.Builder<KV<K, Iterable<InputT>>> regroupedResult =
+          StepTransformResult.<KV<K, Iterable<InputT>>>withHold(
+                  delegateResult.getTransform(), delegateResult.getWatermarkHold())
+              .withTimerUpdate(delegateResult.getTimerUpdate())
+              .withAggregatorChanges(delegateResult.getAggregatorChanges())
+              .withMetricUpdates(delegateResult.getLogicalMetricUpdates())
+              .addOutput(Lists.newArrayList(delegateResult.getOutputBundles()));
+
+      // The delegate may have pushed back unprocessed elements across multiple keys and windows.
+      // Since processing is single-threaded per key and window, we don't need to regroup the
+      // outputs, but just make a bunch of singletons
+      for (WindowedValue<?> untypedUnprocessed : delegateResult.getUnprocessedElements()) {
+        WindowedValue<KV<K, InputT>> windowedKv = (WindowedValue<KV<K, InputT>>) untypedUnprocessed;
+        WindowedValue<KV<K, Iterable<InputT>>> pushedBack =
+            windowedKv.withValue(
+                KV.of(
+                    windowedKv.getValue().getKey(),
+                    (Iterable<InputT>)
+                        Collections.singletonList(windowedKv.getValue().getValue())));
+
+        regroupedResult.addUnprocessedElements(pushedBack);
+      }
+
+      return regroupedResult.build();
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/ec2c0e06/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TransformEvaluatorRegistry.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TransformEvaluatorRegistry.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TransformEvaluatorRegistry.java
index 0514c3a..a4c462a 100644
--- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TransformEvaluatorRegistry.java
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TransformEvaluatorRegistry.java
@@ -28,6 +28,7 @@ import java.util.concurrent.atomic.AtomicBoolean;
 import org.apache.beam.runners.direct.DirectGroupByKey.DirectGroupAlsoByWindow;
 import org.apache.beam.runners.direct.DirectGroupByKey.DirectGroupByKeyOnly;
 import org.apache.beam.runners.direct.DirectRunner.CommittedBundle;
+import org.apache.beam.runners.direct.ParDoMultiOverrideFactory.StatefulParDo;
 import org.apache.beam.sdk.io.Read;
 import org.apache.beam.sdk.transforms.AppliedPTransform;
 import org.apache.beam.sdk.transforms.Flatten.FlattenPCollectionList;
@@ -50,6 +51,7 @@ class TransformEvaluatorRegistry implements TransformEvaluatorFactory {
             .put(Read.Bounded.class, new BoundedReadEvaluatorFactory(ctxt))
             .put(Read.Unbounded.class, new UnboundedReadEvaluatorFactory(ctxt))
             .put(ParDo.BoundMulti.class, new ParDoEvaluatorFactory<>(ctxt))
+            .put(StatefulParDo.class, new StatefulParDoEvaluatorFactory<>(ctxt))
             .put(FlattenPCollectionList.class, new FlattenEvaluatorFactory(ctxt))
             .put(ViewEvaluatorFactory.WriteView.class, new ViewEvaluatorFactory(ctxt))
             .put(Window.Bound.class, new WindowEvaluatorFactory(ctxt))

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/ec2c0e06/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WatermarkCallbackExecutor.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WatermarkCallbackExecutor.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WatermarkCallbackExecutor.java
index 54cab7c..fcefc5f 100644
--- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WatermarkCallbackExecutor.java
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WatermarkCallbackExecutor.java
@@ -89,6 +89,32 @@ class WatermarkCallbackExecutor {
   }
 
   /**
+   * Execute the provided {@link Runnable} after the next call to
+   * {@link #fireForWatermark(AppliedPTransform, Instant)} where the window
+   * is guaranteed to be expired.
+   */
+  public void callOnWindowExpiration(
+      AppliedPTransform<?, ?, ?> step,
+      BoundedWindow window,
+      WindowingStrategy<?, ?> windowingStrategy,
+      Runnable runnable) {
+    WatermarkCallback callback =
+        WatermarkCallback.afterWindowExpiration(window, windowingStrategy, runnable);
+
+    PriorityQueue<WatermarkCallback> callbackQueue = callbacks.get(step);
+    if (callbackQueue == null) {
+      callbackQueue = new PriorityQueue<>(11, new CallbackOrdering());
+      if (callbacks.putIfAbsent(step, callbackQueue) != null) {
+        callbackQueue = callbacks.get(step);
+      }
+    }
+
+    synchronized (callbackQueue) {
+      callbackQueue.offer(callback);
+    }
+  }
+
+  /**
    * Schedule all pending callbacks that must have produced output by the time of the provided
    * watermark.
    */
@@ -112,6 +138,14 @@ class WatermarkCallbackExecutor {
       return new WatermarkCallback(firingAfter, callback);
     }
 
+    public static <W extends BoundedWindow> WatermarkCallback afterWindowExpiration(
+        BoundedWindow window, WindowingStrategy<?, W> strategy, Runnable callback) {
+      // Fire one milli past the end of the window. This ensures that all window expiration
+      // timers are delivered first
+      Instant firingAfter = window.maxTimestamp().plus(strategy.getAllowedLateness()).plus(1L);
+      return new WatermarkCallback(firingAfter, callback);
+    }
+
     private final Instant fireAfter;
     private final Runnable callback;
 

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/ec2c0e06/runners/direct-java/src/test/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactoryTest.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactoryTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactoryTest.java
new file mode 100644
index 0000000..ecf11ed
--- /dev/null
+++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactoryTest.java
@@ -0,0 +1,300 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.direct;
+
+import static org.hamcrest.Matchers.containsInAnyOrder;
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.nullValue;
+import static org.junit.Assert.assertThat;
+import static org.mockito.Matchers.anyList;
+import static org.mockito.Matchers.anyString;
+import static org.mockito.Matchers.eq;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+import com.google.common.collect.Iterables;
+import com.google.common.collect.Lists;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.List;
+import org.apache.beam.runners.direct.DirectRunner.CommittedBundle;
+import org.apache.beam.runners.direct.DirectRunner.UncommittedBundle;
+import org.apache.beam.runners.direct.ParDoMultiOverrideFactory.StatefulParDo;
+import org.apache.beam.runners.direct.WatermarkManager.TimerUpdate;
+import org.apache.beam.sdk.coders.StringUtf8Coder;
+import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.transforms.AppliedPTransform;
+import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.transforms.View;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.transforms.windowing.FixedWindows;
+import org.apache.beam.sdk.transforms.windowing.IntervalWindow;
+import org.apache.beam.sdk.transforms.windowing.PaneInfo;
+import org.apache.beam.sdk.transforms.windowing.Window;
+import org.apache.beam.sdk.util.ReadyCheckingSideInputReader;
+import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.sdk.util.WindowingStrategy;
+import org.apache.beam.sdk.util.state.CopyOnAccessInMemoryStateInternals;
+import org.apache.beam.sdk.util.state.StateInternals;
+import org.apache.beam.sdk.util.state.StateNamespace;
+import org.apache.beam.sdk.util.state.StateNamespaces;
+import org.apache.beam.sdk.util.state.StateSpec;
+import org.apache.beam.sdk.util.state.StateSpecs;
+import org.apache.beam.sdk.util.state.StateTag;
+import org.apache.beam.sdk.util.state.StateTags;
+import org.apache.beam.sdk.util.state.ValueState;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PCollectionTuple;
+import org.apache.beam.sdk.values.PCollectionView;
+import org.joda.time.Duration;
+import org.joda.time.Instant;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+import org.mockito.ArgumentCaptor;
+import org.mockito.Matchers;
+import org.mockito.Mock;
+import org.mockito.Mockito;
+import org.mockito.MockitoAnnotations;
+
+/** Tests for {@link StatefulParDoEvaluatorFactory}. */
+@RunWith(JUnit4.class)
+public class StatefulParDoEvaluatorFactoryTest implements Serializable {
+  @Mock private transient EvaluationContext mockEvaluationContext;
+  @Mock private transient DirectExecutionContext mockExecutionContext;
+  @Mock private transient DirectExecutionContext.DirectStepContext mockStepContext;
+  @Mock private transient ReadyCheckingSideInputReader mockSideInputReader;
+  @Mock private transient UncommittedBundle<Integer> mockUncommittedBundle;
+
+  private static final String KEY = "any-key";
+  private transient StateInternals<Object> stateInternals =
+      CopyOnAccessInMemoryStateInternals.<Object>withUnderlying(KEY, null);
+
+  private static final BundleFactory BUNDLE_FACTORY = ImmutableListBundleFactory.create();
+
+  @Before
+  public void setup() {
+    MockitoAnnotations.initMocks(this);
+    when((StateInternals<Object>) mockStepContext.stateInternals()).thenReturn(stateInternals);
+  }
+
+  @Test
+  public void windowCleanupScheduled() throws Exception {
+    // To test the factory, first we set up a pipeline and then we use the constructed
+    // pipeline to create the right parameters to pass to the factory
+    TestPipeline pipeline = TestPipeline.create();
+
+    final String stateId = "my-state-id";
+
+    // For consistency, window it into FixedWindows. Actually we will fabricate an input bundle.
+    PCollection<KV<String, Integer>> input =
+        pipeline
+            .apply(Create.of(KV.of("hello", 1), KV.of("hello", 2)))
+            .apply(Window.<KV<String, Integer>>into(FixedWindows.of(Duration.millis(10))));
+
+    PCollection<Integer> produced =
+        input.apply(
+            ParDo.of(
+                new DoFn<KV<String, Integer>, Integer>() {
+                  @StateId(stateId)
+                  private final StateSpec<Object, ValueState<String>> spec =
+                      StateSpecs.value(StringUtf8Coder.of());
+
+                  @ProcessElement
+                  public void process(ProcessContext c) {}
+                }));
+
+    StatefulParDoEvaluatorFactory<String, Integer, Integer> factory =
+        new StatefulParDoEvaluatorFactory(mockEvaluationContext);
+
+    AppliedPTransform<
+            PCollection<? extends KV<String, Iterable<Integer>>>, PCollectionTuple,
+            StatefulParDo<String, Integer, Integer>>
+        producingTransform = (AppliedPTransform) produced.getProducingTransformInternal();
+
+    // Then there will be a digging down to the step context to get the state internals
+    when(mockEvaluationContext.getExecutionContext(
+            eq(producingTransform), Mockito.<StructuralKey>any()))
+        .thenReturn(mockExecutionContext);
+    when(mockExecutionContext.getOrCreateStepContext(anyString(), anyString()))
+        .thenReturn(mockStepContext);
+
+    IntervalWindow firstWindow = new IntervalWindow(new Instant(0), new Instant(9));
+    IntervalWindow secondWindow = new IntervalWindow(new Instant(10), new Instant(19));
+
+    StateNamespace firstWindowNamespace =
+        StateNamespaces.window(IntervalWindow.getCoder(), firstWindow);
+    StateNamespace secondWindowNamespace =
+        StateNamespaces.window(IntervalWindow.getCoder(), secondWindow);
+    StateTag<Object, ValueState<String>> tag =
+        StateTags.tagForSpec(stateId, StateSpecs.value(StringUtf8Coder.of()));
+
+    // Set up non-empty state. We don't mock + verify calls to clear() but instead
+    // check that state is actually empty. We musn't care how it is accomplished.
+    stateInternals.state(firstWindowNamespace, tag).write("first");
+    stateInternals.state(secondWindowNamespace, tag).write("second");
+
+    // A single bundle with some elements in the global window; it should register cleanup for the
+    // global window state merely by having the evaluator created. The cleanup logic does not
+    // depend on the window.
+    CommittedBundle<KV<String, Integer>> inputBundle =
+        BUNDLE_FACTORY
+            .createBundle(input)
+            .add(
+                WindowedValue.of(
+                    KV.of("hello", 1), new Instant(3), firstWindow, PaneInfo.NO_FIRING))
+            .add(
+                WindowedValue.of(
+                    KV.of("hello", 2), new Instant(11), secondWindow, PaneInfo.NO_FIRING))
+            .commit(Instant.now());
+
+    // Merely creating the evaluator should suffice to register the cleanup callback
+    factory.forApplication(producingTransform, inputBundle);
+
+    ArgumentCaptor<Runnable> argumentCaptor = ArgumentCaptor.forClass(Runnable.class);
+    verify(mockEvaluationContext)
+        .scheduleAfterWindowExpiration(
+            eq(producingTransform),
+            eq(firstWindow),
+            Mockito.<WindowingStrategy<?, ?>>any(),
+            argumentCaptor.capture());
+
+    // Should actually clear the state for the first window
+    argumentCaptor.getValue().run();
+    assertThat(stateInternals.state(firstWindowNamespace, tag).read(), nullValue());
+    assertThat(stateInternals.state(secondWindowNamespace, tag).read(), equalTo("second"));
+
+    verify(mockEvaluationContext)
+        .scheduleAfterWindowExpiration(
+            eq(producingTransform),
+            eq(secondWindow),
+            Mockito.<WindowingStrategy<?, ?>>any(),
+            argumentCaptor.capture());
+
+    // Should actually clear the state for the second window
+    argumentCaptor.getValue().run();
+    assertThat(stateInternals.state(secondWindowNamespace, tag).read(), nullValue());
+  }
+
+  /**
+   * A test that explicitly delays a side input so that the main input will have to be reprocessed,
+   * testing that {@code finishBundle()} re-assembles the GBK outputs correctly.
+   */
+  @Test
+  public void testUnprocessedElements() throws Exception {
+    // To test the factory, first we set up a pipeline and then we use the constructed
+    // pipeline to create the right parameters to pass to the factory
+    TestPipeline pipeline = TestPipeline.create();
+
+    final String stateId = "my-state-id";
+
+    // For consistency, window it into FixedWindows. Actually we will fabricate an input bundle.
+    PCollection<KV<String, Integer>> mainInput =
+        pipeline
+            .apply(Create.of(KV.of("hello", 1), KV.of("hello", 2)))
+            .apply(Window.<KV<String, Integer>>into(FixedWindows.of(Duration.millis(10))));
+
+    final PCollectionView<List<Integer>> sideInput =
+        pipeline
+            .apply("Create side input", Create.of(42))
+            .apply("Window side input", Window.<Integer>into(FixedWindows.of(Duration.millis(10))))
+            .apply("View side input", View.<Integer>asList());
+
+    PCollection<Integer> produced =
+        mainInput.apply(
+            ParDo.withSideInputs(sideInput)
+                .of(
+                    new DoFn<KV<String, Integer>, Integer>() {
+                      @StateId(stateId)
+                      private final StateSpec<Object, ValueState<String>> spec =
+                          StateSpecs.value(StringUtf8Coder.of());
+
+                      @ProcessElement
+                      public void process(ProcessContext c) {}
+                    }));
+
+    StatefulParDoEvaluatorFactory<String, Integer, Integer> factory =
+        new StatefulParDoEvaluatorFactory(mockEvaluationContext);
+
+    // This will be the stateful ParDo from the expansion
+    AppliedPTransform<
+            PCollection<KV<String, Iterable<Integer>>>, PCollectionTuple,
+            StatefulParDo<String, Integer, Integer>>
+        producingTransform = (AppliedPTransform) produced.getProducingTransformInternal();
+
+    // Then there will be a digging down to the step context to get the state internals
+    when(mockEvaluationContext.getExecutionContext(
+            eq(producingTransform), Mockito.<StructuralKey>any()))
+        .thenReturn(mockExecutionContext);
+    when(mockExecutionContext.getOrCreateStepContext(anyString(), anyString()))
+        .thenReturn(mockStepContext);
+    when(mockEvaluationContext.createBundle(Matchers.<PCollection<Integer>>any()))
+        .thenReturn(mockUncommittedBundle);
+    when(mockStepContext.getTimerUpdate()).thenReturn(TimerUpdate.empty());
+
+    // And digging to check whether the window is ready
+    when(mockEvaluationContext.createSideInputReader(anyList())).thenReturn(mockSideInputReader);
+    when(mockSideInputReader.isReady(
+            Matchers.<PCollectionView<?>>any(), Matchers.<BoundedWindow>any()))
+        .thenReturn(false);
+
+    IntervalWindow firstWindow = new IntervalWindow(new Instant(0), new Instant(9));
+
+    // A single bundle with some elements in the global window; it should register cleanup for the
+    // global window state merely by having the evaluator created. The cleanup logic does not
+    // depend on the window.
+    WindowedValue<KV<String, Iterable<Integer>>> gbkOutputElement =
+        WindowedValue.of(
+            KV.<String, Iterable<Integer>>of("hello", Lists.newArrayList(1, 13, 15)),
+            new Instant(3),
+            firstWindow,
+            PaneInfo.NO_FIRING);
+    CommittedBundle<KV<String, Iterable<Integer>>> inputBundle =
+        BUNDLE_FACTORY
+            .createBundle(producingTransform.getInput())
+            .add(gbkOutputElement)
+            .commit(Instant.now());
+    TransformEvaluator<KV<String, Iterable<Integer>>> evaluator =
+        factory.forApplication(producingTransform, inputBundle);
+    evaluator.processElement(gbkOutputElement);
+
+    // This should push back every element as a KV<String, Iterable<Integer>>
+    // in the appropriate window. Since the keys are equal they are single-threaded
+    TransformResult<KV<String, Iterable<Integer>>> result = evaluator.finishBundle();
+
+    List<Integer> pushedBackInts = new ArrayList<>();
+
+    for (WindowedValue<?> unprocessedElement : result.getUnprocessedElements()) {
+      WindowedValue<KV<String, Iterable<Integer>>> unprocessedKv =
+          (WindowedValue<KV<String, Iterable<Integer>>>) unprocessedElement;
+
+      assertThat(
+          Iterables.getOnlyElement(unprocessedElement.getWindows()),
+          equalTo((BoundedWindow) firstWindow));
+      assertThat(unprocessedKv.getValue().getKey(), equalTo("hello"));
+      for (Integer i : unprocessedKv.getValue().getValue()) {
+        pushedBackInts.add(i);
+      }
+    }
+    assertThat(pushedBackInts, containsInAnyOrder(1, 13, 15));
+  }
+}

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/ec2c0e06/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java
index 221d942..3f1a3f9 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java
@@ -315,7 +315,7 @@ public abstract class DoFn<InputT, OutputT> implements Serializable, HasDisplayD
    *
    * <p>See {@link #getOutputTypeDescriptor} for more discussion.
    */
-  protected TypeDescriptor<InputT> getInputTypeDescriptor() {
+  public TypeDescriptor<InputT> getInputTypeDescriptor() {
     return new TypeDescriptor<InputT>(getClass()) {};
   }
 
@@ -330,7 +330,7 @@ public abstract class DoFn<InputT, OutputT> implements Serializable, HasDisplayD
    * for choosing a default output {@code Coder<O>} for the output
    * {@code PCollection<O>}.
    */
-  protected TypeDescriptor<OutputT> getOutputTypeDescriptor() {
+  public TypeDescriptor<OutputT> getOutputTypeDescriptor() {
     return new TypeDescriptor<OutputT>(getClass()) {};
   }
 

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/ec2c0e06/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/OldDoFn.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/OldDoFn.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/OldDoFn.java
index 9bf9003..2d2c1fd 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/OldDoFn.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/OldDoFn.java
@@ -671,7 +671,7 @@ public abstract class OldDoFn<InputT, OutputT> implements Serializable, HasDispl
     }
 
     @Override
-    protected TypeDescriptor<InputT> getInputTypeDescriptor() {
+    public TypeDescriptor<InputT> getInputTypeDescriptor() {
       return OldDoFn.this.getInputTypeDescriptor();
     }
 
@@ -681,7 +681,7 @@ public abstract class OldDoFn<InputT, OutputT> implements Serializable, HasDispl
     }
 
     @Override
-    protected TypeDescriptor<OutputT> getOutputTypeDescriptor() {
+    public TypeDescriptor<OutputT> getOutputTypeDescriptor() {
       return OldDoFn.this.getOutputTypeDescriptor();
     }
   }
@@ -746,12 +746,12 @@ public abstract class OldDoFn<InputT, OutputT> implements Serializable, HasDispl
     }
 
     @Override
-    protected TypeDescriptor<InputT> getInputTypeDescriptor() {
+    public TypeDescriptor<InputT> getInputTypeDescriptor() {
       return OldDoFn.this.getInputTypeDescriptor();
     }
 
     @Override
-    protected TypeDescriptor<OutputT> getOutputTypeDescriptor() {
+    public TypeDescriptor<OutputT> getOutputTypeDescriptor() {
       return OldDoFn.this.getOutputTypeDescriptor();
     }
   }