You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by ke...@apache.org on 2017/10/17 21:08:15 UTC

[07/14] beam git commit: Add custom rehydration for Combine

Add custom rehydration for Combine


Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/92209c32
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/92209c32
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/92209c32

Branch: refs/heads/master
Commit: 92209c323eb54e8a57b496eb2035da44fec00714
Parents: 6abf6f5
Author: Kenneth Knowles <kl...@google.com>
Authored: Tue Oct 3 11:40:54 2017 -0700
Committer: Kenneth Knowles <ke...@apache.org>
Committed: Tue Oct 17 12:45:11 2017 -0700

----------------------------------------------------------------------
 .../core/construction/CombineTranslation.java   | 165 ++++++++++++++++++-
 .../construction/CombineTranslationTest.java    |  16 +-
 2 files changed, 161 insertions(+), 20 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/beam/blob/92209c32/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/CombineTranslation.java
----------------------------------------------------------------------
diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/CombineTranslation.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/CombineTranslation.java
index 69591ee..21796aa 100644
--- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/CombineTranslation.java
+++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/CombineTranslation.java
@@ -22,12 +22,15 @@ import static com.google.common.base.Preconditions.checkArgument;
 import static org.apache.beam.runners.core.construction.PTransformTranslation.COMBINE_TRANSFORM_URN;
 
 import com.google.auto.service.AutoService;
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.collect.ImmutableMap;
 import com.google.common.collect.Iterables;
 import com.google.protobuf.ByteString;
 import java.io.IOException;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.Map;
+import javax.annotation.Nonnull;
 import org.apache.beam.model.pipeline.v1.RunnerApi;
 import org.apache.beam.model.pipeline.v1.RunnerApi.CombinePayload;
 import org.apache.beam.model.pipeline.v1.RunnerApi.Components;
@@ -52,12 +55,12 @@ import org.apache.beam.sdk.values.PCollection;
  * RunnerApi.CombinePayload} protos.
  */
 public class CombineTranslation {
+
   public static final String JAVA_SERIALIZED_COMBINE_FN_URN = "urn:beam:combinefn:javasdk:v1";
 
   /** A {@link TransformPayloadTranslator} for {@link Combine.PerKey}. */
   public static class CombinePayloadTranslator
-      extends PTransformTranslation.TransformPayloadTranslator.WithDefaultRehydration<
-          Combine.PerKey<?, ?, ?>> {
+      implements PTransformTranslation.TransformPayloadTranslator<Combine.PerKey<?, ?, ?>> {
     public static TransformPayloadTranslator create() {
       return new CombinePayloadTranslator();
     }
@@ -73,13 +76,25 @@ public class CombineTranslation {
     public FunctionSpec translate(
         AppliedPTransform<?, ?, Combine.PerKey<?, ?, ?>> transform, SdkComponents components)
         throws IOException {
-      CombinePayload payload = toProto(transform, components);
-      return RunnerApi.FunctionSpec.newBuilder()
+      return FunctionSpec.newBuilder()
           .setUrn(COMBINE_TRANSFORM_URN)
-          .setPayload(payload.toByteString())
+          .setPayload(payloadForCombine((AppliedPTransform) transform, components).toByteString())
           .build();
     }
 
+    @Override
+    public PTransformTranslation.RawPTransform<?, ?> rehydrate(
+        RunnerApi.PTransform protoTransform, RehydratedComponents rehydratedComponents)
+        throws IOException {
+      checkArgument(
+          protoTransform.getSpec() != null,
+          "%s received transform with null spec",
+          getClass().getSimpleName());
+      checkArgument(protoTransform.getSpec().getUrn().equals(COMBINE_TRANSFORM_URN));
+      return new RawCombine<>(
+          CombinePayload.parseFrom(protoTransform.getSpec().getPayload()), rehydratedComponents);
+    }
+
     /** Registers {@link CombinePayloadTranslator}. */
     @AutoService(TransformPayloadTranslatorRegistrar.class)
     public static class Registrar implements TransformPayloadTranslatorRegistrar {
@@ -90,13 +105,147 @@ public class CombineTranslation {
       }
 
       @Override
-      public Map<String, TransformPayloadTranslator> getTransformRehydrators() {
-        return Collections.emptyMap();
+      public Map<String, ? extends TransformPayloadTranslator> getTransformRehydrators() {
+        return Collections.singletonMap(COMBINE_TRANSFORM_URN, new CombinePayloadTranslator());
+      }
+    }
+  }
+
+  /**
+   * These methods drive to-proto translation for both Java SDK transforms and rehydrated
+   * transforms.
+   */
+  interface CombineLike {
+    RunnerApi.SdkFunctionSpec getCombineFn();
+
+    Coder<?> getAccumulatorCoder();
+
+    Map<String, RunnerApi.SideInput> getSideInputs();
+  }
+
+  /** Produces a {@link RunnerApi.CombinePayload} from a portable {@link CombineLike}. */
+  static RunnerApi.CombinePayload payloadForCombineLike(
+      CombineLike combine, SdkComponents components) throws IOException {
+    return RunnerApi.CombinePayload.newBuilder()
+        .setAccumulatorCoderId(components.registerCoder(combine.getAccumulatorCoder()))
+        .putAllSideInputs(combine.getSideInputs())
+        .setCombineFn(combine.getCombineFn())
+        .build();
+  }
+
+  static <K, InputT, OutputT> CombinePayload payloadForCombine(
+      final AppliedPTransform<
+              PCollection<KV<K, InputT>>, PCollection<KV<K, OutputT>>,
+              Combine.PerKey<K, InputT, OutputT>>
+          combine,
+      SdkComponents components)
+      throws IOException {
+
+    return payloadForCombineLike(
+        new CombineLike() {
+          @Override
+          public SdkFunctionSpec getCombineFn() {
+            return SdkFunctionSpec.newBuilder()
+                // TODO: Set Java SDK Environment
+                .setSpec(
+                    FunctionSpec.newBuilder()
+                        .setUrn(JAVA_SERIALIZED_COMBINE_FN_URN)
+                        .setPayload(
+                            ByteString.copyFrom(
+                                SerializableUtils.serializeToByteArray(
+                                    combine.getTransform().getFn())))
+                        .build())
+                .build();
+          }
+
+          @Override
+          public Coder<?> getAccumulatorCoder() {
+            GlobalCombineFn<?, ?, ?> combineFn = combine.getTransform().getFn();
+            try {
+              return extractAccumulatorCoder(combineFn, (AppliedPTransform) combine);
+            } catch (CannotProvideCoderException e) {
+              throw new IllegalStateException(e);
+            }
+          }
+
+          @Override
+          public Map<String, SideInput> getSideInputs() {
+            // TODO: support side inputs
+            return ImmutableMap.of();
+          }
+        },
+        components);
+  }
+
+  private static class RawCombine<K, InputT, AccumT, OutputT>
+      extends PTransformTranslation.RawPTransform<
+          PCollection<KV<K, InputT>>, PCollection<KV<K, OutputT>>>
+      implements CombineLike {
+
+    private final transient RehydratedComponents rehydratedComponents;
+    private final FunctionSpec spec;
+    private final CombinePayload payload;
+    private final Coder<AccumT> accumulatorCoder;
+
+    private RawCombine(CombinePayload payload, RehydratedComponents rehydratedComponents) {
+      this.rehydratedComponents = rehydratedComponents;
+      this.payload = payload;
+      this.spec =
+          FunctionSpec.newBuilder()
+              .setUrn(COMBINE_TRANSFORM_URN)
+              .setPayload(payload.toByteString())
+              .build();
+
+      // Eagerly extract the coder to throw a good exception here
+      try {
+        this.accumulatorCoder =
+            (Coder<AccumT>) rehydratedComponents.getCoder(payload.getAccumulatorCoderId());
+      } catch (IOException exc) {
+        throw new IllegalArgumentException(
+            String.format(
+                "Failure extracting accumulator coder with id '%s' for %s",
+                payload.getAccumulatorCoderId(), Combine.class.getSimpleName()),
+            exc);
       }
     }
+
+    @Override
+    public String getUrn() {
+      return COMBINE_TRANSFORM_URN;
+    }
+
+    @Nonnull
+    @Override
+    public FunctionSpec getSpec() {
+      return spec;
+    }
+
+    @Override
+    public RunnerApi.FunctionSpec migrate(SdkComponents sdkComponents) throws IOException {
+      return RunnerApi.FunctionSpec.newBuilder()
+          .setUrn(COMBINE_TRANSFORM_URN)
+          .setPayload(payloadForCombineLike(this, sdkComponents).toByteString())
+          .build();
+    }
+
+    @Override
+    public SdkFunctionSpec getCombineFn() {
+      return payload.getCombineFn();
+    }
+
+    @Override
+    public Coder<?> getAccumulatorCoder() {
+      return accumulatorCoder;
+    }
+
+    @Override
+    public Map<String, SideInput> getSideInputs() {
+      return payload.getSideInputsMap();
+    }
   }
 
-  public static CombinePayload toProto(
+  @VisibleForTesting
+  static CombinePayload toProto(
       AppliedPTransform<?, ?, Combine.PerKey<?, ?, ?>> combine, SdkComponents sdkComponents)
       throws IOException {
     GlobalCombineFn<?, ?, ?> combineFn = combine.getTransform().getFn();

http://git-wip-us.apache.org/repos/asf/beam/blob/92209c32/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/CombineTranslationTest.java
----------------------------------------------------------------------
diff --git a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/CombineTranslationTest.java b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/CombineTranslationTest.java
index 8740d7f..af162d3 100644
--- a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/CombineTranslationTest.java
+++ b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/CombineTranslationTest.java
@@ -52,15 +52,11 @@ import org.junit.runners.Parameterized;
 import org.junit.runners.Parameterized.Parameter;
 import org.junit.runners.Parameterized.Parameters;
 
-/**
- * Tests for {@link CombineTranslation}.
- */
+/** Tests for {@link CombineTranslation}. */
 @RunWith(Enclosed.class)
 public class CombineTranslationTest {
 
-  /**
-   * Tests that simple {@link CombineFn CombineFns} can be translated to and from proto.
-   */
+  /** Tests that simple {@link CombineFn CombineFns} can be translated to and from proto. */
   @RunWith(Parameterized.class)
   public static class TranslateSimpleCombinesTest {
     @Parameters(name = "{index}: {0}")
@@ -111,14 +107,10 @@ public class CombineTranslationTest {
     }
   }
 
-
-  /**
-   * Tests that a {@link CombineFnWithContext} can be translated.
-   */
+  /** Tests that a {@link CombineFnWithContext} can be translated. */
   @RunWith(JUnit4.class)
   public static class ValidateCombineWithContextTest {
-    @Rule
-    public TestPipeline pipeline = TestPipeline.create();
+    @Rule public TestPipeline pipeline = TestPipeline.create();
 
     @Test
     public void testToFromProtoWithSideInputs() throws Exception {