You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by ke...@apache.org on 2016/12/21 22:49:49 UTC

[15/51] [abbrv] incubator-beam git commit: Add some key-preserving to KeyedPValueTrackingVisitor

Add some key-preserving to KeyedPValueTrackingVisitor


Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/81702e67
Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/81702e67
Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/81702e67

Branch: refs/heads/python-sdk
Commit: 81702e67b92a23849cbc8f4a16b2a619e4b477a1
Parents: 22e25a4
Author: Kenneth Knowles <kl...@google.com>
Authored: Thu Dec 8 11:49:15 2016 -0800
Committer: Kenneth Knowles <kl...@google.com>
Committed: Tue Dec 20 11:18:02 2016 -0800

----------------------------------------------------------------------
 .../beam/runners/direct/DirectRunner.java       |  9 +--
 .../direct/KeyedPValueTrackingVisitor.java      | 35 +++++---
 .../direct/KeyedPValueTrackingVisitorTest.java  | 84 +++-----------------
 3 files changed, 37 insertions(+), 91 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/81702e67/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java
index 78163c0..afa43ff 100644
--- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java
@@ -31,8 +31,6 @@ import java.util.Map;
 import java.util.Set;
 import javax.annotation.Nullable;
 import org.apache.beam.runners.core.SplittableParDo;
-import org.apache.beam.runners.direct.DirectGroupByKey.DirectGroupAlsoByWindow;
-import org.apache.beam.runners.direct.DirectGroupByKey.DirectGroupByKeyOnly;
 import org.apache.beam.runners.direct.DirectRunner.DirectPipelineResult;
 import org.apache.beam.runners.direct.TestStreamEvaluatorFactory.DirectTestStreamFactory;
 import org.apache.beam.runners.direct.ViewEvaluatorFactory.ViewOverrideFactory;
@@ -306,12 +304,7 @@ public class DirectRunner extends PipelineRunner<DirectPipelineResult> {
     graphVisitor.finishSpecifyingRemainder();
 
     @SuppressWarnings("rawtypes")
-    KeyedPValueTrackingVisitor keyedPValueVisitor =
-        KeyedPValueTrackingVisitor.create(
-            ImmutableSet.of(
-                SplittableParDo.GBKIntoKeyedWorkItems.class,
-                DirectGroupByKeyOnly.class,
-                DirectGroupAlsoByWindow.class));
+    KeyedPValueTrackingVisitor keyedPValueVisitor = KeyedPValueTrackingVisitor.create();
     pipeline.traverseTopologically(keyedPValueVisitor);
 
     DisplayDataValidator.validatePipeline(pipeline);

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/81702e67/runners/direct-java/src/main/java/org/apache/beam/runners/direct/KeyedPValueTrackingVisitor.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/KeyedPValueTrackingVisitor.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/KeyedPValueTrackingVisitor.java
index 7f85169..e91a768 100644
--- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/KeyedPValueTrackingVisitor.java
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/KeyedPValueTrackingVisitor.java
@@ -18,9 +18,15 @@
 package org.apache.beam.runners.direct;
 
 import static com.google.common.base.Preconditions.checkState;
+import static com.google.common.base.Predicates.in;
+import static com.google.common.collect.Iterables.all;
 
+import com.google.common.collect.ImmutableSet;
 import java.util.HashSet;
 import java.util.Set;
+import org.apache.beam.runners.core.SplittableParDo;
+import org.apache.beam.runners.direct.DirectGroupByKey.DirectGroupAlsoByWindow;
+import org.apache.beam.runners.direct.DirectGroupByKey.DirectGroupByKeyOnly;
 import org.apache.beam.sdk.Pipeline.PipelineVisitor;
 import org.apache.beam.sdk.runners.TransformHierarchy;
 import org.apache.beam.sdk.transforms.GroupByKey;
@@ -38,19 +44,21 @@ import org.apache.beam.sdk.values.PValue;
 // TODO: Handle Key-preserving transforms when appropriate and more aggressively make PTransforms
 // unkeyed
 class KeyedPValueTrackingVisitor implements PipelineVisitor {
-  @SuppressWarnings("rawtypes")
-  private final Set<Class<? extends PTransform>> producesKeyedOutputs;
+
+  private static final Set<Class<? extends PTransform>> PRODUCES_KEYED_OUTPUTS =
+      ImmutableSet.of(
+          SplittableParDo.GBKIntoKeyedWorkItems.class,
+          DirectGroupByKeyOnly.class,
+          DirectGroupAlsoByWindow.class);
+
   private final Set<PValue> keyedValues;
   private boolean finalized;
 
-  public static KeyedPValueTrackingVisitor create(
-      @SuppressWarnings("rawtypes") Set<Class<? extends PTransform>> producesKeyedOutputs) {
-    return new KeyedPValueTrackingVisitor(producesKeyedOutputs);
+  public static KeyedPValueTrackingVisitor create() {
+    return new KeyedPValueTrackingVisitor();
   }
 
-  private KeyedPValueTrackingVisitor(
-      @SuppressWarnings("rawtypes") Set<Class<? extends PTransform>> producesKeyedOutputs) {
-    this.producesKeyedOutputs = producesKeyedOutputs;
+  private KeyedPValueTrackingVisitor() {
     this.keyedValues = new HashSet<>();
   }
 
@@ -73,7 +81,7 @@ class KeyedPValueTrackingVisitor implements PipelineVisitor {
         node);
     if (node.isRootNode()) {
       finalized = true;
-    } else if (producesKeyedOutputs.contains(node.getTransform().getClass())) {
+    } else if (PRODUCES_KEYED_OUTPUTS.contains(node.getTransform().getClass())) {
       keyedValues.addAll(node.getOutputs());
     }
   }
@@ -83,7 +91,9 @@ class KeyedPValueTrackingVisitor implements PipelineVisitor {
 
   @Override
   public void visitValue(PValue value, TransformHierarchy.Node producer) {
-    if (producesKeyedOutputs.contains(producer.getTransform().getClass())) {
+    if (PRODUCES_KEYED_OUTPUTS.contains(producer.getTransform().getClass())
+        || (isKeyPreserving(producer.getTransform())
+            && all(producer.getInputs(), in(keyedValues)))) {
       keyedValues.add(value);
     }
   }
@@ -93,4 +103,9 @@ class KeyedPValueTrackingVisitor implements PipelineVisitor {
         finalized, "can't call getKeyedPValues before a Pipeline has been completely traversed");
     return keyedValues;
   }
+
+  private static boolean isKeyPreserving(PTransform<?, ?> transform) {
+    // There are currently no key-preserving transforms; this lays the infrastructure for them
+    return false;
+  }
 }

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/81702e67/runners/direct-java/src/test/java/org/apache/beam/runners/direct/KeyedPValueTrackingVisitorTest.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/KeyedPValueTrackingVisitorTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/KeyedPValueTrackingVisitorTest.java
index eef3375..a357005 100644
--- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/KeyedPValueTrackingVisitorTest.java
+++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/KeyedPValueTrackingVisitorTest.java
@@ -21,9 +21,7 @@ import static org.hamcrest.Matchers.hasItem;
 import static org.hamcrest.Matchers.not;
 import static org.junit.Assert.assertThat;
 
-import com.google.common.collect.ImmutableSet;
 import java.util.Collections;
-import java.util.Set;
 import org.apache.beam.sdk.coders.IterableCoder;
 import org.apache.beam.sdk.coders.KvCoder;
 import org.apache.beam.sdk.coders.VarIntCoder;
@@ -33,7 +31,6 @@ import org.apache.beam.sdk.transforms.Create;
 import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.GroupByKey;
 import org.apache.beam.sdk.transforms.Keys;
-import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.transforms.ParDo;
 import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PCollection;
@@ -57,54 +54,20 @@ public class KeyedPValueTrackingVisitorTest {
 
   @Before
   public void setup() {
-
-    @SuppressWarnings("rawtypes")
-    Set<Class<? extends PTransform>> producesKeyed =
-        ImmutableSet.<Class<? extends PTransform>>of(PrimitiveKeyer.class, CompositeKeyer.class);
-    visitor = KeyedPValueTrackingVisitor.create(producesKeyed);
-  }
-
-  @Test
-  public void primitiveProducesKeyedOutputUnkeyedInputKeyedOutput() {
-    PCollection<Integer> keyed =
-        p.apply(Create.<Integer>of(1, 2, 3)).apply(new PrimitiveKeyer<Integer>());
-
-    p.traverseTopologically(visitor);
-    assertThat(visitor.getKeyedPValues(), hasItem(keyed));
-  }
-
-  @Test
-  public void primitiveProducesKeyedOutputKeyedInputKeyedOutut() {
-    PCollection<Integer> keyed =
-        p.apply(Create.<Integer>of(1, 2, 3))
-            .apply("firstKey", new PrimitiveKeyer<Integer>())
-            .apply("secondKey", new PrimitiveKeyer<Integer>());
-
-    p.traverseTopologically(visitor);
-    assertThat(visitor.getKeyedPValues(), hasItem(keyed));
-  }
-
-  @Test
-  public void compositeProducesKeyedOutputUnkeyedInputKeyedOutput() {
-    PCollection<Integer> keyed =
-        p.apply(Create.<Integer>of(1, 2, 3)).apply(new CompositeKeyer<Integer>());
-
-    p.traverseTopologically(visitor);
-    assertThat(visitor.getKeyedPValues(), hasItem(keyed));
+    p = TestPipeline.create();
+    visitor = KeyedPValueTrackingVisitor.create();
   }
 
   @Test
-  public void compositeProducesKeyedOutputKeyedInputKeyedOutut() {
-    PCollection<Integer> keyed =
-        p.apply(Create.<Integer>of(1, 2, 3))
-            .apply("firstKey", new CompositeKeyer<Integer>())
-            .apply("secondKey", new CompositeKeyer<Integer>());
+  public void groupByKeyProducesKeyedOutput() {
+    PCollection<KV<String, Iterable<Integer>>> keyed =
+        p.apply(Create.of(KV.of("foo", 3)))
+            .apply(GroupByKey.<String, Integer>create());
 
     p.traverseTopologically(visitor);
     assertThat(visitor.getKeyedPValues(), hasItem(keyed));
   }
 
-
   @Test
   public void noInputUnkeyedOutput() {
     PCollection<KV<Integer, Iterable<Void>>> unkeyed =
@@ -117,26 +80,17 @@ public class KeyedPValueTrackingVisitorTest {
   }
 
   @Test
-  public void keyedInputNotProducesKeyedOutputUnkeyedOutput() {
-    PCollection<Integer> onceKeyed =
-        p.apply(Create.<Integer>of(1, 2, 3))
-            .apply(new PrimitiveKeyer<Integer>())
-            .apply(ParDo.of(new IdentityFn<Integer>()));
+  public void keyedInputWithoutKeyPreserving() {
+    PCollection<KV<String, Iterable<Integer>>> onceKeyed =
+        p.apply(Create.of(KV.of("hello", 42)))
+            .apply(GroupByKey.<String, Integer>create())
+            .apply(ParDo.of(new IdentityFn<KV<String, Iterable<Integer>>>()));
 
     p.traverseTopologically(visitor);
     assertThat(visitor.getKeyedPValues(), not(hasItem(onceKeyed)));
   }
 
   @Test
-  public void unkeyedInputNotProducesKeyedOutputUnkeyedOutput() {
-    PCollection<Integer> unkeyed =
-        p.apply(Create.<Integer>of(1, 2, 3)).apply(ParDo.of(new IdentityFn<Integer>()));
-
-    p.traverseTopologically(visitor);
-    assertThat(visitor.getKeyedPValues(), not(hasItem(unkeyed)));
-  }
-
-  @Test
   public void traverseMultipleTimesThrows() {
     p.apply(
             Create.<KV<Integer, Void>>of(
@@ -161,22 +115,6 @@ public class KeyedPValueTrackingVisitorTest {
     visitor.getKeyedPValues();
   }
 
-  private static class PrimitiveKeyer<K> extends PTransform<PCollection<K>, PCollection<K>> {
-    @Override
-    public PCollection<K> expand(PCollection<K> input) {
-      return PCollection.<K>createPrimitiveOutputInternal(
-              input.getPipeline(), input.getWindowingStrategy(), input.isBounded())
-          .setCoder(input.getCoder());
-    }
-  }
-
-  private static class CompositeKeyer<K> extends PTransform<PCollection<K>, PCollection<K>> {
-    @Override
-    public PCollection<K> expand(PCollection<K> input) {
-      return input.apply(new PrimitiveKeyer<K>()).apply(ParDo.of(new IdentityFn<K>()));
-    }
-  }
-
   private static class IdentityFn<K> extends DoFn<K, K> {
     @ProcessElement
     public void processElement(ProcessContext c) throws Exception {