You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by tg...@apache.org on 2017/07/21 17:03:43 UTC

[2/2] beam git commit: Register a PTransformTranslator for Combine

Register a PTransformTranslator for Combine

Include the Combine Payload in the Runner API Graph.

Add getCombineFn(AppliedPTransform) to extract the CombineFn from an
arbitrary transform.

Update Pipeline Translation tests to include accumulator coders.


Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/c2110c97
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/c2110c97
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/c2110c97

Branch: refs/heads/master
Commit: c2110c97d530b8e90387bf99b3ac7d36201d85c7
Parents: 1d9160f
Author: Thomas Groh <tg...@google.com>
Authored: Wed Jul 19 10:55:33 2017 -0700
Committer: Thomas Groh <tg...@google.com>
Committed: Fri Jul 21 10:03:30 2017 -0700

----------------------------------------------------------------------
 .../core/construction/CombineTranslation.java   |  83 ++++++++-
 .../construction/PTransformTranslation.java     |   3 +
 .../construction/CombineTranslationTest.java    | 171 +++++++++++++++----
 .../core/construction/SdkComponentsTest.java    |  14 +-
 4 files changed, 229 insertions(+), 42 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/beam/blob/c2110c97/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/CombineTranslation.java
----------------------------------------------------------------------
diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/CombineTranslation.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/CombineTranslation.java
index 472b6f8..2e5b02c 100644
--- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/CombineTranslation.java
+++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/CombineTranslation.java
@@ -19,19 +19,24 @@
 package org.apache.beam.runners.core.construction;
 
 import static com.google.common.base.Preconditions.checkArgument;
+import static org.apache.beam.runners.core.construction.PTransformTranslation.COMBINE_TRANSFORM_URN;
 
+import com.google.auto.service.AutoService;
 import com.google.common.collect.Iterables;
 import com.google.protobuf.Any;
 import com.google.protobuf.ByteString;
 import com.google.protobuf.BytesValue;
 import java.io.IOException;
+import java.util.Collections;
 import java.util.HashMap;
 import java.util.Map;
+import org.apache.beam.runners.core.construction.PTransformTranslation.TransformPayloadTranslator;
 import org.apache.beam.sdk.coders.CannotProvideCoderException;
 import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.coders.KvCoder;
 import org.apache.beam.sdk.common.runner.v1.RunnerApi;
 import org.apache.beam.sdk.common.runner.v1.RunnerApi.CombinePayload;
+import org.apache.beam.sdk.common.runner.v1.RunnerApi.Components;
 import org.apache.beam.sdk.common.runner.v1.RunnerApi.FunctionSpec;
 import org.apache.beam.sdk.common.runner.v1.RunnerApi.SdkFunctionSpec;
 import org.apache.beam.sdk.common.runner.v1.RunnerApi.SideInput;
@@ -49,7 +54,47 @@ import org.apache.beam.sdk.values.PCollection;
  * RunnerApi.CombinePayload} protos.
  */
 public class CombineTranslation {
-  public static final String JAVA_SERIALIZED_COMBINE_FN_URN = "urn:beam:java:combinefn:v1";
+  public static final String JAVA_SERIALIZED_COMBINE_FN_URN = "urn:beam:combinefn:javasdk:v1";
+
+   /**
+   * A {@link TransformPayloadTranslator} for {@link Combine.PerKey}.
+   */
+  public static class CombinePayloadTranslator
+      implements PTransformTranslation.TransformPayloadTranslator<Combine.PerKey<?, ?, ?>> {
+    public static TransformPayloadTranslator create() {
+      return new CombinePayloadTranslator();
+    }
+
+    private CombinePayloadTranslator() {}
+
+    @Override
+    public String getUrn(Combine.PerKey<?, ?, ?> transform) {
+      return COMBINE_TRANSFORM_URN;
+    }
+
+    @Override
+    public FunctionSpec translate(
+        AppliedPTransform<?, ?, Combine.PerKey<?, ?, ?>> transform, SdkComponents components)
+        throws IOException {
+      CombinePayload payload = toProto(transform, components);
+      return RunnerApi.FunctionSpec.newBuilder()
+          .setUrn(COMBINE_TRANSFORM_URN)
+          .setParameter(Any.pack(payload))
+          .build();
+    }
+
+    /**
+     * Registers {@link CombinePayloadTranslator}.
+     */
+    @AutoService(TransformPayloadTranslatorRegistrar.class)
+    public static class Registrar implements TransformPayloadTranslatorRegistrar {
+      @Override
+      public Map<? extends Class<? extends PTransform>, ? extends TransformPayloadTranslator>
+          getTransformPayloadTranslators() {
+        return Collections.singletonMap(Combine.PerKey.class, new CombinePayloadTranslator());
+      }
+    }
+  }
 
   public static CombinePayload toProto(
       AppliedPTransform<?, ?, Combine.PerKey<?, ?, ?>> combine, SdkComponents sdkComponents)
@@ -72,10 +117,11 @@ public class CombineTranslation {
       GlobalCombineFn<InputT, AccumT, ?> combineFn,
       AppliedPTransform<PCollection<KV<K, InputT>>, ?, Combine.PerKey<K, InputT, ?>> transform)
       throws CannotProvideCoderException {
-    KvCoder<K, InputT> inputCoder =
-        (KvCoder<K, InputT>)
-            ((PCollection<KV<K, InputT>>) Iterables.getOnlyElement(transform.getInputs().values()))
-                .getCoder();
+    @SuppressWarnings("unchecked")
+    PCollection<KV<K, InputT>> mainInput =
+        (PCollection<KV<K, InputT>>)
+            Iterables.getOnlyElement(TransformInputs.nonAdditionalInputs(transform));
+    KvCoder<K, InputT> inputCoder = (KvCoder<K, InputT>) mainInput.getCoder();
     return AppliedCombineFn.withInputCoder(
             combineFn,
             transform.getPipeline().getCoderRegistry(),
@@ -108,6 +154,14 @@ public class CombineTranslation {
     return CoderTranslation.fromProto(components.getCodersOrThrow(id), components);
   }
 
+  public static Coder<?> getAccumulatorCoder(
+      AppliedPTransform<?, ?, ?> transform) throws IOException {
+    SdkComponents sdkComponents = SdkComponents.create();
+    String id = getCombinePayload(transform, sdkComponents).getAccumulatorCoderId();
+    Components components = sdkComponents.toComponents();
+    return CoderTranslation.fromProto(components.getCodersOrThrow(id), components);
+  }
+
   public static GlobalCombineFn<?, ?, ?> getCombineFn(CombinePayload payload)
       throws IOException {
     checkArgument(payload.getCombineFn().getSpec().getUrn().equals(JAVA_SERIALIZED_COMBINE_FN_URN));
@@ -122,4 +176,23 @@ public class CombineTranslation {
                 .toByteArray(),
             "CombineFn");
   }
+
+  public static GlobalCombineFn<?, ?, ?> getCombineFn(AppliedPTransform<?, ?, ?> transform)
+      throws IOException {
+    return getCombineFn(getCombinePayload(transform));
+  }
+
+  private static CombinePayload getCombinePayload(AppliedPTransform<?, ?, ?> transform)
+      throws IOException {
+    return getCombinePayload(transform, SdkComponents.create());
+  }
+
+  private static CombinePayload getCombinePayload(
+      AppliedPTransform<?, ?, ?> transform, SdkComponents components) throws IOException {
+    return PTransformTranslation.toProto(
+            transform, Collections.<AppliedPTransform<?, ?, ?>>emptyList(), components)
+        .getSpec()
+        .getParameter()
+        .unpack(CombinePayload.class);
+  }
 }

http://git-wip-us.apache.org/repos/asf/beam/blob/c2110c97/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformTranslation.java
----------------------------------------------------------------------
diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformTranslation.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformTranslation.java
index 0b4a2ab..3b94724 100644
--- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformTranslation.java
+++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformTranslation.java
@@ -52,6 +52,9 @@ public class PTransformTranslation {
   public static final String WINDOW_TRANSFORM_URN = "urn:beam:transform:window:v1";
   public static final String TEST_STREAM_TRANSFORM_URN = "urn:beam:transform:teststream:v1";
 
+  // Not strictly a primitive transform
+  public static final String COMBINE_TRANSFORM_URN = "urn:beam:transform:combine:v1";
+
   // Less well-known. And where shall these live?
   public static final String WRITE_FILES_TRANSFORM_URN = "urn:beam:transform:write_files:0.1";
 

http://git-wip-us.apache.org/repos/asf/beam/blob/c2110c97/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/CombineTranslationTest.java
----------------------------------------------------------------------
diff --git a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/CombineTranslationTest.java b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/CombineTranslationTest.java
index 6251545..b3b42ab 100644
--- a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/CombineTranslationTest.java
+++ b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/CombineTranslationTest.java
@@ -35,13 +35,19 @@ import org.apache.beam.sdk.testing.TestPipeline;
 import org.apache.beam.sdk.transforms.Combine;
 import org.apache.beam.sdk.transforms.Combine.BinaryCombineIntegerFn;
 import org.apache.beam.sdk.transforms.Combine.CombineFn;
+import org.apache.beam.sdk.transforms.CombineWithContext.CombineFnWithContext;
+import org.apache.beam.sdk.transforms.CombineWithContext.Context;
 import org.apache.beam.sdk.transforms.Count;
 import org.apache.beam.sdk.transforms.Create;
 import org.apache.beam.sdk.transforms.Sum;
+import org.apache.beam.sdk.transforms.View;
 import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PCollectionView;
 import org.junit.Rule;
 import org.junit.Test;
+import org.junit.experimental.runners.Enclosed;
 import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
 import org.junit.runners.Parameterized;
 import org.junit.runners.Parameterized.Parameter;
 import org.junit.runners.Parameterized.Parameters;
@@ -49,46 +55,100 @@ import org.junit.runners.Parameterized.Parameters;
 /**
  * Tests for {@link CombineTranslation}.
  */
-@RunWith(Parameterized.class)
+@RunWith(Enclosed.class)
 public class CombineTranslationTest {
-  @Parameters(name = "{index}: {0}")
-  public static Iterable<Combine.CombineFn<Integer, ?, ?>> params() {
-    BinaryCombineIntegerFn sum = Sum.ofIntegers();
-    CombineFn<Integer, ?, Long> count = Count.combineFn();
-    TestCombineFn test = new TestCombineFn();
-    return ImmutableList.<CombineFn<Integer, ?, ?>>builder().add(sum).add(count).add(test).build();
+
+  /**
+   * Tests that simple {@link CombineFn CombineFns} can be translated to and from proto.
+   */
+  @RunWith(Parameterized.class)
+  public static class TranslateSimpleCombinesTest {
+    @Parameters(name = "{index}: {0}")
+    public static Iterable<Combine.CombineFn<Integer, ?, ?>> params() {
+      BinaryCombineIntegerFn sum = Sum.ofIntegers();
+      CombineFn<Integer, ?, Long> count = Count.combineFn();
+      TestCombineFn test = new TestCombineFn();
+      return ImmutableList.<CombineFn<Integer, ?, ?>>builder()
+          .add(sum)
+          .add(count)
+          .add(test)
+          .build();
+    }
+
+    @Rule public TestPipeline pipeline = TestPipeline.create();
+
+    @Parameter(0)
+    public Combine.CombineFn<Integer, ?, ?> combineFn;
+
+    @Test
+    public void testToFromProto() throws Exception {
+      PCollection<Integer> input = pipeline.apply(Create.of(1, 2, 3));
+      input.apply(Combine.globally(combineFn));
+      final AtomicReference<AppliedPTransform<?, ?, Combine.PerKey<?, ?, ?>>> combine =
+          new AtomicReference<>();
+      pipeline.traverseTopologically(
+          new PipelineVisitor.Defaults() {
+            @Override
+            public void leaveCompositeTransform(Node node) {
+              if (node.getTransform() instanceof Combine.PerKey) {
+                checkState(combine.get() == null);
+                combine.set((AppliedPTransform) node.toAppliedPTransform(getPipeline()));
+              }
+            }
+          });
+      checkState(combine.get() != null);
+      assertEquals(combineFn, CombineTranslation.getCombineFn(combine.get()));
+
+      SdkComponents sdkComponents = SdkComponents.create();
+      CombinePayload combineProto = CombineTranslation.toProto(combine.get(), sdkComponents);
+      RunnerApi.Components componentsProto = sdkComponents.toComponents();
+
+      assertEquals(combineFn.getAccumulatorCoder(pipeline.getCoderRegistry(), input.getCoder()),
+          CombineTranslation.getAccumulatorCoder(combineProto, componentsProto));
+      assertEquals(combineFn, CombineTranslation.getCombineFn(combineProto));
+    }
   }
 
-  @Rule public TestPipeline pipeline = TestPipeline.create();
-  @Parameter(0)
-  public Combine.CombineFn<Integer, ?, ?> combineFn;
-
-  @Test
-  public void testToFromProto() throws Exception {
-    PCollection<Integer> input = pipeline.apply(Create.of(1, 2, 3));
-    input.apply(Combine.globally(combineFn));
-    final AtomicReference<AppliedPTransform<?, ?, Combine.PerKey<?, ?, ?>>> combine =
-        new AtomicReference<>();
-    pipeline.traverseTopologically(
-        new PipelineVisitor.Defaults() {
-          @Override
-          public void leaveCompositeTransform(Node node) {
-            if (node.getTransform() instanceof Combine.PerKey) {
-              checkState(combine.get() == null);
-              combine.set((AppliedPTransform) node.toAppliedPTransform(getPipeline()));
+
+  /**
+   * Tests that a {@link CombineFnWithContext} can be translated.
+   */
+  @RunWith(JUnit4.class)
+  public static class ValidateCombineWithContextTest {
+    @Rule
+    public TestPipeline pipeline = TestPipeline.create();
+
+    @Test
+    public void testToFromProtoWithSideInputs() throws Exception {
+      PCollection<Integer> input = pipeline.apply(Create.of(1, 2, 3));
+      final PCollectionView<Iterable<String>> sideInput =
+          pipeline.apply(Create.of("foo")).apply(View.<String>asIterable());
+      CombineFnWithContext<Integer, int[], Integer> combineFn = new TestCombineFnWithContext();
+      input.apply(Combine.globally(combineFn).withSideInputs(sideInput).withoutDefaults());
+      final AtomicReference<AppliedPTransform<?, ?, Combine.PerKey<?, ?, ?>>> combine =
+          new AtomicReference<>();
+      pipeline.traverseTopologically(
+          new PipelineVisitor.Defaults() {
+            @Override
+            public void leaveCompositeTransform(Node node) {
+              if (node.getTransform() instanceof Combine.PerKey) {
+                checkState(combine.get() == null);
+                combine.set((AppliedPTransform) node.toAppliedPTransform(getPipeline()));
+              }
             }
-          }
-        });
-    checkState(combine.get() != null);
-
-    SdkComponents sdkComponents = SdkComponents.create();
-    CombinePayload combineProto = CombineTranslation.toProto(combine.get(), sdkComponents);
-    RunnerApi.Components componentsProto = sdkComponents.toComponents();
-
-    assertEquals(
-        combineFn.getAccumulatorCoder(pipeline.getCoderRegistry(), input.getCoder()),
-        CombineTranslation.getAccumulatorCoder(combineProto, componentsProto));
-    assertEquals(combineFn, CombineTranslation.getCombineFn(combineProto));
+          });
+      checkState(combine.get() != null);
+      assertEquals(combineFn, CombineTranslation.getCombineFn(combine.get()));
+
+      SdkComponents sdkComponents = SdkComponents.create();
+      CombinePayload combineProto = CombineTranslation.toProto(combine.get(), sdkComponents);
+      RunnerApi.Components componentsProto = sdkComponents.toComponents();
+
+      assertEquals(
+          combineFn.getAccumulatorCoder(pipeline.getCoderRegistry(), input.getCoder()),
+          CombineTranslation.getAccumulatorCoder(combineProto, componentsProto));
+      assertEquals(combineFn, CombineTranslation.getCombineFn(combineProto));
+    }
   }
 
   private static class TestCombineFn extends Combine.CombineFn<Integer, Void, Void> {
@@ -127,4 +187,43 @@ public class CombineTranslationTest {
       return TestCombineFn.class.hashCode();
     }
   }
+
+  private static class TestCombineFnWithContext
+      extends CombineFnWithContext<Integer, int[], Integer> {
+
+    @Override
+    public int[] createAccumulator(Context c) {
+      return new int[1];
+    }
+
+    @Override
+    public int[] addInput(int[] accumulator, Integer input, Context c) {
+      accumulator[0] += input;
+      return accumulator;
+    }
+
+    @Override
+    public int[] mergeAccumulators(Iterable<int[]> accumulators, Context c) {
+      int[] res = new int[1];
+      for (int[] accum : accumulators) {
+        res[0] += accum[0];
+      }
+      return res;
+    }
+
+    @Override
+    public Integer extractOutput(int[] accumulator, Context c) {
+      return accumulator[0];
+    }
+
+    @Override
+    public boolean equals(Object other) {
+      return other instanceof TestCombineFnWithContext;
+    }
+
+    @Override
+    public int hashCode() {
+      return TestCombineFnWithContext.class.hashCode();
+    }
+  };
 }

http://git-wip-us.apache.org/repos/asf/beam/blob/c2110c97/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/SdkComponentsTest.java
----------------------------------------------------------------------
diff --git a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/SdkComponentsTest.java b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/SdkComponentsTest.java
index 55702ea..ce6a99f 100644
--- a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/SdkComponentsTest.java
+++ b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/SdkComponentsTest.java
@@ -124,6 +124,18 @@ public class SdkComponentsTest {
                   equalTo(windowingStrategies.size()));
             } else {
               transforms.add(node);
+              if (PTransformTranslation.COMBINE_TRANSFORM_URN.equals(
+                  PTransformTranslation.urnForTransformOrNull(node.getTransform()))) {
+                // Combine translation introduces a coder that is not assigned to any PCollection
+                // in the default expansion, and must be explicitly added here.
+                try {
+                  addCoders(
+                      CombineTranslation.getAccumulatorCoder(
+                          node.toAppliedPTransform(getPipeline())));
+                } catch (IOException e) {
+                  throw new RuntimeException(e);
+                }
+              }
             }
           }
 
@@ -146,7 +158,7 @@ public class SdkComponentsTest {
           private void addCoders(Coder<?> coder) {
             coders.add(Equivalence.<Coder<?>>identity().wrap(coder));
             if (coder instanceof StructuredCoder) {
-              for (Coder<?> component : ((StructuredCoder <?>) coder).getComponents()) {
+              for (Coder<?> component : ((StructuredCoder<?>) coder).getComponents()) {
                 addCoders(component);
               }
             }