You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by ke...@apache.org on 2016/12/21 22:49:49 UTC
[15/51] [abbrv] incubator-beam git commit: Add some key-preserving to
KeyedPValueTrackingVisitor
Add some key-preserving to KeyedPValueTrackingVisitor
Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/81702e67
Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/81702e67
Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/81702e67
Branch: refs/heads/python-sdk
Commit: 81702e67b92a23849cbc8f4a16b2a619e4b477a1
Parents: 22e25a4
Author: Kenneth Knowles <kl...@google.com>
Authored: Thu Dec 8 11:49:15 2016 -0800
Committer: Kenneth Knowles <kl...@google.com>
Committed: Tue Dec 20 11:18:02 2016 -0800
----------------------------------------------------------------------
.../beam/runners/direct/DirectRunner.java | 9 +--
.../direct/KeyedPValueTrackingVisitor.java | 35 +++++---
.../direct/KeyedPValueTrackingVisitorTest.java | 84 +++-----------------
3 files changed, 37 insertions(+), 91 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/81702e67/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java
index 78163c0..afa43ff 100644
--- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java
@@ -31,8 +31,6 @@ import java.util.Map;
import java.util.Set;
import javax.annotation.Nullable;
import org.apache.beam.runners.core.SplittableParDo;
-import org.apache.beam.runners.direct.DirectGroupByKey.DirectGroupAlsoByWindow;
-import org.apache.beam.runners.direct.DirectGroupByKey.DirectGroupByKeyOnly;
import org.apache.beam.runners.direct.DirectRunner.DirectPipelineResult;
import org.apache.beam.runners.direct.TestStreamEvaluatorFactory.DirectTestStreamFactory;
import org.apache.beam.runners.direct.ViewEvaluatorFactory.ViewOverrideFactory;
@@ -306,12 +304,7 @@ public class DirectRunner extends PipelineRunner<DirectPipelineResult> {
graphVisitor.finishSpecifyingRemainder();
@SuppressWarnings("rawtypes")
- KeyedPValueTrackingVisitor keyedPValueVisitor =
- KeyedPValueTrackingVisitor.create(
- ImmutableSet.of(
- SplittableParDo.GBKIntoKeyedWorkItems.class,
- DirectGroupByKeyOnly.class,
- DirectGroupAlsoByWindow.class));
+ KeyedPValueTrackingVisitor keyedPValueVisitor = KeyedPValueTrackingVisitor.create();
pipeline.traverseTopologically(keyedPValueVisitor);
DisplayDataValidator.validatePipeline(pipeline);
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/81702e67/runners/direct-java/src/main/java/org/apache/beam/runners/direct/KeyedPValueTrackingVisitor.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/KeyedPValueTrackingVisitor.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/KeyedPValueTrackingVisitor.java
index 7f85169..e91a768 100644
--- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/KeyedPValueTrackingVisitor.java
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/KeyedPValueTrackingVisitor.java
@@ -18,9 +18,15 @@
package org.apache.beam.runners.direct;
import static com.google.common.base.Preconditions.checkState;
+import static com.google.common.base.Predicates.in;
+import static com.google.common.collect.Iterables.all;
+import com.google.common.collect.ImmutableSet;
import java.util.HashSet;
import java.util.Set;
+import org.apache.beam.runners.core.SplittableParDo;
+import org.apache.beam.runners.direct.DirectGroupByKey.DirectGroupAlsoByWindow;
+import org.apache.beam.runners.direct.DirectGroupByKey.DirectGroupByKeyOnly;
import org.apache.beam.sdk.Pipeline.PipelineVisitor;
import org.apache.beam.sdk.runners.TransformHierarchy;
import org.apache.beam.sdk.transforms.GroupByKey;
@@ -38,19 +44,21 @@ import org.apache.beam.sdk.values.PValue;
// TODO: Handle Key-preserving transforms when appropriate and more aggressively make PTransforms
// unkeyed
class KeyedPValueTrackingVisitor implements PipelineVisitor {
- @SuppressWarnings("rawtypes")
- private final Set<Class<? extends PTransform>> producesKeyedOutputs;
+
+ private static final Set<Class<? extends PTransform>> PRODUCES_KEYED_OUTPUTS =
+ ImmutableSet.of(
+ SplittableParDo.GBKIntoKeyedWorkItems.class,
+ DirectGroupByKeyOnly.class,
+ DirectGroupAlsoByWindow.class);
+
private final Set<PValue> keyedValues;
private boolean finalized;
- public static KeyedPValueTrackingVisitor create(
- @SuppressWarnings("rawtypes") Set<Class<? extends PTransform>> producesKeyedOutputs) {
- return new KeyedPValueTrackingVisitor(producesKeyedOutputs);
+ public static KeyedPValueTrackingVisitor create() {
+ return new KeyedPValueTrackingVisitor();
}
- private KeyedPValueTrackingVisitor(
- @SuppressWarnings("rawtypes") Set<Class<? extends PTransform>> producesKeyedOutputs) {
- this.producesKeyedOutputs = producesKeyedOutputs;
+ private KeyedPValueTrackingVisitor() {
this.keyedValues = new HashSet<>();
}
@@ -73,7 +81,7 @@ class KeyedPValueTrackingVisitor implements PipelineVisitor {
node);
if (node.isRootNode()) {
finalized = true;
- } else if (producesKeyedOutputs.contains(node.getTransform().getClass())) {
+ } else if (PRODUCES_KEYED_OUTPUTS.contains(node.getTransform().getClass())) {
keyedValues.addAll(node.getOutputs());
}
}
@@ -83,7 +91,9 @@ class KeyedPValueTrackingVisitor implements PipelineVisitor {
@Override
public void visitValue(PValue value, TransformHierarchy.Node producer) {
- if (producesKeyedOutputs.contains(producer.getTransform().getClass())) {
+ if (PRODUCES_KEYED_OUTPUTS.contains(producer.getTransform().getClass())
+ || (isKeyPreserving(producer.getTransform())
+ && all(producer.getInputs(), in(keyedValues)))) {
keyedValues.add(value);
}
}
@@ -93,4 +103,9 @@ class KeyedPValueTrackingVisitor implements PipelineVisitor {
finalized, "can't call getKeyedPValues before a Pipeline has been completely traversed");
return keyedValues;
}
+
+ private static boolean isKeyPreserving(PTransform<?, ?> transform) {
+ // There are currently no key-preserving transforms; this lays the infrastructure for them
+ return false;
+ }
}
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/81702e67/runners/direct-java/src/test/java/org/apache/beam/runners/direct/KeyedPValueTrackingVisitorTest.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/KeyedPValueTrackingVisitorTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/KeyedPValueTrackingVisitorTest.java
index eef3375..a357005 100644
--- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/KeyedPValueTrackingVisitorTest.java
+++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/KeyedPValueTrackingVisitorTest.java
@@ -21,9 +21,7 @@ import static org.hamcrest.Matchers.hasItem;
import static org.hamcrest.Matchers.not;
import static org.junit.Assert.assertThat;
-import com.google.common.collect.ImmutableSet;
import java.util.Collections;
-import java.util.Set;
import org.apache.beam.sdk.coders.IterableCoder;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.coders.VarIntCoder;
@@ -33,7 +31,6 @@ import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.GroupByKey;
import org.apache.beam.sdk.transforms.Keys;
-import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
@@ -57,54 +54,20 @@ public class KeyedPValueTrackingVisitorTest {
@Before
public void setup() {
-
- @SuppressWarnings("rawtypes")
- Set<Class<? extends PTransform>> producesKeyed =
- ImmutableSet.<Class<? extends PTransform>>of(PrimitiveKeyer.class, CompositeKeyer.class);
- visitor = KeyedPValueTrackingVisitor.create(producesKeyed);
- }
-
- @Test
- public void primitiveProducesKeyedOutputUnkeyedInputKeyedOutput() {
- PCollection<Integer> keyed =
- p.apply(Create.<Integer>of(1, 2, 3)).apply(new PrimitiveKeyer<Integer>());
-
- p.traverseTopologically(visitor);
- assertThat(visitor.getKeyedPValues(), hasItem(keyed));
- }
-
- @Test
- public void primitiveProducesKeyedOutputKeyedInputKeyedOutut() {
- PCollection<Integer> keyed =
- p.apply(Create.<Integer>of(1, 2, 3))
- .apply("firstKey", new PrimitiveKeyer<Integer>())
- .apply("secondKey", new PrimitiveKeyer<Integer>());
-
- p.traverseTopologically(visitor);
- assertThat(visitor.getKeyedPValues(), hasItem(keyed));
- }
-
- @Test
- public void compositeProducesKeyedOutputUnkeyedInputKeyedOutput() {
- PCollection<Integer> keyed =
- p.apply(Create.<Integer>of(1, 2, 3)).apply(new CompositeKeyer<Integer>());
-
- p.traverseTopologically(visitor);
- assertThat(visitor.getKeyedPValues(), hasItem(keyed));
+ p = TestPipeline.create();
+ visitor = KeyedPValueTrackingVisitor.create();
}
@Test
- public void compositeProducesKeyedOutputKeyedInputKeyedOutut() {
- PCollection<Integer> keyed =
- p.apply(Create.<Integer>of(1, 2, 3))
- .apply("firstKey", new CompositeKeyer<Integer>())
- .apply("secondKey", new CompositeKeyer<Integer>());
+ public void groupByKeyProducesKeyedOutput() {
+ PCollection<KV<String, Iterable<Integer>>> keyed =
+ p.apply(Create.of(KV.of("foo", 3)))
+ .apply(GroupByKey.<String, Integer>create());
p.traverseTopologically(visitor);
assertThat(visitor.getKeyedPValues(), hasItem(keyed));
}
-
@Test
public void noInputUnkeyedOutput() {
PCollection<KV<Integer, Iterable<Void>>> unkeyed =
@@ -117,26 +80,17 @@ public class KeyedPValueTrackingVisitorTest {
}
@Test
- public void keyedInputNotProducesKeyedOutputUnkeyedOutput() {
- PCollection<Integer> onceKeyed =
- p.apply(Create.<Integer>of(1, 2, 3))
- .apply(new PrimitiveKeyer<Integer>())
- .apply(ParDo.of(new IdentityFn<Integer>()));
+ public void keyedInputWithoutKeyPreserving() {
+ PCollection<KV<String, Iterable<Integer>>> onceKeyed =
+ p.apply(Create.of(KV.of("hello", 42)))
+ .apply(GroupByKey.<String, Integer>create())
+ .apply(ParDo.of(new IdentityFn<KV<String, Iterable<Integer>>>()));
p.traverseTopologically(visitor);
assertThat(visitor.getKeyedPValues(), not(hasItem(onceKeyed)));
}
@Test
- public void unkeyedInputNotProducesKeyedOutputUnkeyedOutput() {
- PCollection<Integer> unkeyed =
- p.apply(Create.<Integer>of(1, 2, 3)).apply(ParDo.of(new IdentityFn<Integer>()));
-
- p.traverseTopologically(visitor);
- assertThat(visitor.getKeyedPValues(), not(hasItem(unkeyed)));
- }
-
- @Test
public void traverseMultipleTimesThrows() {
p.apply(
Create.<KV<Integer, Void>>of(
@@ -161,22 +115,6 @@ public class KeyedPValueTrackingVisitorTest {
visitor.getKeyedPValues();
}
- private static class PrimitiveKeyer<K> extends PTransform<PCollection<K>, PCollection<K>> {
- @Override
- public PCollection<K> expand(PCollection<K> input) {
- return PCollection.<K>createPrimitiveOutputInternal(
- input.getPipeline(), input.getWindowingStrategy(), input.isBounded())
- .setCoder(input.getCoder());
- }
- }
-
- private static class CompositeKeyer<K> extends PTransform<PCollection<K>, PCollection<K>> {
- @Override
- public PCollection<K> expand(PCollection<K> input) {
- return input.apply(new PrimitiveKeyer<K>()).apply(ParDo.of(new IdentityFn<K>()));
- }
- }
-
private static class IdentityFn<K> extends DoFn<K, K> {
@ProcessElement
public void processElement(ProcessContext c) throws Exception {