You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by ke...@apache.org on 2016/12/21 22:49:52 UTC
[18/51] [abbrv] incubator-beam git commit: Port direct runner
StatefulParDo to KeyedWorkItem
Port direct runner StatefulParDo to KeyedWorkItem
Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/1f018ab6
Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/1f018ab6
Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/1f018ab6
Branch: refs/heads/python-sdk
Commit: 1f018ab69fdcc720a10e2aeb8ec1eea1c06e1cbc
Parents: d040b7f
Author: Kenneth Knowles <kl...@google.com>
Authored: Mon Dec 12 19:49:58 2016 -0800
Committer: Kenneth Knowles <kl...@google.com>
Committed: Tue Dec 20 11:19:07 2016 -0800
----------------------------------------------------------------------
.../direct/KeyedPValueTrackingVisitor.java | 13 ++-
.../direct/ParDoMultiOverrideFactory.java | 94 +++++++++++++++++---
.../direct/StatefulParDoEvaluatorFactory.java | 36 ++++----
.../direct/KeyedPValueTrackingVisitorTest.java | 69 ++++++++++++--
.../StatefulParDoEvaluatorFactoryTest.java | 51 +++++++----
5 files changed, 205 insertions(+), 58 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/1f018ab6/runners/direct-java/src/main/java/org/apache/beam/runners/direct/KeyedPValueTrackingVisitor.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/KeyedPValueTrackingVisitor.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/KeyedPValueTrackingVisitor.java
index e91a768..65c41e0 100644
--- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/KeyedPValueTrackingVisitor.java
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/KeyedPValueTrackingVisitor.java
@@ -31,6 +31,7 @@ import org.apache.beam.sdk.Pipeline.PipelineVisitor;
import org.apache.beam.sdk.runners.TransformHierarchy;
import org.apache.beam.sdk.transforms.GroupByKey;
import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.values.PValue;
/**
@@ -105,7 +106,15 @@ class KeyedPValueTrackingVisitor implements PipelineVisitor {
}
private static boolean isKeyPreserving(PTransform<?, ?> transform) {
- // There are currently no key-preserving transforms; this lays the infrastructure for them
- return false;
+ // This is a hacky check for what is considered key-preserving to the direct runner.
+ // The most obvious alternative would be a package-private marker interface, but
+ // better to make this obviously hacky so it is less likely to proliferate. Meanwhile
+ // we intend to allow explicit expression of key-preserving DoFn in the model.
+ if (transform instanceof ParDo.BoundMulti) {
+ ParDo.BoundMulti<?, ?> parDo = (ParDo.BoundMulti<?, ?>) transform;
+ return parDo.getFn() instanceof ParDoMultiOverrideFactory.ToKeyedWorkItem;
+ } else {
+ return false;
+ }
}
}
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/1f018ab6/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java
index c5bc069..2cea999 100644
--- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java
@@ -17,9 +17,15 @@
*/
package org.apache.beam.runners.direct;
+import static com.google.common.base.Preconditions.checkState;
+
+import org.apache.beam.runners.core.KeyedWorkItem;
+import org.apache.beam.runners.core.KeyedWorkItemCoder;
+import org.apache.beam.runners.core.KeyedWorkItems;
import org.apache.beam.runners.core.SplittableParDo;
import org.apache.beam.sdk.coders.CannotProvideCoderException;
import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.runners.PTransformOverrideFactory;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.GroupByKey;
@@ -28,6 +34,8 @@ import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.ParDo.BoundMulti;
import org.apache.beam.sdk.transforms.reflect.DoFnSignature;
import org.apache.beam.sdk.transforms.reflect.DoFnSignatures;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionTuple;
@@ -84,16 +92,41 @@ class ParDoMultiOverrideFactory<InputT, OutputT>
@Override
public PCollectionTuple expand(PCollection<KV<K, InputT>> input) {
- PCollectionTuple outputs = input
- .apply("Group by key", GroupByKey.<K, InputT>create())
- .apply("Stateful ParDo", new StatefulParDo<>(underlyingParDo, input));
+ // A KvCoder is required since this goes through GBK. Further, WindowedValueCoder
+ // is not registered by default, so we explicitly set the relevant coders.
+ checkState(input.getCoder() instanceof KvCoder,
+ "Input to a %s using state requires a %s, but the coder was %s",
+ ParDo.class.getSimpleName(),
+ KvCoder.class.getSimpleName(),
+ input.getCoder());
+ KvCoder<K, InputT> kvCoder = (KvCoder<K, InputT>) input.getCoder();
+ Coder<K> keyCoder = kvCoder.getKeyCoder();
+ Coder<? extends BoundedWindow> windowCoder =
+ input.getWindowingStrategy().getWindowFn().windowCoder();
+
+ PCollectionTuple outputs =
+ input
+ // Stash the original timestamps, etc, for when it is fed to the user's DoFn
+ .apply("Reify timestamps", ParDo.of(new ReifyWindowedValueFn<K, InputT>()))
+ .setCoder(KvCoder.of(keyCoder, WindowedValue.getFullCoder(kvCoder, windowCoder)))
+
+ // A full GBK to group by key _and_ window
+ .apply("Group by key", GroupByKey.<K, WindowedValue<KV<K, InputT>>>create())
+
+ // Adapt to KeyedWorkItem; that is how this runner delivers timers
+ .apply("To KeyedWorkItem", ParDo.of(new ToKeyedWorkItem<K, InputT>()))
+ .setCoder(KeyedWorkItemCoder.of(keyCoder, kvCoder, windowCoder))
+
+ // Explode the resulting iterable into elements that are exactly the ones from
+ // the input
+ .apply("Stateful ParDo", new StatefulParDo<>(underlyingParDo, input));
return outputs;
}
}
static class StatefulParDo<K, InputT, OutputT>
- extends PTransform<PCollection<? extends KV<K, Iterable<InputT>>>, PCollectionTuple> {
+ extends PTransform<PCollection<? extends KeyedWorkItem<K, KV<K, InputT>>>, PCollectionTuple> {
private final transient ParDo.BoundMulti<KV<K, InputT>, OutputT> underlyingParDo;
private final transient PCollection<KV<K, InputT>> originalInput;
@@ -110,21 +143,58 @@ class ParDoMultiOverrideFactory<InputT, OutputT>
@Override
public <T> Coder<T> getDefaultOutputCoder(
- PCollection<? extends KV<K, Iterable<InputT>>> input, TypedPValue<T> output)
+ PCollection<? extends KeyedWorkItem<K, KV<K, InputT>>> input, TypedPValue<T> output)
throws CannotProvideCoderException {
return underlyingParDo.getDefaultOutputCoder(originalInput, output);
}
- public PCollectionTuple expand(PCollection<? extends KV<K, Iterable<InputT>>> input) {
+ @Override
+ public PCollectionTuple expand(PCollection<? extends KeyedWorkItem<K, KV<K, InputT>>> input) {
- PCollectionTuple outputs = PCollectionTuple.ofPrimitiveOutputsInternal(
- input.getPipeline(),
- TupleTagList.of(underlyingParDo.getMainOutputTag())
- .and(underlyingParDo.getSideOutputTags().getAll()),
- input.getWindowingStrategy(),
- input.isBounded());
+ PCollectionTuple outputs =
+ PCollectionTuple.ofPrimitiveOutputsInternal(
+ input.getPipeline(),
+ TupleTagList.of(underlyingParDo.getMainOutputTag())
+ .and(underlyingParDo.getSideOutputTags().getAll()),
+ input.getWindowingStrategy(),
+ input.isBounded());
return outputs;
}
}
+
+ /**
+ * A distinguished key-preserving {@link DoFn}.
+ *
+ * <p>This wraps the {@link GroupByKey} output in a {@link KeyedWorkItem} to be able to deliver
+ * timers. It also explodes them into single {@link KV KVs} since this is what the user's {@link
+ * DoFn} needs to process anyhow.
+ */
+ static class ReifyWindowedValueFn<K, V> extends DoFn<KV<K, V>, KV<K, WindowedValue<KV<K, V>>>> {
+ @ProcessElement
+ public void processElement(final ProcessContext c, final BoundedWindow window) {
+ c.output(
+ KV.of(
+ c.element().getKey(),
+ WindowedValue.of(c.element(), c.timestamp(), window, c.pane())));
+ }
+ }
+
+ /**
+ * A runner-specific primitive that is just a key-preserving {@link ParDo}, but we do not have the
+ * machinery to detect or enforce that yet.
+ *
+ * <p>This wraps the {@link GroupByKey} output in a {@link KeyedWorkItem} to be able to deliver
+ * timers. It also explodes them into single {@link KV KVs} since this is what the user's {@link
+ * DoFn} needs to process anyhow.
+ */
+ static class ToKeyedWorkItem<K, V>
+ extends DoFn<KV<K, Iterable<WindowedValue<KV<K, V>>>>, KeyedWorkItem<K, KV<K, V>>> {
+
+ @ProcessElement
+ public void processElement(final ProcessContext c, final BoundedWindow window) {
+ final K key = c.element().getKey();
+ c.output(KeyedWorkItems.elementsWorkItem(key, c.element().getValue()));
+ }
+ }
}
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/1f018ab6/runners/direct-java/src/main/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactory.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactory.java
index 1f64d9a..5f9d8f4 100644
--- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactory.java
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactory.java
@@ -23,6 +23,8 @@ import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import com.google.common.collect.Lists;
import java.util.Collections;
+import org.apache.beam.runners.core.KeyedWorkItem;
+import org.apache.beam.runners.core.KeyedWorkItems;
import org.apache.beam.runners.direct.DirectExecutionContext.DirectStepContext;
import org.apache.beam.runners.direct.DirectRunner.CommittedBundle;
import org.apache.beam.runners.direct.ParDoMultiOverrideFactory.StatefulParDo;
@@ -77,12 +79,12 @@ final class StatefulParDoEvaluatorFactory<K, InputT, OutputT> implements Transfo
}
@SuppressWarnings({"unchecked", "rawtypes"})
- private TransformEvaluator<KV<K, Iterable<InputT>>> createEvaluator(
+ private TransformEvaluator<KeyedWorkItem<K, KV<K, InputT>>> createEvaluator(
AppliedPTransform<
- PCollection<? extends KV<K, Iterable<InputT>>>, PCollectionTuple,
+ PCollection<? extends KeyedWorkItem<K, KV<K, InputT>>>, PCollectionTuple,
StatefulParDo<K, InputT, OutputT>>
application,
- CommittedBundle<KV<K, Iterable<InputT>>> inputBundle)
+ CommittedBundle<KeyedWorkItem<K, KV<K, InputT>>> inputBundle)
throws Exception {
final DoFn<KV<K, InputT>, OutputT> doFn =
@@ -185,7 +187,7 @@ final class StatefulParDoEvaluatorFactory<K, InputT, OutputT> implements Transfo
@AutoValue
abstract static class AppliedPTransformOutputKeyAndWindow<K, InputT, OutputT> {
abstract AppliedPTransform<
- PCollection<? extends KV<K, Iterable<InputT>>>, PCollectionTuple,
+ PCollection<? extends KeyedWorkItem<K, KV<K, InputT>>>, PCollectionTuple,
StatefulParDo<K, InputT, OutputT>>
getTransform();
@@ -195,7 +197,7 @@ final class StatefulParDoEvaluatorFactory<K, InputT, OutputT> implements Transfo
static <K, InputT, OutputT> AppliedPTransformOutputKeyAndWindow<K, InputT, OutputT> create(
AppliedPTransform<
- PCollection<? extends KV<K, Iterable<InputT>>>, PCollectionTuple,
+ PCollection<? extends KeyedWorkItem<K, KV<K, InputT>>>, PCollectionTuple,
StatefulParDo<K, InputT, OutputT>>
transform,
StructuralKey<K> key,
@@ -206,7 +208,7 @@ final class StatefulParDoEvaluatorFactory<K, InputT, OutputT> implements Transfo
}
private static class StatefulParDoEvaluator<K, InputT>
- implements TransformEvaluator<KV<K, Iterable<InputT>>> {
+ implements TransformEvaluator<KeyedWorkItem<K, KV<K, InputT>>> {
private final TransformEvaluator<KV<K, InputT>> delegateEvaluator;
@@ -215,20 +217,20 @@ final class StatefulParDoEvaluatorFactory<K, InputT, OutputT> implements Transfo
}
@Override
- public void processElement(WindowedValue<KV<K, Iterable<InputT>>> gbkResult) throws Exception {
+ public void processElement(WindowedValue<KeyedWorkItem<K, KV<K, InputT>>> gbkResult)
+ throws Exception {
- for (InputT value : gbkResult.getValue().getValue()) {
- delegateEvaluator.processElement(
- gbkResult.withValue(KV.of(gbkResult.getValue().getKey(), value)));
+ for (WindowedValue<KV<K, InputT>> windowedValue : gbkResult.getValue().elementsIterable()) {
+ delegateEvaluator.processElement(windowedValue);
}
}
@Override
- public TransformResult<KV<K, Iterable<InputT>>> finishBundle() throws Exception {
+ public TransformResult<KeyedWorkItem<K, KV<K, InputT>>> finishBundle() throws Exception {
TransformResult<KV<K, InputT>> delegateResult = delegateEvaluator.finishBundle();
- StepTransformResult.Builder<KV<K, Iterable<InputT>>> regroupedResult =
- StepTransformResult.<KV<K, Iterable<InputT>>>withHold(
+ StepTransformResult.Builder<KeyedWorkItem<K, KV<K, InputT>>> regroupedResult =
+ StepTransformResult.<KeyedWorkItem<K, KV<K, InputT>>>withHold(
delegateResult.getTransform(), delegateResult.getWatermarkHold())
.withTimerUpdate(delegateResult.getTimerUpdate())
.withAggregatorChanges(delegateResult.getAggregatorChanges())
@@ -240,12 +242,10 @@ final class StatefulParDoEvaluatorFactory<K, InputT, OutputT> implements Transfo
// outputs, but just make a bunch of singletons
for (WindowedValue<?> untypedUnprocessed : delegateResult.getUnprocessedElements()) {
WindowedValue<KV<K, InputT>> windowedKv = (WindowedValue<KV<K, InputT>>) untypedUnprocessed;
- WindowedValue<KV<K, Iterable<InputT>>> pushedBack =
+ WindowedValue<KeyedWorkItem<K, KV<K, InputT>>> pushedBack =
windowedKv.withValue(
- KV.of(
- windowedKv.getValue().getKey(),
- (Iterable<InputT>)
- Collections.singletonList(windowedKv.getValue().getValue())));
+ KeyedWorkItems.elementsWorkItem(
+ windowedKv.getValue().getKey(), Collections.singleton(windowedKv)));
regroupedResult.addUnprocessedElements(pushedBack);
}
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/1f018ab6/runners/direct-java/src/test/java/org/apache/beam/runners/direct/KeyedPValueTrackingVisitorTest.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/KeyedPValueTrackingVisitorTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/KeyedPValueTrackingVisitorTest.java
index a357005..a1fb81b 100644
--- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/KeyedPValueTrackingVisitorTest.java
+++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/KeyedPValueTrackingVisitorTest.java
@@ -22,8 +22,10 @@ import static org.hamcrest.Matchers.not;
import static org.junit.Assert.assertThat;
import java.util.Collections;
+import org.apache.beam.runners.core.KeyedWorkItem;
import org.apache.beam.sdk.coders.IterableCoder;
import org.apache.beam.sdk.coders.KvCoder;
+import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.coders.VarIntCoder;
import org.apache.beam.sdk.coders.VoidCoder;
import org.apache.beam.sdk.testing.TestPipeline;
@@ -32,8 +34,12 @@ import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.GroupByKey;
import org.apache.beam.sdk.transforms.Keys;
import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.transforms.windowing.IntervalWindow;
+import org.apache.beam.sdk.transforms.windowing.PaneInfo;
+import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
+import org.joda.time.Instant;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
@@ -41,9 +47,7 @@ import org.junit.rules.ExpectedException;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
-/**
- * Tests for {@link KeyedPValueTrackingVisitor}.
- */
+/** Tests for {@link KeyedPValueTrackingVisitor}. */
@RunWith(JUnit4.class)
public class KeyedPValueTrackingVisitorTest {
@Rule public ExpectedException thrown = ExpectedException.none();
@@ -61,8 +65,7 @@ public class KeyedPValueTrackingVisitorTest {
@Test
public void groupByKeyProducesKeyedOutput() {
PCollection<KV<String, Iterable<Integer>>> keyed =
- p.apply(Create.of(KV.of("foo", 3)))
- .apply(GroupByKey.<String, Integer>create());
+ p.apply(Create.of(KV.of("foo", 3))).apply(GroupByKey.<String, Integer>create());
p.traverseTopologically(visitor);
assertThat(visitor.getKeyedPValues(), hasItem(keyed));
@@ -91,16 +94,66 @@ public class KeyedPValueTrackingVisitorTest {
}
@Test
+ public void unkeyedInputWithKeyPreserving() {
+
+ PCollection<KV<String, Iterable<WindowedValue<KV<String, Integer>>>>> input =
+ p.apply(
+ Create.of(
+ KV.of(
+ "hello",
+ (Iterable<WindowedValue<KV<String, Integer>>>)
+ Collections.<WindowedValue<KV<String, Integer>>>emptyList()))
+ .withCoder(
+ KvCoder.of(
+ StringUtf8Coder.of(),
+ IterableCoder.of(
+ WindowedValue.getValueOnlyCoder(
+ KvCoder.of(StringUtf8Coder.of(), VarIntCoder.of()))))));
+
+ PCollection<KeyedWorkItem<String, KV<String, Integer>>> unkeyed =
+ input.apply(ParDo.of(new ParDoMultiOverrideFactory.ToKeyedWorkItem<String, Integer>()));
+
+ p.traverseTopologically(visitor);
+ assertThat(visitor.getKeyedPValues(), not(hasItem(unkeyed)));
+ }
+
+ @Test
+ public void keyedInputWithKeyPreserving() {
+
+ PCollection<KV<String, WindowedValue<KV<String, Integer>>>> input =
+ p.apply(
+ Create.of(
+ KV.of(
+ "hello",
+ WindowedValue.of(
+ KV.of("hello", 3),
+ new Instant(0),
+ new IntervalWindow(new Instant(0), new Instant(9)),
+ PaneInfo.NO_FIRING)))
+ .withCoder(
+ KvCoder.of(
+ StringUtf8Coder.of(),
+ WindowedValue.getValueOnlyCoder(
+ KvCoder.of(StringUtf8Coder.of(), VarIntCoder.of())))));
+
+ PCollection<KeyedWorkItem<String, KV<String, Integer>>> keyed =
+ input
+ .apply(GroupByKey.<String, WindowedValue<KV<String, Integer>>>create())
+ .apply(ParDo.of(new ParDoMultiOverrideFactory.ToKeyedWorkItem<String, Integer>()));
+
+ p.traverseTopologically(visitor);
+ assertThat(visitor.getKeyedPValues(), hasItem(keyed));
+ }
+
+ @Test
public void traverseMultipleTimesThrows() {
p.apply(
- Create.<KV<Integer, Void>>of(
- KV.of(1, (Void) null), KV.of(2, (Void) null), KV.of(3, (Void) null))
+ Create.of(KV.of(1, (Void) null), KV.of(2, (Void) null), KV.of(3, (Void) null))
.withCoder(KvCoder.of(VarIntCoder.of(), VoidCoder.of())))
.apply(GroupByKey.<Integer, Void>create())
.apply(Keys.<Integer>create());
p.traverseTopologically(visitor);
-
thrown.expect(IllegalStateException.class);
thrown.expectMessage("already been finalized");
thrown.expectMessage(KeyedPValueTrackingVisitor.class.getSimpleName());
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/1f018ab6/runners/direct-java/src/test/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactoryTest.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactoryTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactoryTest.java
index d312aa3..b88d5e0 100644
--- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactoryTest.java
+++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactoryTest.java
@@ -27,12 +27,14 @@ import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
+import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
-import com.google.common.collect.Lists;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
+import org.apache.beam.runners.core.KeyedWorkItem;
+import org.apache.beam.runners.core.KeyedWorkItems;
import org.apache.beam.runners.direct.DirectRunner.CommittedBundle;
import org.apache.beam.runners.direct.DirectRunner.UncommittedBundle;
import org.apache.beam.runners.direct.ParDoMultiOverrideFactory.StatefulParDo;
@@ -136,7 +138,7 @@ public class StatefulParDoEvaluatorFactoryTest implements Serializable {
new StatefulParDoEvaluatorFactory(mockEvaluationContext);
AppliedPTransform<
- PCollection<? extends KV<String, Iterable<Integer>>>, PCollectionTuple,
+ PCollection<? extends KeyedWorkItem<String, KV<String, Integer>>>, PCollectionTuple,
StatefulParDo<String, Integer, Integer>>
producingTransform = (AppliedPTransform) DirectGraphs.getProducer(produced);
@@ -245,7 +247,7 @@ public class StatefulParDoEvaluatorFactoryTest implements Serializable {
// This will be the stateful ParDo from the expansion
AppliedPTransform<
- PCollection<KV<String, Iterable<Integer>>>, PCollectionTuple,
+ PCollection<KeyedWorkItem<String, KV<String, Integer>>>, PCollectionTuple,
StatefulParDo<String, Integer, Integer>>
producingTransform = (AppliedPTransform) DirectGraphs.getProducer(produced);
@@ -270,37 +272,50 @@ public class StatefulParDoEvaluatorFactoryTest implements Serializable {
// A single bundle with some elements in the global window; it should register cleanup for the
// global window state merely by having the evaluator created. The cleanup logic does not
// depend on the window.
- WindowedValue<KV<String, Iterable<Integer>>> gbkOutputElement =
- WindowedValue.of(
- KV.<String, Iterable<Integer>>of("hello", Lists.newArrayList(1, 13, 15)),
- new Instant(3),
- firstWindow,
- PaneInfo.NO_FIRING);
- CommittedBundle<KV<String, Iterable<Integer>>> inputBundle =
+ String key = "hello";
+ WindowedValue<KV<String, Integer>> firstKv = WindowedValue.of(
+ KV.of(key, 1),
+ new Instant(3),
+ firstWindow,
+ PaneInfo.NO_FIRING);
+
+ WindowedValue<KeyedWorkItem<String, KV<String, Integer>>> gbkOutputElement =
+ firstKv.withValue(
+ KeyedWorkItems.elementsWorkItem(
+ "hello",
+ ImmutableList.of(
+ firstKv,
+ firstKv.withValue(KV.of(key, 13)),
+ firstKv.withValue(KV.of(key, 15)))));
+
+ CommittedBundle<KeyedWorkItem<String, KV<String, Integer>>> inputBundle =
BUNDLE_FACTORY
.createBundle(producingTransform.getInput())
.add(gbkOutputElement)
.commit(Instant.now());
- TransformEvaluator<KV<String, Iterable<Integer>>> evaluator =
+ TransformEvaluator<KeyedWorkItem<String, KV<String, Integer>>> evaluator =
factory.forApplication(producingTransform, inputBundle);
+
evaluator.processElement(gbkOutputElement);
// This should push back every element as a KV<String, Iterable<Integer>>
// in the appropriate window. Since the keys are equal they are single-threaded
- TransformResult<KV<String, Iterable<Integer>>> result = evaluator.finishBundle();
+ TransformResult<KeyedWorkItem<String, KV<String, Integer>>> result =
+ evaluator.finishBundle();
List<Integer> pushedBackInts = new ArrayList<>();
- for (WindowedValue<?> unprocessedElement : result.getUnprocessedElements()) {
- WindowedValue<KV<String, Iterable<Integer>>> unprocessedKv =
- (WindowedValue<KV<String, Iterable<Integer>>>) unprocessedElement;
+ for (WindowedValue<? extends KeyedWorkItem<String, KV<String, Integer>>> unprocessedElement :
+ result.getUnprocessedElements()) {
assertThat(
Iterables.getOnlyElement(unprocessedElement.getWindows()),
equalTo((BoundedWindow) firstWindow));
- assertThat(unprocessedKv.getValue().getKey(), equalTo("hello"));
- for (Integer i : unprocessedKv.getValue().getValue()) {
- pushedBackInts.add(i);
+
+ assertThat(unprocessedElement.getValue().key(), equalTo("hello"));
+ for (WindowedValue<KV<String, Integer>> windowedKv :
+ unprocessedElement.getValue().elementsIterable()) {
+ pushedBackInts.add(windowedKv.getValue().getValue());
}
}
assertThat(pushedBackInts, containsInAnyOrder(1, 13, 15));