You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by ke...@apache.org on 2016/12/20 20:40:25 UTC

[4/5] incubator-beam git commit: Port direct runner StatefulParDo to KeyedWorkItem

Port direct runner StatefulParDo to KeyedWorkItem


Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/1f018ab6
Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/1f018ab6
Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/1f018ab6

Branch: refs/heads/master
Commit: 1f018ab69fdcc720a10e2aeb8ec1eea1c06e1cbc
Parents: d040b7f
Author: Kenneth Knowles <kl...@google.com>
Authored: Mon Dec 12 19:49:58 2016 -0800
Committer: Kenneth Knowles <kl...@google.com>
Committed: Tue Dec 20 11:19:07 2016 -0800

----------------------------------------------------------------------
 .../direct/KeyedPValueTrackingVisitor.java      | 13 ++-
 .../direct/ParDoMultiOverrideFactory.java       | 94 +++++++++++++++++---
 .../direct/StatefulParDoEvaluatorFactory.java   | 36 ++++----
 .../direct/KeyedPValueTrackingVisitorTest.java  | 69 ++++++++++++--
 .../StatefulParDoEvaluatorFactoryTest.java      | 51 +++++++----
 5 files changed, 205 insertions(+), 58 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/1f018ab6/runners/direct-java/src/main/java/org/apache/beam/runners/direct/KeyedPValueTrackingVisitor.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/KeyedPValueTrackingVisitor.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/KeyedPValueTrackingVisitor.java
index e91a768..65c41e0 100644
--- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/KeyedPValueTrackingVisitor.java
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/KeyedPValueTrackingVisitor.java
@@ -31,6 +31,7 @@ import org.apache.beam.sdk.Pipeline.PipelineVisitor;
 import org.apache.beam.sdk.runners.TransformHierarchy;
 import org.apache.beam.sdk.transforms.GroupByKey;
 import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.transforms.ParDo;
 import org.apache.beam.sdk.values.PValue;
 
 /**
@@ -105,7 +106,15 @@ class KeyedPValueTrackingVisitor implements PipelineVisitor {
   }
 
   private static boolean isKeyPreserving(PTransform<?, ?> transform) {
-    // There are currently no key-preserving transforms; this lays the infrastructure for them
-    return false;
+    // This is a hacky check for what is considered key-preserving to the direct runner.
+    // The most obvious alternative would be a package-private marker interface, but
+    // better to make this obviously hacky so it is less likely to proliferate. Meanwhile
+    // we intend to allow explicit expression of key-preserving DoFn in the model.
+    if (transform instanceof ParDo.BoundMulti) {
+      ParDo.BoundMulti<?, ?> parDo = (ParDo.BoundMulti<?, ?>) transform;
+      return parDo.getFn() instanceof ParDoMultiOverrideFactory.ToKeyedWorkItem;
+    } else {
+      return false;
+    }
   }
 }

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/1f018ab6/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java
index c5bc069..2cea999 100644
--- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java
@@ -17,9 +17,15 @@
  */
 package org.apache.beam.runners.direct;
 
+import static com.google.common.base.Preconditions.checkState;
+
+import org.apache.beam.runners.core.KeyedWorkItem;
+import org.apache.beam.runners.core.KeyedWorkItemCoder;
+import org.apache.beam.runners.core.KeyedWorkItems;
 import org.apache.beam.runners.core.SplittableParDo;
 import org.apache.beam.sdk.coders.CannotProvideCoderException;
 import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.KvCoder;
 import org.apache.beam.sdk.runners.PTransformOverrideFactory;
 import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.GroupByKey;
@@ -28,6 +34,8 @@ import org.apache.beam.sdk.transforms.ParDo;
 import org.apache.beam.sdk.transforms.ParDo.BoundMulti;
 import org.apache.beam.sdk.transforms.reflect.DoFnSignature;
 import org.apache.beam.sdk.transforms.reflect.DoFnSignatures;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.util.WindowedValue;
 import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.PCollectionTuple;
@@ -84,16 +92,41 @@ class ParDoMultiOverrideFactory<InputT, OutputT>
     @Override
     public PCollectionTuple expand(PCollection<KV<K, InputT>> input) {
 
-      PCollectionTuple outputs = input
-          .apply("Group by key", GroupByKey.<K, InputT>create())
-          .apply("Stateful ParDo", new StatefulParDo<>(underlyingParDo, input));
+      // A KvCoder is required since this goes through GBK. Further, WindowedValueCoder
+      // is not registered by default, so we explicitly set the relevant coders.
+      checkState(input.getCoder() instanceof KvCoder,
+          "Input to a %s using state requires a %s, but the coder was %s",
+          ParDo.class.getSimpleName(),
+          KvCoder.class.getSimpleName(),
+          input.getCoder());
+      KvCoder<K, InputT> kvCoder = (KvCoder<K, InputT>) input.getCoder();
+      Coder<K> keyCoder = kvCoder.getKeyCoder();
+      Coder<? extends BoundedWindow> windowCoder =
+          input.getWindowingStrategy().getWindowFn().windowCoder();
+
+      PCollectionTuple outputs =
+          input
+              // Stash the original timestamps, etc, for when it is fed to the user's DoFn
+              .apply("Reify timestamps", ParDo.of(new ReifyWindowedValueFn<K, InputT>()))
+              .setCoder(KvCoder.of(keyCoder, WindowedValue.getFullCoder(kvCoder, windowCoder)))
+
+              // A full GBK to group by key _and_ window
+              .apply("Group by key", GroupByKey.<K, WindowedValue<KV<K, InputT>>>create())
+
+              // Adapt to KeyedWorkItem; that is how this runner delivers timers
+              .apply("To KeyedWorkItem", ParDo.of(new ToKeyedWorkItem<K, InputT>()))
+              .setCoder(KeyedWorkItemCoder.of(keyCoder, kvCoder, windowCoder))
+
+              // Explode the resulting iterable into elements that are exactly the ones from
+              // the input
+              .apply("Stateful ParDo", new StatefulParDo<>(underlyingParDo, input));
 
       return outputs;
     }
   }
 
   static class StatefulParDo<K, InputT, OutputT>
-      extends PTransform<PCollection<? extends KV<K, Iterable<InputT>>>, PCollectionTuple> {
+      extends PTransform<PCollection<? extends KeyedWorkItem<K, KV<K, InputT>>>, PCollectionTuple> {
     private final transient ParDo.BoundMulti<KV<K, InputT>, OutputT> underlyingParDo;
     private final transient PCollection<KV<K, InputT>> originalInput;
 
@@ -110,21 +143,58 @@ class ParDoMultiOverrideFactory<InputT, OutputT>
 
     @Override
     public <T> Coder<T> getDefaultOutputCoder(
-        PCollection<? extends KV<K, Iterable<InputT>>> input, TypedPValue<T> output)
+        PCollection<? extends KeyedWorkItem<K, KV<K, InputT>>> input, TypedPValue<T> output)
         throws CannotProvideCoderException {
       return underlyingParDo.getDefaultOutputCoder(originalInput, output);
     }
 
-    public PCollectionTuple expand(PCollection<? extends KV<K, Iterable<InputT>>> input) {
+    @Override
+    public PCollectionTuple expand(PCollection<? extends KeyedWorkItem<K, KV<K, InputT>>> input) {
 
-      PCollectionTuple outputs = PCollectionTuple.ofPrimitiveOutputsInternal(
-          input.getPipeline(),
-          TupleTagList.of(underlyingParDo.getMainOutputTag())
-              .and(underlyingParDo.getSideOutputTags().getAll()),
-          input.getWindowingStrategy(),
-          input.isBounded());
+      PCollectionTuple outputs =
+          PCollectionTuple.ofPrimitiveOutputsInternal(
+              input.getPipeline(),
+              TupleTagList.of(underlyingParDo.getMainOutputTag())
+                  .and(underlyingParDo.getSideOutputTags().getAll()),
+              input.getWindowingStrategy(),
+              input.isBounded());
 
       return outputs;
     }
   }
+
+  /**
+   * A distinguished key-preserving {@link DoFn}.
+   *
+   * <p>This wraps the {@link GroupByKey} output in a {@link KeyedWorkItem} to be able to deliver
+   * timers. It also explodes them into single {@link KV KVs} since this is what the user's {@link
+   * DoFn} needs to process anyhow.
+   */
+  static class ReifyWindowedValueFn<K, V> extends DoFn<KV<K, V>, KV<K, WindowedValue<KV<K, V>>>> {
+    @ProcessElement
+    public void processElement(final ProcessContext c, final BoundedWindow window) {
+      c.output(
+          KV.of(
+              c.element().getKey(),
+              WindowedValue.of(c.element(), c.timestamp(), window, c.pane())));
+    }
+  }
+
+  /**
+   * A runner-specific primitive that is just a key-preserving {@link ParDo}, but we do not have the
+   * machinery to detect or enforce that yet.
+   *
+   * <p>This wraps the {@link GroupByKey} output in a {@link KeyedWorkItem} to be able to deliver
+   * timers. It also explodes them into single {@link KV KVs} since this is what the user's {@link
+   * DoFn} needs to process anyhow.
+   */
+  static class ToKeyedWorkItem<K, V>
+      extends DoFn<KV<K, Iterable<WindowedValue<KV<K, V>>>>, KeyedWorkItem<K, KV<K, V>>> {
+
+    @ProcessElement
+    public void processElement(final ProcessContext c, final BoundedWindow window) {
+      final K key = c.element().getKey();
+      c.output(KeyedWorkItems.elementsWorkItem(key, c.element().getValue()));
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/1f018ab6/runners/direct-java/src/main/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactory.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactory.java
index 1f64d9a..5f9d8f4 100644
--- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactory.java
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactory.java
@@ -23,6 +23,8 @@ import com.google.common.cache.CacheLoader;
 import com.google.common.cache.LoadingCache;
 import com.google.common.collect.Lists;
 import java.util.Collections;
+import org.apache.beam.runners.core.KeyedWorkItem;
+import org.apache.beam.runners.core.KeyedWorkItems;
 import org.apache.beam.runners.direct.DirectExecutionContext.DirectStepContext;
 import org.apache.beam.runners.direct.DirectRunner.CommittedBundle;
 import org.apache.beam.runners.direct.ParDoMultiOverrideFactory.StatefulParDo;
@@ -77,12 +79,12 @@ final class StatefulParDoEvaluatorFactory<K, InputT, OutputT> implements Transfo
   }
 
   @SuppressWarnings({"unchecked", "rawtypes"})
-  private TransformEvaluator<KV<K, Iterable<InputT>>> createEvaluator(
+  private TransformEvaluator<KeyedWorkItem<K, KV<K, InputT>>> createEvaluator(
       AppliedPTransform<
-              PCollection<? extends KV<K, Iterable<InputT>>>, PCollectionTuple,
+              PCollection<? extends KeyedWorkItem<K, KV<K, InputT>>>, PCollectionTuple,
               StatefulParDo<K, InputT, OutputT>>
           application,
-      CommittedBundle<KV<K, Iterable<InputT>>> inputBundle)
+      CommittedBundle<KeyedWorkItem<K, KV<K, InputT>>> inputBundle)
       throws Exception {
 
     final DoFn<KV<K, InputT>, OutputT> doFn =
@@ -185,7 +187,7 @@ final class StatefulParDoEvaluatorFactory<K, InputT, OutputT> implements Transfo
   @AutoValue
   abstract static class AppliedPTransformOutputKeyAndWindow<K, InputT, OutputT> {
     abstract AppliedPTransform<
-            PCollection<? extends KV<K, Iterable<InputT>>>, PCollectionTuple,
+            PCollection<? extends KeyedWorkItem<K, KV<K, InputT>>>, PCollectionTuple,
             StatefulParDo<K, InputT, OutputT>>
         getTransform();
 
@@ -195,7 +197,7 @@ final class StatefulParDoEvaluatorFactory<K, InputT, OutputT> implements Transfo
 
     static <K, InputT, OutputT> AppliedPTransformOutputKeyAndWindow<K, InputT, OutputT> create(
         AppliedPTransform<
-                PCollection<? extends KV<K, Iterable<InputT>>>, PCollectionTuple,
+                PCollection<? extends KeyedWorkItem<K, KV<K, InputT>>>, PCollectionTuple,
                 StatefulParDo<K, InputT, OutputT>>
             transform,
         StructuralKey<K> key,
@@ -206,7 +208,7 @@ final class StatefulParDoEvaluatorFactory<K, InputT, OutputT> implements Transfo
   }
 
   private static class StatefulParDoEvaluator<K, InputT>
-      implements TransformEvaluator<KV<K, Iterable<InputT>>> {
+      implements TransformEvaluator<KeyedWorkItem<K, KV<K, InputT>>> {
 
     private final TransformEvaluator<KV<K, InputT>> delegateEvaluator;
 
@@ -215,20 +217,20 @@ final class StatefulParDoEvaluatorFactory<K, InputT, OutputT> implements Transfo
     }
 
     @Override
-    public void processElement(WindowedValue<KV<K, Iterable<InputT>>> gbkResult) throws Exception {
+    public void processElement(WindowedValue<KeyedWorkItem<K, KV<K, InputT>>> gbkResult)
+        throws Exception {
 
-      for (InputT value : gbkResult.getValue().getValue()) {
-        delegateEvaluator.processElement(
-            gbkResult.withValue(KV.of(gbkResult.getValue().getKey(), value)));
+      for (WindowedValue<KV<K, InputT>> windowedValue : gbkResult.getValue().elementsIterable()) {
+        delegateEvaluator.processElement(windowedValue);
       }
     }
 
     @Override
-    public TransformResult<KV<K, Iterable<InputT>>> finishBundle() throws Exception {
+    public TransformResult<KeyedWorkItem<K, KV<K, InputT>>> finishBundle() throws Exception {
       TransformResult<KV<K, InputT>> delegateResult = delegateEvaluator.finishBundle();
 
-      StepTransformResult.Builder<KV<K, Iterable<InputT>>> regroupedResult =
-          StepTransformResult.<KV<K, Iterable<InputT>>>withHold(
+      StepTransformResult.Builder<KeyedWorkItem<K, KV<K, InputT>>> regroupedResult =
+          StepTransformResult.<KeyedWorkItem<K, KV<K, InputT>>>withHold(
                   delegateResult.getTransform(), delegateResult.getWatermarkHold())
               .withTimerUpdate(delegateResult.getTimerUpdate())
               .withAggregatorChanges(delegateResult.getAggregatorChanges())
@@ -240,12 +242,10 @@ final class StatefulParDoEvaluatorFactory<K, InputT, OutputT> implements Transfo
       // outputs, but just make a bunch of singletons
       for (WindowedValue<?> untypedUnprocessed : delegateResult.getUnprocessedElements()) {
         WindowedValue<KV<K, InputT>> windowedKv = (WindowedValue<KV<K, InputT>>) untypedUnprocessed;
-        WindowedValue<KV<K, Iterable<InputT>>> pushedBack =
+        WindowedValue<KeyedWorkItem<K, KV<K, InputT>>> pushedBack =
             windowedKv.withValue(
-                KV.of(
-                    windowedKv.getValue().getKey(),
-                    (Iterable<InputT>)
-                        Collections.singletonList(windowedKv.getValue().getValue())));
+                KeyedWorkItems.elementsWorkItem(
+                    windowedKv.getValue().getKey(), Collections.singleton(windowedKv)));
 
         regroupedResult.addUnprocessedElements(pushedBack);
       }

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/1f018ab6/runners/direct-java/src/test/java/org/apache/beam/runners/direct/KeyedPValueTrackingVisitorTest.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/KeyedPValueTrackingVisitorTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/KeyedPValueTrackingVisitorTest.java
index a357005..a1fb81b 100644
--- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/KeyedPValueTrackingVisitorTest.java
+++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/KeyedPValueTrackingVisitorTest.java
@@ -22,8 +22,10 @@ import static org.hamcrest.Matchers.not;
 import static org.junit.Assert.assertThat;
 
 import java.util.Collections;
+import org.apache.beam.runners.core.KeyedWorkItem;
 import org.apache.beam.sdk.coders.IterableCoder;
 import org.apache.beam.sdk.coders.KvCoder;
+import org.apache.beam.sdk.coders.StringUtf8Coder;
 import org.apache.beam.sdk.coders.VarIntCoder;
 import org.apache.beam.sdk.coders.VoidCoder;
 import org.apache.beam.sdk.testing.TestPipeline;
@@ -32,8 +34,12 @@ import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.GroupByKey;
 import org.apache.beam.sdk.transforms.Keys;
 import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.transforms.windowing.IntervalWindow;
+import org.apache.beam.sdk.transforms.windowing.PaneInfo;
+import org.apache.beam.sdk.util.WindowedValue;
 import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PCollection;
+import org.joda.time.Instant;
 import org.junit.Before;
 import org.junit.Rule;
 import org.junit.Test;
@@ -41,9 +47,7 @@ import org.junit.rules.ExpectedException;
 import org.junit.runner.RunWith;
 import org.junit.runners.JUnit4;
 
-/**
- * Tests for {@link KeyedPValueTrackingVisitor}.
- */
+/** Tests for {@link KeyedPValueTrackingVisitor}. */
 @RunWith(JUnit4.class)
 public class KeyedPValueTrackingVisitorTest {
   @Rule public ExpectedException thrown = ExpectedException.none();
@@ -61,8 +65,7 @@ public class KeyedPValueTrackingVisitorTest {
   @Test
   public void groupByKeyProducesKeyedOutput() {
     PCollection<KV<String, Iterable<Integer>>> keyed =
-        p.apply(Create.of(KV.of("foo", 3)))
-            .apply(GroupByKey.<String, Integer>create());
+        p.apply(Create.of(KV.of("foo", 3))).apply(GroupByKey.<String, Integer>create());
 
     p.traverseTopologically(visitor);
     assertThat(visitor.getKeyedPValues(), hasItem(keyed));
@@ -91,16 +94,66 @@ public class KeyedPValueTrackingVisitorTest {
   }
 
   @Test
+  public void unkeyedInputWithKeyPreserving() {
+
+    PCollection<KV<String, Iterable<WindowedValue<KV<String, Integer>>>>> input =
+        p.apply(
+            Create.of(
+                    KV.of(
+                        "hello",
+                        (Iterable<WindowedValue<KV<String, Integer>>>)
+                            Collections.<WindowedValue<KV<String, Integer>>>emptyList()))
+                .withCoder(
+                    KvCoder.of(
+                        StringUtf8Coder.of(),
+                        IterableCoder.of(
+                            WindowedValue.getValueOnlyCoder(
+                                KvCoder.of(StringUtf8Coder.of(), VarIntCoder.of()))))));
+
+    PCollection<KeyedWorkItem<String, KV<String, Integer>>> unkeyed =
+        input.apply(ParDo.of(new ParDoMultiOverrideFactory.ToKeyedWorkItem<String, Integer>()));
+
+    p.traverseTopologically(visitor);
+    assertThat(visitor.getKeyedPValues(), not(hasItem(unkeyed)));
+  }
+
+  @Test
+  public void keyedInputWithKeyPreserving() {
+
+    PCollection<KV<String, WindowedValue<KV<String, Integer>>>> input =
+        p.apply(
+            Create.of(
+                    KV.of(
+                        "hello",
+                        WindowedValue.of(
+                            KV.of("hello", 3),
+                            new Instant(0),
+                            new IntervalWindow(new Instant(0), new Instant(9)),
+                            PaneInfo.NO_FIRING)))
+                .withCoder(
+                    KvCoder.of(
+                        StringUtf8Coder.of(),
+                        WindowedValue.getValueOnlyCoder(
+                            KvCoder.of(StringUtf8Coder.of(), VarIntCoder.of())))));
+
+    PCollection<KeyedWorkItem<String, KV<String, Integer>>> keyed =
+        input
+            .apply(GroupByKey.<String, WindowedValue<KV<String, Integer>>>create())
+            .apply(ParDo.of(new ParDoMultiOverrideFactory.ToKeyedWorkItem<String, Integer>()));
+
+    p.traverseTopologically(visitor);
+    assertThat(visitor.getKeyedPValues(), hasItem(keyed));
+  }
+
+  @Test
   public void traverseMultipleTimesThrows() {
     p.apply(
-            Create.<KV<Integer, Void>>of(
-                    KV.of(1, (Void) null), KV.of(2, (Void) null), KV.of(3, (Void) null))
+            Create.of(KV.of(1, (Void) null), KV.of(2, (Void) null), KV.of(3, (Void) null))
                 .withCoder(KvCoder.of(VarIntCoder.of(), VoidCoder.of())))
         .apply(GroupByKey.<Integer, Void>create())
         .apply(Keys.<Integer>create());
 
     p.traverseTopologically(visitor);
-
     thrown.expect(IllegalStateException.class);
     thrown.expectMessage("already been finalized");
     thrown.expectMessage(KeyedPValueTrackingVisitor.class.getSimpleName());

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/1f018ab6/runners/direct-java/src/test/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactoryTest.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactoryTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactoryTest.java
index d312aa3..b88d5e0 100644
--- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactoryTest.java
+++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactoryTest.java
@@ -27,12 +27,14 @@ import static org.mockito.Matchers.eq;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
 
+import com.google.common.collect.ImmutableList;
 import com.google.common.collect.Iterables;
-import com.google.common.collect.Lists;
 import java.io.Serializable;
 import java.util.ArrayList;
 import java.util.Collections;
 import java.util.List;
+import org.apache.beam.runners.core.KeyedWorkItem;
+import org.apache.beam.runners.core.KeyedWorkItems;
 import org.apache.beam.runners.direct.DirectRunner.CommittedBundle;
 import org.apache.beam.runners.direct.DirectRunner.UncommittedBundle;
 import org.apache.beam.runners.direct.ParDoMultiOverrideFactory.StatefulParDo;
@@ -136,7 +138,7 @@ public class StatefulParDoEvaluatorFactoryTest implements Serializable {
         new StatefulParDoEvaluatorFactory(mockEvaluationContext);
 
     AppliedPTransform<
-            PCollection<? extends KV<String, Iterable<Integer>>>, PCollectionTuple,
+            PCollection<? extends KeyedWorkItem<String, KV<String, Integer>>>, PCollectionTuple,
             StatefulParDo<String, Integer, Integer>>
         producingTransform = (AppliedPTransform) DirectGraphs.getProducer(produced);
 
@@ -245,7 +247,7 @@ public class StatefulParDoEvaluatorFactoryTest implements Serializable {
 
     // This will be the stateful ParDo from the expansion
     AppliedPTransform<
-            PCollection<KV<String, Iterable<Integer>>>, PCollectionTuple,
+            PCollection<KeyedWorkItem<String, KV<String, Integer>>>, PCollectionTuple,
             StatefulParDo<String, Integer, Integer>>
         producingTransform = (AppliedPTransform) DirectGraphs.getProducer(produced);
 
@@ -270,37 +272,50 @@ public class StatefulParDoEvaluatorFactoryTest implements Serializable {
     // A single bundle with some elements in the global window; it should register cleanup for the
     // global window state merely by having the evaluator created. The cleanup logic does not
     // depend on the window.
-    WindowedValue<KV<String, Iterable<Integer>>> gbkOutputElement =
-        WindowedValue.of(
-            KV.<String, Iterable<Integer>>of("hello", Lists.newArrayList(1, 13, 15)),
-            new Instant(3),
-            firstWindow,
-            PaneInfo.NO_FIRING);
-    CommittedBundle<KV<String, Iterable<Integer>>> inputBundle =
+    String key = "hello";
+    WindowedValue<KV<String, Integer>> firstKv = WindowedValue.of(
+        KV.of(key, 1),
+        new Instant(3),
+        firstWindow,
+        PaneInfo.NO_FIRING);
+
+    WindowedValue<KeyedWorkItem<String, KV<String, Integer>>> gbkOutputElement =
+        firstKv.withValue(
+            KeyedWorkItems.elementsWorkItem(
+                "hello",
+                ImmutableList.of(
+                    firstKv,
+                    firstKv.withValue(KV.of(key, 13)),
+                    firstKv.withValue(KV.of(key, 15)))));
+
+    CommittedBundle<KeyedWorkItem<String, KV<String, Integer>>> inputBundle =
         BUNDLE_FACTORY
             .createBundle(producingTransform.getInput())
             .add(gbkOutputElement)
             .commit(Instant.now());
-    TransformEvaluator<KV<String, Iterable<Integer>>> evaluator =
+    TransformEvaluator<KeyedWorkItem<String, KV<String, Integer>>> evaluator =
         factory.forApplication(producingTransform, inputBundle);
+
     evaluator.processElement(gbkOutputElement);
 
     // This should push back every element as a KV<String, Iterable<Integer>>
     // in the appropriate window. Since the keys are equal they are single-threaded
-    TransformResult<KV<String, Iterable<Integer>>> result = evaluator.finishBundle();
+    TransformResult<KeyedWorkItem<String, KV<String, Integer>>> result =
+        evaluator.finishBundle();
 
     List<Integer> pushedBackInts = new ArrayList<>();
 
-    for (WindowedValue<?> unprocessedElement : result.getUnprocessedElements()) {
-      WindowedValue<KV<String, Iterable<Integer>>> unprocessedKv =
-          (WindowedValue<KV<String, Iterable<Integer>>>) unprocessedElement;
+    for (WindowedValue<? extends KeyedWorkItem<String, KV<String, Integer>>> unprocessedElement :
+        result.getUnprocessedElements()) {
 
       assertThat(
           Iterables.getOnlyElement(unprocessedElement.getWindows()),
           equalTo((BoundedWindow) firstWindow));
-      assertThat(unprocessedKv.getValue().getKey(), equalTo("hello"));
-      for (Integer i : unprocessedKv.getValue().getValue()) {
-        pushedBackInts.add(i);
+
+      assertThat(unprocessedElement.getValue().key(), equalTo("hello"));
+      for (WindowedValue<KV<String, Integer>> windowedKv :
+          unprocessedElement.getValue().elementsIterable()) {
+        pushedBackInts.add(windowedKv.getValue().getValue());
       }
     }
     assertThat(pushedBackInts, containsInAnyOrder(1, 13, 15));