You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by ke...@apache.org on 2017/10/17 21:08:15 UTC
[07/14] beam git commit: Add custom rehydration for Combine
Add custom rehydration for Combine
Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/92209c32
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/92209c32
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/92209c32
Branch: refs/heads/master
Commit: 92209c323eb54e8a57b496eb2035da44fec00714
Parents: 6abf6f5
Author: Kenneth Knowles <kl...@google.com>
Authored: Tue Oct 3 11:40:54 2017 -0700
Committer: Kenneth Knowles <ke...@apache.org>
Committed: Tue Oct 17 12:45:11 2017 -0700
----------------------------------------------------------------------
.../core/construction/CombineTranslation.java | 165 ++++++++++++++++++-
.../construction/CombineTranslationTest.java | 16 +-
2 files changed, 161 insertions(+), 20 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/beam/blob/92209c32/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/CombineTranslation.java
----------------------------------------------------------------------
diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/CombineTranslation.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/CombineTranslation.java
index 69591ee..21796aa 100644
--- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/CombineTranslation.java
+++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/CombineTranslation.java
@@ -22,12 +22,15 @@ import static com.google.common.base.Preconditions.checkArgument;
import static org.apache.beam.runners.core.construction.PTransformTranslation.COMBINE_TRANSFORM_URN;
import com.google.auto.service.AutoService;
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterables;
import com.google.protobuf.ByteString;
import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
+import javax.annotation.Nonnull;
import org.apache.beam.model.pipeline.v1.RunnerApi;
import org.apache.beam.model.pipeline.v1.RunnerApi.CombinePayload;
import org.apache.beam.model.pipeline.v1.RunnerApi.Components;
@@ -52,12 +55,12 @@ import org.apache.beam.sdk.values.PCollection;
* RunnerApi.CombinePayload} protos.
*/
public class CombineTranslation {
+
public static final String JAVA_SERIALIZED_COMBINE_FN_URN = "urn:beam:combinefn:javasdk:v1";
/** A {@link TransformPayloadTranslator} for {@link Combine.PerKey}. */
public static class CombinePayloadTranslator
- extends PTransformTranslation.TransformPayloadTranslator.WithDefaultRehydration<
- Combine.PerKey<?, ?, ?>> {
+ implements PTransformTranslation.TransformPayloadTranslator<Combine.PerKey<?, ?, ?>> {
public static TransformPayloadTranslator create() {
return new CombinePayloadTranslator();
}
@@ -73,13 +76,25 @@ public class CombineTranslation {
public FunctionSpec translate(
AppliedPTransform<?, ?, Combine.PerKey<?, ?, ?>> transform, SdkComponents components)
throws IOException {
- CombinePayload payload = toProto(transform, components);
- return RunnerApi.FunctionSpec.newBuilder()
+ return FunctionSpec.newBuilder()
.setUrn(COMBINE_TRANSFORM_URN)
- .setPayload(payload.toByteString())
+ .setPayload(payloadForCombine((AppliedPTransform) transform, components).toByteString())
.build();
}
+ @Override
+ public PTransformTranslation.RawPTransform<?, ?> rehydrate(
+ RunnerApi.PTransform protoTransform, RehydratedComponents rehydratedComponents)
+ throws IOException {
+ checkArgument(
+ protoTransform.getSpec() != null,
+ "%s received transform with null spec",
+ getClass().getSimpleName());
+ checkArgument(protoTransform.getSpec().getUrn().equals(COMBINE_TRANSFORM_URN));
+ return new RawCombine<>(
+ CombinePayload.parseFrom(protoTransform.getSpec().getPayload()), rehydratedComponents);
+ }
+
/** Registers {@link CombinePayloadTranslator}. */
@AutoService(TransformPayloadTranslatorRegistrar.class)
public static class Registrar implements TransformPayloadTranslatorRegistrar {
@@ -90,13 +105,147 @@ public class CombineTranslation {
}
@Override
- public Map<String, TransformPayloadTranslator> getTransformRehydrators() {
- return Collections.emptyMap();
+ public Map<String, ? extends TransformPayloadTranslator> getTransformRehydrators() {
+ return Collections.singletonMap(COMBINE_TRANSFORM_URN, new CombinePayloadTranslator());
+ }
+ }
+ }
+
+ /**
+ * These methods drive to-proto translation for both Java SDK transforms and rehydrated
+ * transforms.
+ */
+ interface CombineLike {
+ RunnerApi.SdkFunctionSpec getCombineFn();
+
+ Coder<?> getAccumulatorCoder();
+
+ Map<String, RunnerApi.SideInput> getSideInputs();
+ }
+
+ /** Produces a {@link RunnerApi.CombinePayload} from a portable {@link CombineLike}. */
+ static RunnerApi.CombinePayload payloadForCombineLike(
+ CombineLike combine, SdkComponents components) throws IOException {
+ return RunnerApi.CombinePayload.newBuilder()
+ .setAccumulatorCoderId(components.registerCoder(combine.getAccumulatorCoder()))
+ .putAllSideInputs(combine.getSideInputs())
+ .setCombineFn(combine.getCombineFn())
+ .build();
+ }
+
+ static <K, InputT, OutputT> CombinePayload payloadForCombine(
+ final AppliedPTransform<
+ PCollection<KV<K, InputT>>, PCollection<KV<K, OutputT>>,
+ Combine.PerKey<K, InputT, OutputT>>
+ combine,
+ SdkComponents components)
+ throws IOException {
+
+ return payloadForCombineLike(
+ new CombineLike() {
+ @Override
+ public SdkFunctionSpec getCombineFn() {
+ return SdkFunctionSpec.newBuilder()
+ // TODO: Set Java SDK Environment
+ .setSpec(
+ FunctionSpec.newBuilder()
+ .setUrn(JAVA_SERIALIZED_COMBINE_FN_URN)
+ .setPayload(
+ ByteString.copyFrom(
+ SerializableUtils.serializeToByteArray(
+ combine.getTransform().getFn())))
+ .build())
+ .build();
+ }
+
+ @Override
+ public Coder<?> getAccumulatorCoder() {
+ GlobalCombineFn<?, ?, ?> combineFn = combine.getTransform().getFn();
+ try {
+ return extractAccumulatorCoder(combineFn, (AppliedPTransform) combine);
+ } catch (CannotProvideCoderException e) {
+ throw new IllegalStateException(e);
+ }
+ }
+
+ @Override
+ public Map<String, SideInput> getSideInputs() {
+ // TODO: support side inputs
+ return ImmutableMap.of();
+ }
+ },
+ components);
+ }
+
+ private static class RawCombine<K, InputT, AccumT, OutputT>
+ extends PTransformTranslation.RawPTransform<
+ PCollection<KV<K, InputT>>, PCollection<KV<K, OutputT>>>
+ implements CombineLike {
+
+ private final transient RehydratedComponents rehydratedComponents;
+ private final FunctionSpec spec;
+ private final CombinePayload payload;
+ private final Coder<AccumT> accumulatorCoder;
+
+ private RawCombine(CombinePayload payload, RehydratedComponents rehydratedComponents) {
+ this.rehydratedComponents = rehydratedComponents;
+ this.payload = payload;
+ this.spec =
+ FunctionSpec.newBuilder()
+ .setUrn(COMBINE_TRANSFORM_URN)
+ .setPayload(payload.toByteString())
+ .build();
+
+ // Eagerly extract the coder to throw a good exception here
+ try {
+ this.accumulatorCoder =
+ (Coder<AccumT>) rehydratedComponents.getCoder(payload.getAccumulatorCoderId());
+ } catch (IOException exc) {
+ throw new IllegalArgumentException(
+ String.format(
+ "Failure extracting accumulator coder with id '%s' for %s",
+ payload.getAccumulatorCoderId(), Combine.class.getSimpleName()),
+ exc);
}
}
+
+ @Override
+ public String getUrn() {
+ return COMBINE_TRANSFORM_URN;
+ }
+
+ @Nonnull
+ @Override
+ public FunctionSpec getSpec() {
+ return spec;
+ }
+
+ @Override
+ public RunnerApi.FunctionSpec migrate(SdkComponents sdkComponents) throws IOException {
+ return RunnerApi.FunctionSpec.newBuilder()
+ .setUrn(COMBINE_TRANSFORM_URN)
+ .setPayload(payloadForCombineLike(this, sdkComponents).toByteString())
+ .build();
+ }
+
+ @Override
+ public SdkFunctionSpec getCombineFn() {
+ return payload.getCombineFn();
+ }
+
+ @Override
+ public Coder<?> getAccumulatorCoder() {
+ return accumulatorCoder;
+ }
+
+ @Override
+ public Map<String, SideInput> getSideInputs() {
+ return payload.getSideInputsMap();
+ }
}
- public static CombinePayload toProto(
+ @VisibleForTesting
+ static CombinePayload toProto(
AppliedPTransform<?, ?, Combine.PerKey<?, ?, ?>> combine, SdkComponents sdkComponents)
throws IOException {
GlobalCombineFn<?, ?, ?> combineFn = combine.getTransform().getFn();
http://git-wip-us.apache.org/repos/asf/beam/blob/92209c32/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/CombineTranslationTest.java
----------------------------------------------------------------------
diff --git a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/CombineTranslationTest.java b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/CombineTranslationTest.java
index 8740d7f..af162d3 100644
--- a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/CombineTranslationTest.java
+++ b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/CombineTranslationTest.java
@@ -52,15 +52,11 @@ import org.junit.runners.Parameterized;
import org.junit.runners.Parameterized.Parameter;
import org.junit.runners.Parameterized.Parameters;
-/**
- * Tests for {@link CombineTranslation}.
- */
+/** Tests for {@link CombineTranslation}. */
@RunWith(Enclosed.class)
public class CombineTranslationTest {
- /**
- * Tests that simple {@link CombineFn CombineFns} can be translated to and from proto.
- */
+ /** Tests that simple {@link CombineFn CombineFns} can be translated to and from proto. */
@RunWith(Parameterized.class)
public static class TranslateSimpleCombinesTest {
@Parameters(name = "{index}: {0}")
@@ -111,14 +107,10 @@ public class CombineTranslationTest {
}
}
-
- /**
- * Tests that a {@link CombineFnWithContext} can be translated.
- */
+ /** Tests that a {@link CombineFnWithContext} can be translated. */
@RunWith(JUnit4.class)
public static class ValidateCombineWithContextTest {
- @Rule
- public TestPipeline pipeline = TestPipeline.create();
+ @Rule public TestPipeline pipeline = TestPipeline.create();
@Test
public void testToFromProtoWithSideInputs() throws Exception {