You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by lc...@apache.org on 2017/06/16 19:03:48 UTC
[2/3] beam git commit: [BEAM-1347] Break apart ProcessBundleHandler
to use service locator pattern based upon URNs.
[BEAM-1347] Break apart ProcessBundleHandler to use service locator pattern based upon URNs.
This cleans up ProcessBundleHandler and allows for separate improvements of the various PTransform handler factories.
Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/c9c1a05d
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/c9c1a05d
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/c9c1a05d
Branch: refs/heads/master
Commit: c9c1a05dc07a9a7e57fefbe6e43f723b330499d5
Parents: 54f3078
Author: Luke Cwik <lc...@google.com>
Authored: Thu Jun 15 16:36:22 2017 -0700
Committer: Luke Cwik <lc...@google.com>
Committed: Fri Jun 16 12:03:06 2017 -0700
----------------------------------------------------------------------
sdks/java/harness/pom.xml | 6 +
.../harness/control/ProcessBundleHandler.java | 293 +++--------
.../beam/runners/core/BeamFnDataReadRunner.java | 70 ++-
.../runners/core/BeamFnDataWriteRunner.java | 67 ++-
.../beam/runners/core/BoundedSourceRunner.java | 74 ++-
.../beam/runners/core/DoFnRunnerFactory.java | 182 +++++++
.../runners/core/PTransformRunnerFactory.java | 81 +++
.../control/ProcessBundleHandlerTest.java | 521 +++----------------
.../runners/core/BeamFnDataReadRunnerTest.java | 112 +++-
.../runners/core/BeamFnDataWriteRunnerTest.java | 120 ++++-
.../runners/core/BoundedSourceRunnerTest.java | 124 ++++-
.../runners/core/DoFnRunnerFactoryTest.java | 209 ++++++++
12 files changed, 1134 insertions(+), 725 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/beam/blob/c9c1a05d/sdks/java/harness/pom.xml
----------------------------------------------------------------------
diff --git a/sdks/java/harness/pom.xml b/sdks/java/harness/pom.xml
index 61a170a..a35481d 100644
--- a/sdks/java/harness/pom.xml
+++ b/sdks/java/harness/pom.xml
@@ -154,6 +154,12 @@
<artifactId>slf4j-api</artifactId>
</dependency>
+ <dependency>
+ <groupId>com.google.auto.service</groupId>
+ <artifactId>auto-service</artifactId>
+ <optional>true</optional>
+ </dependency>
+
<!-- test dependencies -->
<dependency>
<groupId>org.hamcrest</groupId>
http://git-wip-us.apache.org/repos/asf/beam/blob/c9c1a05d/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java
----------------------------------------------------------------------
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java
index e33277a..4c4f73d 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java
@@ -18,51 +18,32 @@
package org.apache.beam.fn.harness.control;
-import static com.google.common.base.Preconditions.checkArgument;
-import static com.google.common.collect.Iterables.getOnlyElement;
-
-import com.google.common.collect.Collections2;
+import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.HashMultimap;
import com.google.common.collect.ImmutableMap;
-import com.google.common.collect.ImmutableMultimap;
import com.google.common.collect.Lists;
import com.google.common.collect.Multimap;
-import com.google.protobuf.ByteString;
-import com.google.protobuf.BytesValue;
-import com.google.protobuf.InvalidProtocolBufferException;
+import com.google.common.collect.Sets;
import com.google.protobuf.Message;
import java.io.IOException;
import java.util.ArrayList;
-import java.util.Collection;
-import java.util.HashSet;
import java.util.List;
import java.util.Map;
-import java.util.Objects;
+import java.util.ServiceLoader;
+import java.util.Set;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;
import org.apache.beam.fn.harness.data.BeamFnDataClient;
-import org.apache.beam.fn.harness.fake.FakeStepContext;
import org.apache.beam.fn.harness.fn.ThrowingConsumer;
import org.apache.beam.fn.harness.fn.ThrowingRunnable;
import org.apache.beam.fn.v1.BeamFnApi;
-import org.apache.beam.runners.core.BeamFnDataReadRunner;
-import org.apache.beam.runners.core.BeamFnDataWriteRunner;
-import org.apache.beam.runners.core.BoundedSourceRunner;
-import org.apache.beam.runners.core.DoFnRunner;
-import org.apache.beam.runners.core.DoFnRunners;
-import org.apache.beam.runners.core.DoFnRunners.OutputManager;
-import org.apache.beam.runners.core.NullSideInputReader;
-import org.apache.beam.runners.dataflow.util.DoFnInfo;
+import org.apache.beam.runners.core.PTransformRunnerFactory;
+import org.apache.beam.runners.core.PTransformRunnerFactory.Registrar;
import org.apache.beam.sdk.common.runner.v1.RunnerApi;
-import org.apache.beam.sdk.io.BoundedSource;
import org.apache.beam.sdk.options.PipelineOptions;
-import org.apache.beam.sdk.options.PipelineOptionsFactory;
-import org.apache.beam.sdk.transforms.DoFn;
-import org.apache.beam.sdk.util.SerializableUtils;
import org.apache.beam.sdk.util.WindowedValue;
-import org.apache.beam.sdk.values.TupleTag;
-import org.apache.beam.sdk.values.WindowingStrategy;
+import org.apache.beam.sdk.util.common.ReflectHelpers;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -75,25 +56,73 @@ import org.slf4j.LoggerFactory;
* and finishing all runners in forward topological order.
*/
public class ProcessBundleHandler {
+
// TODO: What should the initial set of URNs be?
private static final String DATA_INPUT_URN = "urn:org.apache.beam:source:runner:0.1";
- private static final String DATA_OUTPUT_URN = "urn:org.apache.beam:sink:runner:0.1";
- private static final String JAVA_DO_FN_URN = "urn:org.apache.beam:dofn:java:0.1";
- private static final String JAVA_SOURCE_URN = "urn:org.apache.beam:source:java:0.1";
+ public static final String JAVA_SOURCE_URN = "urn:org.apache.beam:source:java:0.1";
private static final Logger LOG = LoggerFactory.getLogger(ProcessBundleHandler.class);
+ private static final Map<String, PTransformRunnerFactory> REGISTERED_RUNNER_FACTORIES;
+
+ static {
+ Set<Registrar> pipelineRunnerRegistrars =
+ Sets.newTreeSet(ReflectHelpers.ObjectsClassComparator.INSTANCE);
+ pipelineRunnerRegistrars.addAll(
+ Lists.newArrayList(ServiceLoader.load(Registrar.class,
+ ReflectHelpers.findClassLoader())));
+
+ // Load all registered PTransform runner factories.
+ ImmutableMap.Builder<String, PTransformRunnerFactory> builder =
+ ImmutableMap.builder();
+ for (Registrar registrar : pipelineRunnerRegistrars) {
+ builder.putAll(registrar.getPTransformRunnerFactories());
+ }
+ REGISTERED_RUNNER_FACTORIES = builder.build();
+ }
private final PipelineOptions options;
private final Function<String, Message> fnApiRegistry;
private final BeamFnDataClient beamFnDataClient;
+ private final Map<String, PTransformRunnerFactory> urnToPTransformRunnerFactoryMap;
+ private final PTransformRunnerFactory defaultPTransformRunnerFactory;
+
public ProcessBundleHandler(
PipelineOptions options,
Function<String, Message> fnApiRegistry,
BeamFnDataClient beamFnDataClient) {
+ this(options, fnApiRegistry, beamFnDataClient, REGISTERED_RUNNER_FACTORIES);
+ }
+
+ @VisibleForTesting
+ ProcessBundleHandler(
+ PipelineOptions options,
+ Function<String, Message> fnApiRegistry,
+ BeamFnDataClient beamFnDataClient,
+ Map<String, PTransformRunnerFactory> urnToPTransformRunnerFactoryMap) {
this.options = options;
this.fnApiRegistry = fnApiRegistry;
this.beamFnDataClient = beamFnDataClient;
+ this.urnToPTransformRunnerFactoryMap = urnToPTransformRunnerFactoryMap;
+ this.defaultPTransformRunnerFactory = new PTransformRunnerFactory<Object>() {
+ @Override
+ public Object createRunnerForPTransform(
+ PipelineOptions pipelineOptions,
+ BeamFnDataClient beamFnDataClient,
+ String pTransformId,
+ RunnerApi.PTransform pTransform,
+ Supplier<String> processBundleInstructionId,
+ Map<String, RunnerApi.PCollection> pCollections,
+ Map<String, RunnerApi.Coder> coders,
+ Multimap<String, ThrowingConsumer<WindowedValue<?>>> pCollectionIdsToConsumers,
+ Consumer<ThrowingRunnable> addStartFunction,
+ Consumer<ThrowingRunnable> addFinishFunction) {
+ throw new IllegalStateException(String.format(
+ "No factory registered for %s, known factories %s",
+ pTransform.getSpec().getUrn(),
+ urnToPTransformRunnerFactoryMap.keySet()));
+ }
+ };
}
private void createRunnerAndConsumersForPTransformRecursively(
@@ -128,115 +157,19 @@ public class ProcessBundleHandler {
}
}
- createRunnerForPTransform(
- pTransformId,
- pTransform,
- processBundleInstructionId,
- processBundleDescriptor.getPcollectionsMap(),
- pCollectionIdsToConsumers,
- addStartFunction,
- addFinishFunction);
- }
-
- protected void createRunnerForPTransform(
- String pTransformId,
- RunnerApi.PTransform pTransform,
- Supplier<String> processBundleInstructionId,
- Map<String, RunnerApi.PCollection> pCollections,
- Multimap<String, ThrowingConsumer<WindowedValue<?>>> pCollectionIdsToConsumers,
- Consumer<ThrowingRunnable> addStartFunction,
- Consumer<ThrowingRunnable> addFinishFunction) throws IOException {
-
-
- // For every output PCollection, create a map from output name to Consumer
- ImmutableMap.Builder<String, Collection<ThrowingConsumer<WindowedValue<?>>>>
- outputMapBuilder = ImmutableMap.builder();
- for (Map.Entry<String, String> entry : pTransform.getOutputsMap().entrySet()) {
- outputMapBuilder.put(
- entry.getKey(),
- pCollectionIdsToConsumers.get(entry.getValue()));
- }
- ImmutableMap<String, Collection<ThrowingConsumer<WindowedValue<?>>>> outputMap =
- outputMapBuilder.build();
-
-
- // Based upon the function spec, populate the start/finish/consumer information.
- RunnerApi.FunctionSpec functionSpec = pTransform.getSpec();
- ThrowingConsumer<WindowedValue<?>> consumer;
- switch (functionSpec.getUrn()) {
- default:
- BeamFnApi.Target target;
- RunnerApi.Coder coderSpec;
- throw new IllegalArgumentException(
- String.format("Unknown FunctionSpec %s", functionSpec));
-
- case DATA_OUTPUT_URN:
- target = BeamFnApi.Target.newBuilder()
- .setPrimitiveTransformReference(pTransformId)
- .setName(getOnlyElement(pTransform.getInputsMap().keySet()))
- .build();
- coderSpec = (RunnerApi.Coder) fnApiRegistry.apply(
- pCollections.get(getOnlyElement(pTransform.getInputsMap().values())).getCoderId());
- BeamFnDataWriteRunner<Object> remoteGrpcWriteRunner =
- new BeamFnDataWriteRunner<Object>(
- functionSpec,
- processBundleInstructionId,
- target,
- coderSpec,
- beamFnDataClient);
- addStartFunction.accept(remoteGrpcWriteRunner::registerForOutput);
- consumer = (ThrowingConsumer)
- (ThrowingConsumer<WindowedValue<Object>>) remoteGrpcWriteRunner::consume;
- addFinishFunction.accept(remoteGrpcWriteRunner::close);
- break;
-
- case DATA_INPUT_URN:
- target = BeamFnApi.Target.newBuilder()
- .setPrimitiveTransformReference(pTransformId)
- .setName(getOnlyElement(pTransform.getOutputsMap().keySet()))
- .build();
- coderSpec = (RunnerApi.Coder) fnApiRegistry.apply(
- pCollections.get(getOnlyElement(pTransform.getOutputsMap().values())).getCoderId());
- BeamFnDataReadRunner<?> remoteGrpcReadRunner =
- new BeamFnDataReadRunner<Object>(
- functionSpec,
- processBundleInstructionId,
- target,
- coderSpec,
- beamFnDataClient,
- (Map) outputMap);
- addStartFunction.accept(remoteGrpcReadRunner::registerInputLocation);
- consumer = null;
- addFinishFunction.accept(remoteGrpcReadRunner::blockTillReadFinishes);
- break;
-
- case JAVA_DO_FN_URN:
- DoFnRunner<Object, Object> doFnRunner = createDoFnRunner(functionSpec, (Map) outputMap);
- addStartFunction.accept(doFnRunner::startBundle);
- consumer = (ThrowingConsumer)
- (ThrowingConsumer<WindowedValue<Object>>) doFnRunner::processElement;
- addFinishFunction.accept(doFnRunner::finishBundle);
- break;
-
- case JAVA_SOURCE_URN:
- @SuppressWarnings({"unchecked", "rawtypes"})
- BoundedSourceRunner<BoundedSource<Object>, Object> sourceRunner =
- createBoundedSourceRunner(functionSpec, (Map) outputMap);
- // TODO: Remove and replace with source being sent across gRPC port
- addStartFunction.accept(sourceRunner::start);
- consumer = (ThrowingConsumer)
- (ThrowingConsumer<WindowedValue<BoundedSource<Object>>>)
- sourceRunner::runReadLoop;
- break;
- }
-
- // If we created a consumer, add it to the map containing PCollection ids to consumers
- if (consumer != null) {
- for (String inputPCollectionId :
- pTransform.getInputsMap().values()) {
- pCollectionIdsToConsumers.put(inputPCollectionId, consumer);
- }
- }
+ urnToPTransformRunnerFactoryMap.getOrDefault(
+ pTransform.getSpec().getUrn(), defaultPTransformRunnerFactory)
+ .createRunnerForPTransform(
+ options,
+ beamFnDataClient,
+ pTransformId,
+ pTransform,
+ processBundleInstructionId,
+ processBundleDescriptor.getPcollectionsMap(),
+ processBundleDescriptor.getCodersyyyMap(),
+ pCollectionIdsToConsumers,
+ addStartFunction,
+ addFinishFunction);
}
public BeamFnApi.InstructionResponse.Builder processBundle(BeamFnApi.InstructionRequest request)
@@ -299,88 +232,4 @@ public class ProcessBundleHandler {
return response;
}
-
- /**
- * Converts a {@link org.apache.beam.fn.v1.BeamFnApi.FunctionSpec} into a {@link DoFnRunner}.
- */
- private <InputT, OutputT> DoFnRunner<InputT, OutputT> createDoFnRunner(
- RunnerApi.FunctionSpec functionSpec,
- Map<String, Collection<ThrowingConsumer<WindowedValue<OutputT>>>> outputMap) {
- ByteString serializedFn;
- try {
- serializedFn = functionSpec.getParameter().unpack(BytesValue.class).getValue();
- } catch (InvalidProtocolBufferException e) {
- throw new IllegalArgumentException(
- String.format("Unable to unwrap DoFn %s", functionSpec), e);
- }
- DoFnInfo<?, ?> doFnInfo =
- (DoFnInfo<?, ?>)
- SerializableUtils.deserializeFromByteArray(serializedFn.toByteArray(), "DoFnInfo");
-
- checkArgument(
- Objects.equals(
- new HashSet<>(Collections2.transform(outputMap.keySet(), Long::parseLong)),
- doFnInfo.getOutputMap().keySet()),
- "Unexpected mismatch between transform output map %s and DoFnInfo output map %s.",
- outputMap.keySet(),
- doFnInfo.getOutputMap());
-
- ImmutableMultimap.Builder<TupleTag<?>,
- ThrowingConsumer<WindowedValue<OutputT>>> tagToOutput =
- ImmutableMultimap.builder();
- for (Map.Entry<Long, TupleTag<?>> entry : doFnInfo.getOutputMap().entrySet()) {
- tagToOutput.putAll(entry.getValue(), outputMap.get(Long.toString(entry.getKey())));
- }
- @SuppressWarnings({"unchecked", "rawtypes"})
- final Map<TupleTag<?>, Collection<ThrowingConsumer<WindowedValue<?>>>> tagBasedOutputMap =
- (Map) tagToOutput.build().asMap();
-
- OutputManager outputManager =
- new OutputManager() {
- Map<TupleTag<?>, Collection<ThrowingConsumer<WindowedValue<?>>>> tupleTagToOutput =
- tagBasedOutputMap;
-
- @Override
- public <T> void output(TupleTag<T> tag, WindowedValue<T> output) {
- try {
- Collection<ThrowingConsumer<WindowedValue<?>>> consumers =
- tupleTagToOutput.get(tag);
- if (consumers == null) {
- /* This is a normal case, e.g., if a DoFn has output but that output is not
- * consumed. Drop the output. */
- return;
- }
- for (ThrowingConsumer<WindowedValue<?>> consumer : consumers) {
- consumer.accept(output);
- }
- } catch (Throwable t) {
- throw new RuntimeException(t);
- }
- }
- };
-
- @SuppressWarnings({"unchecked", "rawtypes", "deprecation"})
- DoFnRunner<InputT, OutputT> runner =
- DoFnRunners.simpleRunner(
- PipelineOptionsFactory.create(), /* TODO */
- (DoFn) doFnInfo.getDoFn(),
- NullSideInputReader.empty(), /* TODO */
- outputManager,
- (TupleTag) doFnInfo.getOutputMap().get(doFnInfo.getMainOutput()),
- new ArrayList<>(doFnInfo.getOutputMap().values()),
- new FakeStepContext(),
- (WindowingStrategy) doFnInfo.getWindowingStrategy());
- return runner;
- }
-
- private <InputT extends BoundedSource<OutputT>, OutputT>
- BoundedSourceRunner<InputT, OutputT> createBoundedSourceRunner(
- RunnerApi.FunctionSpec functionSpec,
- Map<String, Collection<ThrowingConsumer<WindowedValue<OutputT>>>> outputMap) {
-
- @SuppressWarnings({"rawtypes", "unchecked"})
- BoundedSourceRunner<InputT, OutputT> runner =
- new BoundedSourceRunner(options, functionSpec, outputMap);
- return runner;
- }
}
http://git-wip-us.apache.org/repos/asf/beam/blob/c9c1a05d/sdks/java/harness/src/main/java/org/apache/beam/runners/core/BeamFnDataReadRunner.java
----------------------------------------------------------------------
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/runners/core/BeamFnDataReadRunner.java b/sdks/java/harness/src/main/java/org/apache/beam/runners/core/BeamFnDataReadRunner.java
index f0fe274..9339347 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/runners/core/BeamFnDataReadRunner.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/runners/core/BeamFnDataReadRunner.java
@@ -18,22 +18,28 @@
package org.apache.beam.runners.core;
+import static com.google.common.collect.Iterables.getOnlyElement;
+
import com.fasterxml.jackson.databind.ObjectMapper;
-import com.google.common.collect.FluentIterable;
-import com.google.common.collect.ImmutableList;
+import com.google.auto.service.AutoService;
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.Multimap;
import com.google.protobuf.BytesValue;
import java.io.IOException;
import java.util.Collection;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
+import java.util.function.Consumer;
import java.util.function.Supplier;
import org.apache.beam.fn.harness.data.BeamFnDataClient;
import org.apache.beam.fn.harness.fn.ThrowingConsumer;
+import org.apache.beam.fn.harness.fn.ThrowingRunnable;
import org.apache.beam.fn.v1.BeamFnApi;
import org.apache.beam.runners.dataflow.util.CloudObject;
import org.apache.beam.runners.dataflow.util.CloudObjects;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.common.runner.v1.RunnerApi;
+import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.values.KV;
import org.slf4j.Logger;
@@ -48,9 +54,61 @@ import org.slf4j.LoggerFactory;
* {@link #blockTillReadFinishes()} to finish.
*/
public class BeamFnDataReadRunner<OutputT> {
- private static final Logger LOG = LoggerFactory.getLogger(BeamFnDataReadRunner.class);
+ private static final Logger LOG = LoggerFactory.getLogger(BeamFnDataReadRunner.class);
private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();
+ private static final String URN = "urn:org.apache.beam:source:runner:0.1";
+
+ /** A registrar which provides a factory to handle reading from the Fn Api Data Plane. */
+ @AutoService(PTransformRunnerFactory.Registrar.class)
+ public static class Registrar implements
+ PTransformRunnerFactory.Registrar {
+
+ @Override
+ public Map<String, PTransformRunnerFactory> getPTransformRunnerFactories() {
+ return ImmutableMap.of(URN, new Factory());
+ }
+ }
+
+ /** A factory for {@link BeamFnDataReadRunner}s. */
+ static class Factory<OutputT>
+ implements PTransformRunnerFactory<BeamFnDataReadRunner<OutputT>> {
+
+ @Override
+ public BeamFnDataReadRunner<OutputT> createRunnerForPTransform(
+ PipelineOptions pipelineOptions,
+ BeamFnDataClient beamFnDataClient,
+ String pTransformId,
+ RunnerApi.PTransform pTransform,
+ Supplier<String> processBundleInstructionId,
+ Map<String, RunnerApi.PCollection> pCollections,
+ Map<String, RunnerApi.Coder> coders,
+ Multimap<String, ThrowingConsumer<WindowedValue<?>>> pCollectionIdsToConsumers,
+ Consumer<ThrowingRunnable> addStartFunction,
+ Consumer<ThrowingRunnable> addFinishFunction) throws IOException {
+
+ BeamFnApi.Target target = BeamFnApi.Target.newBuilder()
+ .setPrimitiveTransformReference(pTransformId)
+ .setName(getOnlyElement(pTransform.getOutputsMap().keySet()))
+ .build();
+ RunnerApi.Coder coderSpec = coders.get(pCollections.get(
+ getOnlyElement(pTransform.getOutputsMap().values())).getCoderId());
+ Collection<ThrowingConsumer<WindowedValue<OutputT>>> consumers =
+ (Collection) pCollectionIdsToConsumers.get(
+ getOnlyElement(pTransform.getOutputsMap().values()));
+
+ BeamFnDataReadRunner<OutputT> runner = new BeamFnDataReadRunner<>(
+ pTransform.getSpec(),
+ processBundleInstructionId,
+ target,
+ coderSpec,
+ beamFnDataClient,
+ consumers);
+ addStartFunction.accept(runner::registerInputLocation);
+ addFinishFunction.accept(runner::blockTillReadFinishes);
+ return runner;
+ }
+ }
private final BeamFnApi.ApiServiceDescriptor apiServiceDescriptor;
private final Collection<ThrowingConsumer<WindowedValue<OutputT>>> consumers;
@@ -61,20 +119,20 @@ public class BeamFnDataReadRunner<OutputT> {
private CompletableFuture<Void> readFuture;
- public BeamFnDataReadRunner(
+ BeamFnDataReadRunner(
RunnerApi.FunctionSpec functionSpec,
Supplier<String> processBundleInstructionIdSupplier,
BeamFnApi.Target inputTarget,
RunnerApi.Coder coderSpec,
BeamFnDataClient beamFnDataClientFactory,
- Map<String, Collection<ThrowingConsumer<WindowedValue<OutputT>>>> outputMap)
+ Collection<ThrowingConsumer<WindowedValue<OutputT>>> consumers)
throws IOException {
this.apiServiceDescriptor = functionSpec.getParameter().unpack(BeamFnApi.RemoteGrpcPort.class)
.getApiServiceDescriptor();
this.inputTarget = inputTarget;
this.processBundleInstructionIdSupplier = processBundleInstructionIdSupplier;
this.beamFnDataClientFactory = beamFnDataClientFactory;
- this.consumers = ImmutableList.copyOf(FluentIterable.concat(outputMap.values()));
+ this.consumers = consumers;
@SuppressWarnings("unchecked")
Coder<WindowedValue<OutputT>> coder =
http://git-wip-us.apache.org/repos/asf/beam/blob/c9c1a05d/sdks/java/harness/src/main/java/org/apache/beam/runners/core/BeamFnDataWriteRunner.java
----------------------------------------------------------------------
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/runners/core/BeamFnDataWriteRunner.java b/sdks/java/harness/src/main/java/org/apache/beam/runners/core/BeamFnDataWriteRunner.java
index a48df12..c2a996b 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/runners/core/BeamFnDataWriteRunner.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/runners/core/BeamFnDataWriteRunner.java
@@ -18,30 +18,91 @@
package org.apache.beam.runners.core;
+import static com.google.common.collect.Iterables.getOnlyElement;
+
import com.fasterxml.jackson.databind.ObjectMapper;
+import com.google.auto.service.AutoService;
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.Multimap;
import com.google.protobuf.BytesValue;
import java.io.IOException;
import java.util.Map;
+import java.util.function.Consumer;
import java.util.function.Supplier;
import org.apache.beam.fn.harness.data.BeamFnDataClient;
import org.apache.beam.fn.harness.fn.CloseableThrowingConsumer;
+import org.apache.beam.fn.harness.fn.ThrowingConsumer;
+import org.apache.beam.fn.harness.fn.ThrowingRunnable;
import org.apache.beam.fn.v1.BeamFnApi;
import org.apache.beam.runners.dataflow.util.CloudObject;
import org.apache.beam.runners.dataflow.util.CloudObjects;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.common.runner.v1.RunnerApi;
+import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.values.KV;
/**
- * Registers as a consumer with the Beam Fn Data API. Propagates and elements consumed to
- * the the registered consumer.
+ * Registers as a consumer with the Beam Fn Data Api. Consumes elements and encodes them for
+ * transmission.
*
* <p>Can be re-used serially across {@link org.apache.beam.fn.v1.BeamFnApi.ProcessBundleRequest}s.
* For each request, call {@link #registerForOutput()} to start and call {@link #close()} to finish.
*/
public class BeamFnDataWriteRunner<InputT> {
+
private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();
+ private static final String URN = "urn:org.apache.beam:sink:runner:0.1";
+
+ /** A registrar which provides a factory to handle writing to the Fn Api Data Plane. */
+ @AutoService(PTransformRunnerFactory.Registrar.class)
+ public static class Registrar implements
+ PTransformRunnerFactory.Registrar {
+
+ @Override
+ public Map<String, PTransformRunnerFactory> getPTransformRunnerFactories() {
+ return ImmutableMap.of(URN, new Factory());
+ }
+ }
+
+ /** A factory for {@link BeamFnDataWriteRunner}s. */
+ static class Factory<InputT>
+ implements PTransformRunnerFactory<BeamFnDataWriteRunner<InputT>> {
+
+ @Override
+ public BeamFnDataWriteRunner<InputT> createRunnerForPTransform(
+ PipelineOptions pipelineOptions,
+ BeamFnDataClient beamFnDataClient,
+ String pTransformId,
+ RunnerApi.PTransform pTransform,
+ Supplier<String> processBundleInstructionId,
+ Map<String, RunnerApi.PCollection> pCollections,
+ Map<String, RunnerApi.Coder> coders,
+ Multimap<String, ThrowingConsumer<WindowedValue<?>>> pCollectionIdsToConsumers,
+ Consumer<ThrowingRunnable> addStartFunction,
+ Consumer<ThrowingRunnable> addFinishFunction) throws IOException {
+ BeamFnApi.Target target = BeamFnApi.Target.newBuilder()
+ .setPrimitiveTransformReference(pTransformId)
+ .setName(getOnlyElement(pTransform.getInputsMap().keySet()))
+ .build();
+ RunnerApi.Coder coderSpec = coders.get(
+ pCollections.get(getOnlyElement(pTransform.getInputsMap().values())).getCoderId());
+ BeamFnDataWriteRunner<InputT> runner =
+ new BeamFnDataWriteRunner<>(
+ pTransform.getSpec(),
+ processBundleInstructionId,
+ target,
+ coderSpec,
+ beamFnDataClient);
+ addStartFunction.accept(runner::registerForOutput);
+ pCollectionIdsToConsumers.put(
+ getOnlyElement(pTransform.getInputsMap().values()),
+ (ThrowingConsumer)
+ (ThrowingConsumer<WindowedValue<InputT>>) runner::consume);
+ addFinishFunction.accept(runner::close);
+ return runner;
+ }
+ }
private final BeamFnApi.ApiServiceDescriptor apiServiceDescriptor;
private final BeamFnApi.Target outputTarget;
@@ -51,7 +112,7 @@ public class BeamFnDataWriteRunner<InputT> {
private CloseableThrowingConsumer<WindowedValue<InputT>> consumer;
- public BeamFnDataWriteRunner(
+ BeamFnDataWriteRunner(
RunnerApi.FunctionSpec functionSpec,
Supplier<String> processBundleInstructionIdSupplier,
BeamFnApi.Target outputTarget,
http://git-wip-us.apache.org/repos/asf/beam/blob/c9c1a05d/sdks/java/harness/src/main/java/org/apache/beam/runners/core/BoundedSourceRunner.java
----------------------------------------------------------------------
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/runners/core/BoundedSourceRunner.java b/sdks/java/harness/src/main/java/org/apache/beam/runners/core/BoundedSourceRunner.java
index 4d530b8..3338c3a 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/runners/core/BoundedSourceRunner.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/runners/core/BoundedSourceRunner.java
@@ -18,14 +18,20 @@
package org.apache.beam.runners.core;
-import com.google.common.collect.FluentIterable;
+import com.google.auto.service.AutoService;
import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.Multimap;
import com.google.protobuf.BytesValue;
import com.google.protobuf.InvalidProtocolBufferException;
import java.io.IOException;
import java.util.Collection;
import java.util.Map;
+import java.util.function.Consumer;
+import java.util.function.Supplier;
+import org.apache.beam.fn.harness.data.BeamFnDataClient;
import org.apache.beam.fn.harness.fn.ThrowingConsumer;
+import org.apache.beam.fn.harness.fn.ThrowingRunnable;
import org.apache.beam.sdk.common.runner.v1.RunnerApi;
import org.apache.beam.sdk.io.BoundedSource;
import org.apache.beam.sdk.io.Source.Reader;
@@ -34,21 +40,77 @@ import org.apache.beam.sdk.util.SerializableUtils;
import org.apache.beam.sdk.util.WindowedValue;
/**
- * A runner which creates {@link Reader}s for each {@link BoundedSource} and executes
- * the {@link Reader}s read loop.
+ * A runner which creates {@link Reader}s for each {@link BoundedSource} sent as an input and
+ * executes the {@link Reader}s read loop.
*/
public class BoundedSourceRunner<InputT extends BoundedSource<OutputT>, OutputT> {
+
+ private static final String URN = "urn:org.apache.beam:source:java:0.1";
+
+ /** A registrar which provides a factory to handle Java {@link BoundedSource}s. */
+ @AutoService(PTransformRunnerFactory.Registrar.class)
+ public static class Registrar implements
+ PTransformRunnerFactory.Registrar {
+
+ @Override
+ public Map<String, PTransformRunnerFactory> getPTransformRunnerFactories() {
+ return ImmutableMap.of(URN, new Factory());
+ }
+ }
+
+ /** A factory for {@link BoundedSourceRunner}. */
+ static class Factory<InputT extends BoundedSource<OutputT>, OutputT>
+ implements PTransformRunnerFactory<BoundedSourceRunner<InputT, OutputT>> {
+ @Override
+ public BoundedSourceRunner<InputT, OutputT> createRunnerForPTransform(
+ PipelineOptions pipelineOptions,
+ BeamFnDataClient beamFnDataClient,
+ String pTransformId,
+ RunnerApi.PTransform pTransform,
+ Supplier<String> processBundleInstructionId,
+ Map<String, RunnerApi.PCollection> pCollections,
+ Map<String, RunnerApi.Coder> coders,
+ Multimap<String, ThrowingConsumer<WindowedValue<?>>> pCollectionIdsToConsumers,
+ Consumer<ThrowingRunnable> addStartFunction,
+ Consumer<ThrowingRunnable> addFinishFunction) {
+
+ ImmutableList.Builder<ThrowingConsumer<WindowedValue<?>>> consumers = ImmutableList.builder();
+ for (String pCollectionId : pTransform.getOutputsMap().values()) {
+ consumers.addAll(pCollectionIdsToConsumers.get(pCollectionId));
+ }
+
+ @SuppressWarnings({"rawtypes", "unchecked"})
+ BoundedSourceRunner<InputT, OutputT> runner = new BoundedSourceRunner(
+ pipelineOptions,
+ pTransform.getSpec(),
+ consumers.build());
+
+ // TODO: Remove and replace with source being sent across gRPC port
+ addStartFunction.accept(runner::start);
+
+ ThrowingConsumer runReadLoop =
+ (ThrowingConsumer<WindowedValue<InputT>>) runner::runReadLoop;
+ for (String pCollectionId : pTransform.getInputsMap().values()) {
+ pCollectionIdsToConsumers.put(
+ pCollectionId,
+ runReadLoop);
+ }
+
+ return runner;
+ }
+ }
+
private final PipelineOptions pipelineOptions;
private final RunnerApi.FunctionSpec definition;
private final Collection<ThrowingConsumer<WindowedValue<OutputT>>> consumers;
- public BoundedSourceRunner(
+ BoundedSourceRunner(
PipelineOptions pipelineOptions,
RunnerApi.FunctionSpec definition,
- Map<String, Collection<ThrowingConsumer<WindowedValue<OutputT>>>> outputMap) {
+ Collection<ThrowingConsumer<WindowedValue<OutputT>>> consumers) {
this.pipelineOptions = pipelineOptions;
this.definition = definition;
- this.consumers = ImmutableList.copyOf(FluentIterable.concat(outputMap.values()));
+ this.consumers = consumers;
}
/**
http://git-wip-us.apache.org/repos/asf/beam/blob/c9c1a05d/sdks/java/harness/src/main/java/org/apache/beam/runners/core/DoFnRunnerFactory.java
----------------------------------------------------------------------
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/runners/core/DoFnRunnerFactory.java b/sdks/java/harness/src/main/java/org/apache/beam/runners/core/DoFnRunnerFactory.java
new file mode 100644
index 0000000..3c0b6eb
--- /dev/null
+++ b/sdks/java/harness/src/main/java/org/apache/beam/runners/core/DoFnRunnerFactory.java
@@ -0,0 +1,182 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.core;
+
+import static com.google.common.base.Preconditions.checkArgument;
+
+import com.google.auto.service.AutoService;
+import com.google.common.collect.Collections2;
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.ImmutableMultimap;
+import com.google.common.collect.Multimap;
+import com.google.protobuf.ByteString;
+import com.google.protobuf.BytesValue;
+import com.google.protobuf.InvalidProtocolBufferException;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Objects;
+import java.util.function.Consumer;
+import java.util.function.Supplier;
+import org.apache.beam.fn.harness.data.BeamFnDataClient;
+import org.apache.beam.fn.harness.fake.FakeStepContext;
+import org.apache.beam.fn.harness.fn.ThrowingConsumer;
+import org.apache.beam.fn.harness.fn.ThrowingRunnable;
+import org.apache.beam.runners.core.DoFnRunners.OutputManager;
+import org.apache.beam.runners.dataflow.util.DoFnInfo;
+import org.apache.beam.sdk.common.runner.v1.RunnerApi;
+import org.apache.beam.sdk.options.PipelineOptions;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.util.SerializableUtils;
+import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.sdk.values.TupleTag;
+import org.apache.beam.sdk.values.WindowingStrategy;
+
+/**
+ * Classes associated with converting {@link RunnerApi.PTransform}s to {@link DoFnRunner}s.
+ *
+ * <p>TODO: Move DoFnRunners into SDK harness and merge the methods below into it removing this
+ * class.
+ */
+public class DoFnRunnerFactory {
+
+ private static final String URN = "urn:org.apache.beam:dofn:java:0.1";
+
+ /** A registrar which provides a factory to handle Java {@link DoFn}s. */
+ @AutoService(PTransformRunnerFactory.Registrar.class)
+ public static class Registrar implements
+ PTransformRunnerFactory.Registrar {
+
+ @Override
+ public Map<String, PTransformRunnerFactory> getPTransformRunnerFactories() {
+ return ImmutableMap.of(URN, new Factory());
+ }
+ }
+
+ /** A factory for {@link DoFnRunner}s. */
+ static class Factory<InputT, OutputT>
+ implements PTransformRunnerFactory<DoFnRunner<InputT, OutputT>> {
+
+ @Override
+ public DoFnRunner<InputT, OutputT> createRunnerForPTransform(
+ PipelineOptions pipelineOptions,
+ BeamFnDataClient beamFnDataClient,
+ String pTransformId,
+ RunnerApi.PTransform pTransform,
+ Supplier<String> processBundleInstructionId,
+ Map<String, RunnerApi.PCollection> pCollections,
+ Map<String, RunnerApi.Coder> coders,
+ Multimap<String, ThrowingConsumer<WindowedValue<?>>> pCollectionIdsToConsumers,
+ Consumer<ThrowingRunnable> addStartFunction,
+ Consumer<ThrowingRunnable> addFinishFunction) {
+
+ // For every output PCollection, create a map from output name to Consumer
+ ImmutableMap.Builder<String, Collection<ThrowingConsumer<WindowedValue<?>>>>
+ outputMapBuilder = ImmutableMap.builder();
+ for (Map.Entry<String, String> entry : pTransform.getOutputsMap().entrySet()) {
+ outputMapBuilder.put(
+ entry.getKey(),
+ pCollectionIdsToConsumers.get(entry.getValue()));
+ }
+ ImmutableMap<String, Collection<ThrowingConsumer<WindowedValue<?>>>> outputMap =
+ outputMapBuilder.build();
+
+ // Get the DoFnInfo from the serialized blob.
+ ByteString serializedFn;
+ try {
+ serializedFn = pTransform.getSpec().getParameter().unpack(BytesValue.class).getValue();
+ } catch (InvalidProtocolBufferException e) {
+ throw new IllegalArgumentException(
+ String.format("Unable to unwrap DoFn %s", pTransform.getSpec()), e);
+ }
+ DoFnInfo<?, ?> doFnInfo =
+ (DoFnInfo<?, ?>)
+ SerializableUtils.deserializeFromByteArray(serializedFn.toByteArray(), "DoFnInfo");
+
+ // Verify that the DoFnInfo tag to output map matches the output map on the PTransform.
+ checkArgument(
+ Objects.equals(
+ new HashSet<>(Collections2.transform(outputMap.keySet(), Long::parseLong)),
+ doFnInfo.getOutputMap().keySet()),
+ "Unexpected mismatch between transform output map %s and DoFnInfo output map %s.",
+ outputMap.keySet(),
+ doFnInfo.getOutputMap());
+
+ ImmutableMultimap.Builder<TupleTag<?>,
+ ThrowingConsumer<WindowedValue<OutputT>>> tagToOutput =
+ ImmutableMultimap.builder();
+ for (Map.Entry<Long, TupleTag<?>> entry : doFnInfo.getOutputMap().entrySet()) {
+ @SuppressWarnings({"unchecked", "rawtypes"})
+ Collection<ThrowingConsumer<WindowedValue<OutputT>>> consumers =
+ (Collection) outputMap.get(Long.toString(entry.getKey()));
+ tagToOutput.putAll(entry.getValue(), consumers);
+ }
+
+ @SuppressWarnings({"unchecked", "rawtypes"})
+ Map<TupleTag<?>, Collection<ThrowingConsumer<WindowedValue<?>>>> tagBasedOutputMap =
+ (Map) tagToOutput.build().asMap();
+
+ OutputManager outputManager =
+ new OutputManager() {
+ Map<TupleTag<?>, Collection<ThrowingConsumer<WindowedValue<?>>>> tupleTagToOutput =
+ tagBasedOutputMap;
+
+ @Override
+ public <T> void output(TupleTag<T> tag, WindowedValue<T> output) {
+ try {
+ Collection<ThrowingConsumer<WindowedValue<?>>> consumers =
+ tupleTagToOutput.get(tag);
+ if (consumers == null) {
+ /* This is a normal case, e.g., if a DoFn has output but that output is not
+ * consumed. Drop the output. */
+ return;
+ }
+ for (ThrowingConsumer<WindowedValue<?>> consumer : consumers) {
+ consumer.accept(output);
+ }
+ } catch (Throwable t) {
+ throw new RuntimeException(t);
+ }
+ }
+ };
+
+ @SuppressWarnings({"unchecked", "rawtypes", "deprecation"})
+ DoFnRunner<InputT, OutputT> runner =
+ DoFnRunners.simpleRunner(
+ pipelineOptions,
+ (DoFn) doFnInfo.getDoFn(),
+ NullSideInputReader.empty(), /* TODO */
+ outputManager,
+ (TupleTag) doFnInfo.getOutputMap().get(doFnInfo.getMainOutput()),
+ new ArrayList<>(doFnInfo.getOutputMap().values()),
+ new FakeStepContext(),
+ (WindowingStrategy) doFnInfo.getWindowingStrategy());
+
+ // Register the appropriate handlers.
+ addStartFunction.accept(runner::startBundle);
+ for (String pcollectionId : pTransform.getInputsMap().values()) {
+ pCollectionIdsToConsumers.put(
+ pcollectionId,
+ (ThrowingConsumer) (ThrowingConsumer<WindowedValue<InputT>>) runner::processElement);
+ }
+ addFinishFunction.accept(runner::finishBundle);
+ return runner;
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/beam/blob/c9c1a05d/sdks/java/harness/src/main/java/org/apache/beam/runners/core/PTransformRunnerFactory.java
----------------------------------------------------------------------
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/runners/core/PTransformRunnerFactory.java b/sdks/java/harness/src/main/java/org/apache/beam/runners/core/PTransformRunnerFactory.java
new file mode 100644
index 0000000..b325db4
--- /dev/null
+++ b/sdks/java/harness/src/main/java/org/apache/beam/runners/core/PTransformRunnerFactory.java
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.core;
+
+import com.google.common.collect.Multimap;
+import java.io.IOException;
+import java.util.Map;
+import java.util.function.Consumer;
+import java.util.function.Supplier;
+import org.apache.beam.fn.harness.data.BeamFnDataClient;
+import org.apache.beam.fn.harness.fn.ThrowingConsumer;
+import org.apache.beam.fn.harness.fn.ThrowingRunnable;
+import org.apache.beam.sdk.common.runner.v1.RunnerApi;
+import org.apache.beam.sdk.options.PipelineOptions;
+import org.apache.beam.sdk.util.WindowedValue;
+
+/**
+ * A factory able to instantiate an appropriate handler for a given PTransform.
+ */
+public interface PTransformRunnerFactory<T> {
+
+ /**
+ * Creates and returns a handler for a given PTransform. Note that the handler must support
+ * processing multiple bundles. The handler will be discarded if an error is thrown during
+ * element processing, or during execution of start/finish.
+ *
+ * @param pipelineOptions Pipeline options
+ * @param beamFnDataClient
+ * @param pTransformId The id of the PTransform.
+ * @param pTransform The PTransform definition.
+ * @param processBundleInstructionId A supplier containing the active process bundle instruction
+ * id.
+ * @param pCollections A mapping from PCollection id to PCollection definition.
+ * @param coders A mapping from coder id to coder definition.
+ * @param pCollectionIdsToConsumers A mapping from PCollection id to a collection of consumers.
+ * Note that if this handler is a consumer, it should register itself within this multimap under
+ * the appropriate PCollection ids. Also note that all output consumers needed by this PTransform
+ * (based on the values of the {@link RunnerApi.PTransform#getOutputsMap()} will have already
+ * registered within this multimap.
+ * @param addStartFunction A consumer to register a start bundle handler with.
+ * @param addFinishFunction A consumer to register a finish bundle handler with.
+ */
+ T createRunnerForPTransform(
+ PipelineOptions pipelineOptions,
+ BeamFnDataClient beamFnDataClient,
+ String pTransformId,
+ RunnerApi.PTransform pTransform,
+ Supplier<String> processBundleInstructionId,
+ Map<String, RunnerApi.PCollection> pCollections,
+ Map<String, RunnerApi.Coder> coders,
+ Multimap<String, ThrowingConsumer<WindowedValue<?>>> pCollectionIdsToConsumers,
+ Consumer<ThrowingRunnable> addStartFunction,
+ Consumer<ThrowingRunnable> addFinishFunction) throws IOException;
+
+ /**
+ * A registrar which can return a mapping from {@link RunnerApi.FunctionSpec#getUrn()} to
+ * a factory capable of instantiating an appropriate handler.
+ */
+ interface Registrar {
+ /**
+ * Returns a mapping from {@link RunnerApi.FunctionSpec#getUrn()} to a factory capable of
+ * instantiating an appropriate handler.
+ */
+ Map<String, PTransformRunnerFactory> getPTransformRunnerFactories();
+ }
+}
http://git-wip-us.apache.org/repos/asf/beam/blob/c9c1a05d/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java
----------------------------------------------------------------------
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java
index 562f91f..a616b2c 100644
--- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java
@@ -18,62 +18,28 @@
package org.apache.beam.fn.harness.control;
-import static org.apache.beam.sdk.util.WindowedValue.timestampedValueInGlobalWindow;
-import static org.apache.beam.sdk.util.WindowedValue.valueInGlobalWindow;
import static org.hamcrest.Matchers.contains;
-import static org.hamcrest.Matchers.containsInAnyOrder;
-import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.equalTo;
-import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertThat;
-import static org.junit.Assert.assertTrue;
-import static org.mockito.Matchers.any;
-import static org.mockito.Matchers.eq;
-import static org.mockito.Mockito.verify;
-import static org.mockito.Mockito.verifyNoMoreInteractions;
-import static org.mockito.Mockito.verifyZeroInteractions;
-import static org.mockito.Mockito.when;
-import com.fasterxml.jackson.databind.ObjectMapper;
-import com.google.common.base.Suppliers;
-import com.google.common.collect.HashMultimap;
-import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
-import com.google.common.collect.Iterables;
import com.google.common.collect.Multimap;
-import com.google.protobuf.Any;
-import com.google.protobuf.ByteString;
-import com.google.protobuf.BytesValue;
import com.google.protobuf.Message;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
-import java.util.concurrent.CompletableFuture;
-import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Consumer;
import java.util.function.Supplier;
import org.apache.beam.fn.harness.data.BeamFnDataClient;
-import org.apache.beam.fn.harness.fn.CloseableThrowingConsumer;
import org.apache.beam.fn.harness.fn.ThrowingConsumer;
import org.apache.beam.fn.harness.fn.ThrowingRunnable;
import org.apache.beam.fn.v1.BeamFnApi;
-import org.apache.beam.runners.dataflow.util.CloudObjects;
-import org.apache.beam.runners.dataflow.util.DoFnInfo;
-import org.apache.beam.sdk.coders.Coder;
-import org.apache.beam.sdk.coders.StringUtf8Coder;
-import org.apache.beam.sdk.coders.VarLongCoder;
+import org.apache.beam.runners.core.PTransformRunnerFactory;
import org.apache.beam.sdk.common.runner.v1.RunnerApi;
-import org.apache.beam.sdk.io.CountingSource;
+import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
-import org.apache.beam.sdk.transforms.DoFn;
-import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
-import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
-import org.apache.beam.sdk.util.SerializableUtils;
import org.apache.beam.sdk.util.WindowedValue;
-import org.apache.beam.sdk.values.KV;
-import org.apache.beam.sdk.values.TupleTag;
-import org.apache.beam.sdk.values.WindowingStrategy;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
@@ -82,55 +48,14 @@ import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.ArgumentCaptor;
import org.mockito.Captor;
-import org.mockito.Matchers;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
/** Tests for {@link ProcessBundleHandler}. */
@RunWith(JUnit4.class)
public class ProcessBundleHandlerTest {
- private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();
-
- private static final Coder<WindowedValue<String>> STRING_CODER =
- WindowedValue.getFullCoder(StringUtf8Coder.of(), GlobalWindow.Coder.INSTANCE);
- private static final String LONG_CODER_SPEC_ID = "998L";
- private static final String STRING_CODER_SPEC_ID = "999L";
- private static final BeamFnApi.RemoteGrpcPort REMOTE_PORT = BeamFnApi.RemoteGrpcPort.newBuilder()
- .setApiServiceDescriptor(BeamFnApi.ApiServiceDescriptor.newBuilder()
- .setId("58L")
- .setUrl("TestUrl"))
- .build();
- private static final RunnerApi.Coder LONG_CODER_SPEC;
- private static final RunnerApi.Coder STRING_CODER_SPEC;
- static {
- try {
- STRING_CODER_SPEC = RunnerApi.Coder.newBuilder()
- .setSpec(RunnerApi.SdkFunctionSpec.newBuilder()
- .setSpec(RunnerApi.FunctionSpec.newBuilder()
- .setParameter(Any.pack(BytesValue.newBuilder().setValue(ByteString.copyFrom(
- OBJECT_MAPPER.writeValueAsBytes(CloudObjects.asCloudObject(STRING_CODER))))
- .build())))
- .build())
- .build();
- LONG_CODER_SPEC = RunnerApi.Coder.newBuilder()
- .setSpec(RunnerApi.SdkFunctionSpec.newBuilder()
- .setSpec(RunnerApi.FunctionSpec.newBuilder()
- .setParameter(Any.pack(BytesValue.newBuilder().setValue(ByteString.copyFrom(
- OBJECT_MAPPER.writeValueAsBytes(
- CloudObjects.asCloudObject(WindowedValue.getFullCoder(VarLongCoder.of(),
- GlobalWindow.Coder.INSTANCE)))))
- .build())))
- .build())
- .build();
- } catch (IOException e) {
- throw new ExceptionInInitializerError(e);
- }
- }
-
private static final String DATA_INPUT_URN = "urn:org.apache.beam:source:runner:0.1";
private static final String DATA_OUTPUT_URN = "urn:org.apache.beam:sink:runner:0.1";
- private static final String JAVA_DO_FN_URN = "urn:org.apache.beam:dofn:java:0.1";
- private static final String JAVA_SOURCE_URN = "urn:org.apache.beam:source:java:0.1";
@Rule public ExpectedException thrown = ExpectedException.none();
@@ -161,16 +86,16 @@ public class ProcessBundleHandlerTest {
List<RunnerApi.PTransform> transformsProcessed = new ArrayList<>();
List<String> orderOfOperations = new ArrayList<>();
- ProcessBundleHandler handler = new ProcessBundleHandler(
- PipelineOptionsFactory.create(),
- fnApiRegistry::get,
- beamFnDataClient) {
+ PTransformRunnerFactory<Object> startFinishRecorder = new PTransformRunnerFactory<Object>() {
@Override
- protected void createRunnerForPTransform(
+ public Object createRunnerForPTransform(
+ PipelineOptions pipelineOptions,
+ BeamFnDataClient beamFnDataClient,
String pTransformId,
RunnerApi.PTransform pTransform,
Supplier<String> processBundleInstructionId,
Map<String, RunnerApi.PCollection> pCollections,
+ Map<String, RunnerApi.Coder> coders,
Multimap<String, ThrowingConsumer<WindowedValue<?>>> pCollectionIdsToConsumers,
Consumer<ThrowingRunnable> addStartFunction,
Consumer<ThrowingRunnable> addFinishFunction) throws IOException {
@@ -182,8 +107,18 @@ public class ProcessBundleHandlerTest {
() -> orderOfOperations.add("Start" + pTransformId));
addFinishFunction.accept(
() -> orderOfOperations.add("Finish" + pTransformId));
+ return null;
}
};
+
+ ProcessBundleHandler handler = new ProcessBundleHandler(
+ PipelineOptionsFactory.create(),
+ fnApiRegistry::get,
+ beamFnDataClient,
+ ImmutableMap.of(
+ DATA_INPUT_URN, startFinishRecorder,
+ DATA_OUTPUT_URN, startFinishRecorder));
+
handler.processBundle(BeamFnApi.InstructionRequest.newBuilder()
.setInstructionId("999L")
.setProcessBundle(
@@ -211,21 +146,25 @@ public class ProcessBundleHandlerTest {
ProcessBundleHandler handler = new ProcessBundleHandler(
PipelineOptionsFactory.create(),
fnApiRegistry::get,
- beamFnDataClient) {
- @Override
- protected void createRunnerForPTransform(
- String pTransformId,
- RunnerApi.PTransform pTransform,
- Supplier<String> processBundleInstructionId,
- Map<String, RunnerApi.PCollection> pCollections,
- Multimap<String, ThrowingConsumer<WindowedValue<?>>> pCollectionIdsToConsumers,
- Consumer<ThrowingRunnable> addStartFunction,
- Consumer<ThrowingRunnable> addFinishFunction) throws IOException {
- thrown.expect(IllegalStateException.class);
- thrown.expectMessage("TestException");
- throw new IllegalStateException("TestException");
- }
- };
+ beamFnDataClient,
+ ImmutableMap.of(DATA_INPUT_URN, new PTransformRunnerFactory<Object>() {
+ @Override
+ public Object createRunnerForPTransform(
+ PipelineOptions pipelineOptions,
+ BeamFnDataClient beamFnDataClient,
+ String pTransformId,
+ RunnerApi.PTransform pTransform,
+ Supplier<String> processBundleInstructionId,
+ Map<String, RunnerApi.PCollection> pCollections,
+ Map<String, RunnerApi.Coder> coders,
+ Multimap<String, ThrowingConsumer<WindowedValue<?>>> pCollectionIdsToConsumers,
+ Consumer<ThrowingRunnable> addStartFunction,
+ Consumer<ThrowingRunnable> addFinishFunction) throws IOException {
+ thrown.expect(IllegalStateException.class);
+ thrown.expectMessage("TestException");
+ throw new IllegalStateException("TestException");
+ }
+ }));
handler.processBundle(
BeamFnApi.InstructionRequest.newBuilder().setProcessBundle(
BeamFnApi.ProcessBundleRequest.newBuilder().setProcessBundleDescriptorReference("1L"))
@@ -245,25 +184,26 @@ public class ProcessBundleHandlerTest {
ProcessBundleHandler handler = new ProcessBundleHandler(
PipelineOptionsFactory.create(),
fnApiRegistry::get,
- beamFnDataClient) {
- @Override
- protected void createRunnerForPTransform(
- String pTransformId,
- RunnerApi.PTransform pTransform,
- Supplier<String> processBundleInstructionId,
- Map<String, RunnerApi.PCollection> pCollections,
- Multimap<String, ThrowingConsumer<WindowedValue<?>>> pCollectionIdsToConsumers,
- Consumer<ThrowingRunnable> addStartFunction,
- Consumer<ThrowingRunnable> addFinishFunction) throws IOException {
- thrown.expect(IllegalStateException.class);
- thrown.expectMessage("TestException");
- addStartFunction.accept(this::throwException);
- }
-
- private void throwException() {
- throw new IllegalStateException("TestException");
- }
- };
+ beamFnDataClient,
+ ImmutableMap.of(DATA_INPUT_URN, new PTransformRunnerFactory<Object>() {
+ @Override
+ public Object createRunnerForPTransform(
+ PipelineOptions pipelineOptions,
+ BeamFnDataClient beamFnDataClient,
+ String pTransformId,
+ RunnerApi.PTransform pTransform,
+ Supplier<String> processBundleInstructionId,
+ Map<String, RunnerApi.PCollection> pCollections,
+ Map<String, RunnerApi.Coder> coders,
+ Multimap<String, ThrowingConsumer<WindowedValue<?>>> pCollectionIdsToConsumers,
+ Consumer<ThrowingRunnable> addStartFunction,
+ Consumer<ThrowingRunnable> addFinishFunction) throws IOException {
+ thrown.expect(IllegalStateException.class);
+ thrown.expectMessage("TestException");
+ addStartFunction.accept(ProcessBundleHandlerTest::throwException);
+ return null;
+ }
+ }));
handler.processBundle(
BeamFnApi.InstructionRequest.newBuilder().setProcessBundle(
BeamFnApi.ProcessBundleRequest.newBuilder().setProcessBundleDescriptorReference("1L"))
@@ -283,338 +223,33 @@ public class ProcessBundleHandlerTest {
ProcessBundleHandler handler = new ProcessBundleHandler(
PipelineOptionsFactory.create(),
fnApiRegistry::get,
- beamFnDataClient) {
- @Override
- protected void createRunnerForPTransform(
- String pTransformId,
- RunnerApi.PTransform pTransform,
- Supplier<String> processBundleInstructionId,
- Map<String, RunnerApi.PCollection> pCollections,
- Multimap<String, ThrowingConsumer<WindowedValue<?>>> pCollectionIdsToConsumers,
- Consumer<ThrowingRunnable> addStartFunction,
- Consumer<ThrowingRunnable> addFinishFunction) throws IOException {
- thrown.expect(IllegalStateException.class);
- thrown.expectMessage("TestException");
- addFinishFunction.accept(this::throwException);
- }
-
- private void throwException() {
- throw new IllegalStateException("TestException");
- }
- };
+ beamFnDataClient,
+ ImmutableMap.of(DATA_INPUT_URN, new PTransformRunnerFactory<Object>() {
+ @Override
+ public Object createRunnerForPTransform(
+ PipelineOptions pipelineOptions,
+ BeamFnDataClient beamFnDataClient,
+ String pTransformId,
+ RunnerApi.PTransform pTransform,
+ Supplier<String> processBundleInstructionId,
+ Map<String, RunnerApi.PCollection> pCollections,
+ Map<String, RunnerApi.Coder> coders,
+ Multimap<String, ThrowingConsumer<WindowedValue<?>>> pCollectionIdsToConsumers,
+ Consumer<ThrowingRunnable> addStartFunction,
+ Consumer<ThrowingRunnable> addFinishFunction) throws IOException {
+ thrown.expect(IllegalStateException.class);
+ thrown.expectMessage("TestException");
+ addFinishFunction.accept(ProcessBundleHandlerTest::throwException);
+ return null;
+ }
+ }));
handler.processBundle(
BeamFnApi.InstructionRequest.newBuilder().setProcessBundle(
BeamFnApi.ProcessBundleRequest.newBuilder().setProcessBundleDescriptorReference("1L"))
.build());
}
- private static class TestDoFn extends DoFn<String, String> {
- private static final TupleTag<String> mainOutput = new TupleTag<>("mainOutput");
- private static final TupleTag<String> additionalOutput = new TupleTag<>("output");
-
- private BoundedWindow window;
-
- @ProcessElement
- public void processElement(ProcessContext context, BoundedWindow window) {
- context.output("MainOutput" + context.element());
- context.output(additionalOutput, "AdditionalOutput" + context.element());
- this.window = window;
- }
-
- @FinishBundle
- public void finishBundle(FinishBundleContext context) {
- if (window != null) {
- context.output("FinishBundle", window.maxTimestamp(), window);
- window = null;
- }
- }
- }
-
- /**
- * Create a DoFn that has 3 inputs (inputATarget1, inputATarget2, inputBTarget) and 2 outputs
- * (mainOutput, output). Validate that inputs are fed to the {@link DoFn} and that outputs
- * are directed to the correct consumers.
- */
- @Test
- public void testCreatingAndProcessingDoFn() throws Exception {
- Map<String, Message> fnApiRegistry = ImmutableMap.of(STRING_CODER_SPEC_ID, STRING_CODER_SPEC);
- String pTransformId = "100L";
- String mainOutputId = "101";
- String additionalOutputId = "102";
-
- DoFnInfo<?, ?> doFnInfo = DoFnInfo.forFn(
- new TestDoFn(),
- WindowingStrategy.globalDefault(),
- ImmutableList.of(),
- StringUtf8Coder.of(),
- Long.parseLong(mainOutputId),
- ImmutableMap.of(
- Long.parseLong(mainOutputId), TestDoFn.mainOutput,
- Long.parseLong(additionalOutputId), TestDoFn.additionalOutput));
- RunnerApi.FunctionSpec functionSpec = RunnerApi.FunctionSpec.newBuilder()
- .setUrn(JAVA_DO_FN_URN)
- .setParameter(Any.pack(BytesValue.newBuilder()
- .setValue(ByteString.copyFrom(SerializableUtils.serializeToByteArray(doFnInfo)))
- .build()))
- .build();
- RunnerApi.PTransform pTransform = RunnerApi.PTransform.newBuilder()
- .setSpec(functionSpec)
- .putInputs("inputA", "inputATarget")
- .putInputs("inputB", "inputBTarget")
- .putOutputs(mainOutputId, "mainOutputTarget")
- .putOutputs(additionalOutputId, "additionalOutputTarget")
- .build();
-
- List<WindowedValue<String>> mainOutputValues = new ArrayList<>();
- List<WindowedValue<String>> additionalOutputValues = new ArrayList<>();
- Multimap<String, ThrowingConsumer<WindowedValue<?>>> consumers = HashMultimap.create();
- consumers.put("mainOutputTarget",
- (ThrowingConsumer) (ThrowingConsumer<WindowedValue<String>>) mainOutputValues::add);
- consumers.put("additionalOutputTarget",
- (ThrowingConsumer) (ThrowingConsumer<WindowedValue<String>>) additionalOutputValues::add);
- List<ThrowingRunnable> startFunctions = new ArrayList<>();
- List<ThrowingRunnable> finishFunctions = new ArrayList<>();
-
- ProcessBundleHandler handler = new ProcessBundleHandler(
- PipelineOptionsFactory.create(),
- fnApiRegistry::get,
- beamFnDataClient);
- handler.createRunnerForPTransform(
- pTransformId,
- pTransform,
- Suppliers.ofInstance("57L")::get,
- ImmutableMap.of(),
- consumers,
- startFunctions::add,
- finishFunctions::add);
-
- Iterables.getOnlyElement(startFunctions).run();
- mainOutputValues.clear();
-
- assertThat(consumers.keySet(), containsInAnyOrder(
- "inputATarget", "inputBTarget", "mainOutputTarget", "additionalOutputTarget"));
-
- Iterables.getOnlyElement(consumers.get("inputATarget")).accept(valueInGlobalWindow("A1"));
- Iterables.getOnlyElement(consumers.get("inputATarget")).accept(valueInGlobalWindow("A2"));
- Iterables.getOnlyElement(consumers.get("inputATarget")).accept(valueInGlobalWindow("B"));
- assertThat(mainOutputValues, contains(
- valueInGlobalWindow("MainOutputA1"),
- valueInGlobalWindow("MainOutputA2"),
- valueInGlobalWindow("MainOutputB")));
- assertThat(additionalOutputValues, contains(
- valueInGlobalWindow("AdditionalOutputA1"),
- valueInGlobalWindow("AdditionalOutputA2"),
- valueInGlobalWindow("AdditionalOutputB")));
- mainOutputValues.clear();
- additionalOutputValues.clear();
-
- Iterables.getOnlyElement(finishFunctions).run();
- assertThat(
- mainOutputValues,
- contains(
- timestampedValueInGlobalWindow("FinishBundle", GlobalWindow.INSTANCE.maxTimestamp())));
- mainOutputValues.clear();
- }
-
- @Test
- public void testCreatingAndProcessingSource() throws Exception {
- Map<String, Message> fnApiRegistry = ImmutableMap.of(LONG_CODER_SPEC_ID, LONG_CODER_SPEC);
- List<WindowedValue<String>> outputValues = new ArrayList<>();
-
- Multimap<String, ThrowingConsumer<WindowedValue<?>>> consumers = HashMultimap.create();
- consumers.put("outputPC",
- (ThrowingConsumer) (ThrowingConsumer<WindowedValue<String>>) outputValues::add);
- List<ThrowingRunnable> startFunctions = new ArrayList<>();
- List<ThrowingRunnable> finishFunctions = new ArrayList<>();
-
- RunnerApi.FunctionSpec functionSpec = RunnerApi.FunctionSpec.newBuilder()
- .setUrn(JAVA_SOURCE_URN)
- .setParameter(Any.pack(BytesValue.newBuilder()
- .setValue(ByteString.copyFrom(
- SerializableUtils.serializeToByteArray(CountingSource.upTo(3))))
- .build()))
- .build();
-
- RunnerApi.PTransform pTransform = RunnerApi.PTransform.newBuilder()
- .setSpec(functionSpec)
- .putInputs("input", "inputPC")
- .putOutputs("output", "outputPC")
- .build();
-
- ProcessBundleHandler handler = new ProcessBundleHandler(
- PipelineOptionsFactory.create(),
- fnApiRegistry::get,
- beamFnDataClient);
-
- handler.createRunnerForPTransform(
- "pTransformId",
- pTransform,
- Suppliers.ofInstance("57L")::get,
- ImmutableMap.of(),
- consumers,
- startFunctions::add,
- finishFunctions::add);
-
- // This is testing a deprecated way of running sources and should be removed
- // once all source definitions are instead propagated along the input edge.
- Iterables.getOnlyElement(startFunctions).run();
- assertThat(outputValues, contains(
- valueInGlobalWindow(0L),
- valueInGlobalWindow(1L),
- valueInGlobalWindow(2L)));
- outputValues.clear();
-
- // Check that when passing a source along as an input, the source is processed.
- assertThat(consumers.keySet(), containsInAnyOrder("inputPC", "outputPC"));
- Iterables.getOnlyElement(consumers.get("inputPC")).accept(
- valueInGlobalWindow(CountingSource.upTo(2)));
- assertThat(outputValues, contains(
- valueInGlobalWindow(0L),
- valueInGlobalWindow(1L)));
-
- assertThat(finishFunctions, empty());
- }
-
- @Test
- public void testCreatingAndProcessingBeamFnDataReadRunner() throws Exception {
- Map<String, Message> fnApiRegistry = ImmutableMap.of(STRING_CODER_SPEC_ID, STRING_CODER_SPEC);
- String bundleId = "57";
- String outputId = "101";
-
- List<WindowedValue<String>> outputValues = new ArrayList<>();
-
- Multimap<String, ThrowingConsumer<WindowedValue<?>>> consumers = HashMultimap.create();
- consumers.put("outputPC",
- (ThrowingConsumer) (ThrowingConsumer<WindowedValue<String>>) outputValues::add);
- List<ThrowingRunnable> startFunctions = new ArrayList<>();
- List<ThrowingRunnable> finishFunctions = new ArrayList<>();
-
- RunnerApi.FunctionSpec functionSpec = RunnerApi.FunctionSpec.newBuilder()
- .setUrn(DATA_INPUT_URN)
- .setParameter(Any.pack(REMOTE_PORT))
- .build();
-
- RunnerApi.PTransform pTransform = RunnerApi.PTransform.newBuilder()
- .setSpec(functionSpec)
- .putOutputs(outputId, "outputPC")
- .build();
-
- ProcessBundleHandler handler = new ProcessBundleHandler(
- PipelineOptionsFactory.create(),
- fnApiRegistry::get,
- beamFnDataClient);
-
- handler.createRunnerForPTransform(
- "pTransformId",
- pTransform,
- Suppliers.ofInstance(bundleId)::get,
- ImmutableMap.of("outputPC",
- RunnerApi.PCollection.newBuilder().setCoderId(STRING_CODER_SPEC_ID).build()),
- consumers,
- startFunctions::add,
- finishFunctions::add);
-
- verifyZeroInteractions(beamFnDataClient);
-
- CompletableFuture<Void> completionFuture = new CompletableFuture<>();
- when(beamFnDataClient.forInboundConsumer(any(), any(), any(), any()))
- .thenReturn(completionFuture);
- Iterables.getOnlyElement(startFunctions).run();
- verify(beamFnDataClient).forInboundConsumer(
- eq(REMOTE_PORT.getApiServiceDescriptor()),
- eq(KV.of(bundleId, BeamFnApi.Target.newBuilder()
- .setPrimitiveTransformReference("pTransformId")
- .setName(outputId)
- .build())),
- eq(STRING_CODER),
- consumerCaptor.capture());
-
- consumerCaptor.getValue().accept(valueInGlobalWindow("TestValue"));
- assertThat(outputValues, contains(valueInGlobalWindow("TestValue")));
- outputValues.clear();
-
- assertThat(consumers.keySet(), containsInAnyOrder("outputPC"));
-
- completionFuture.complete(null);
- Iterables.getOnlyElement(finishFunctions).run();
-
- verifyNoMoreInteractions(beamFnDataClient);
- }
-
- @Test
- public void testCreatingAndProcessingBeamFnDataWriteRunner() throws Exception {
- Map<String, Message> fnApiRegistry = ImmutableMap.of(STRING_CODER_SPEC_ID, STRING_CODER_SPEC);
- String bundleId = "57L";
- String inputId = "100L";
-
- Multimap<String, ThrowingConsumer<WindowedValue<?>>> consumers = HashMultimap.create();
- List<ThrowingRunnable> startFunctions = new ArrayList<>();
- List<ThrowingRunnable> finishFunctions = new ArrayList<>();
-
- RunnerApi.FunctionSpec functionSpec = RunnerApi.FunctionSpec.newBuilder()
- .setUrn(DATA_OUTPUT_URN)
- .setParameter(Any.pack(REMOTE_PORT))
- .build();
-
- RunnerApi.PTransform pTransform = RunnerApi.PTransform.newBuilder()
- .setSpec(functionSpec)
- .putInputs(inputId, "inputPC")
- .build();
-
- ProcessBundleHandler handler = new ProcessBundleHandler(
- PipelineOptionsFactory.create(),
- fnApiRegistry::get,
- beamFnDataClient);
-
- handler.createRunnerForPTransform(
- "ptransformId",
- pTransform,
- Suppliers.ofInstance(bundleId)::get,
- ImmutableMap.of("inputPC",
- RunnerApi.PCollection.newBuilder().setCoderId(STRING_CODER_SPEC_ID).build()),
- consumers,
- startFunctions::add,
- finishFunctions::add);
-
- verifyZeroInteractions(beamFnDataClient);
-
- List<WindowedValue<String>> outputValues = new ArrayList<>();
- AtomicBoolean wasCloseCalled = new AtomicBoolean();
- CloseableThrowingConsumer<WindowedValue<String>> outputConsumer =
- new CloseableThrowingConsumer<WindowedValue<String>>(){
- @Override
- public void close() throws Exception {
- wasCloseCalled.set(true);
- }
-
- @Override
- public void accept(WindowedValue<String> t) throws Exception {
- outputValues.add(t);
- }
- };
-
- when(beamFnDataClient.forOutboundConsumer(
- any(),
- any(),
- Matchers.<Coder<WindowedValue<String>>>any())).thenReturn(outputConsumer);
- Iterables.getOnlyElement(startFunctions).run();
- verify(beamFnDataClient).forOutboundConsumer(
- eq(REMOTE_PORT.getApiServiceDescriptor()),
- eq(KV.of(bundleId, BeamFnApi.Target.newBuilder()
- .setPrimitiveTransformReference("ptransformId")
- .setName(inputId)
- .build())),
- eq(STRING_CODER));
-
- assertThat(consumers.keySet(), containsInAnyOrder("inputPC"));
- Iterables.getOnlyElement(consumers.get("inputPC")).accept(valueInGlobalWindow("TestValue"));
- assertThat(outputValues, contains(valueInGlobalWindow("TestValue")));
- outputValues.clear();
-
- assertFalse(wasCloseCalled.get());
- Iterables.getOnlyElement(finishFunctions).run();
- assertTrue(wasCloseCalled.get());
-
- verifyNoMoreInteractions(beamFnDataClient);
+ private static void throwException() {
+ throw new IllegalStateException("TestException");
}
}
http://git-wip-us.apache.org/repos/asf/beam/blob/c9c1a05d/sdks/java/harness/src/test/java/org/apache/beam/runners/core/BeamFnDataReadRunnerTest.java
----------------------------------------------------------------------
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/runners/core/BeamFnDataReadRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/runners/core/BeamFnDataReadRunnerTest.java
index 7e8ab1a..d6a476e 100644
--- a/sdks/java/harness/src/test/java/org/apache/beam/runners/core/BeamFnDataReadRunnerTest.java
+++ b/sdks/java/harness/src/test/java/org/apache/beam/runners/core/BeamFnDataReadRunnerTest.java
@@ -20,41 +20,51 @@ package org.apache.beam.runners.core;
import static org.apache.beam.sdk.util.WindowedValue.valueInGlobalWindow;
import static org.hamcrest.Matchers.contains;
+import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.junit.Assert.assertThat;
+import static org.junit.Assert.fail;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
+import static org.mockito.Mockito.verifyZeroInteractions;
import static org.mockito.Mockito.when;
import com.fasterxml.jackson.databind.ObjectMapper;
+import com.google.common.base.Suppliers;
+import com.google.common.collect.HashMultimap;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.Iterables;
+import com.google.common.collect.Multimap;
import com.google.common.util.concurrent.Uninterruptibles;
import com.google.protobuf.Any;
import com.google.protobuf.ByteString;
import com.google.protobuf.BytesValue;
import java.io.IOException;
import java.util.ArrayList;
-import java.util.Collection;
import java.util.List;
-import java.util.Map;
+import java.util.ServiceLoader;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import org.apache.beam.fn.harness.data.BeamFnDataClient;
import org.apache.beam.fn.harness.fn.ThrowingConsumer;
+import org.apache.beam.fn.harness.fn.ThrowingRunnable;
import org.apache.beam.fn.harness.test.TestExecutors;
import org.apache.beam.fn.harness.test.TestExecutors.TestExecutorService;
import org.apache.beam.fn.v1.BeamFnApi;
+import org.apache.beam.runners.core.PTransformRunnerFactory.Registrar;
import org.apache.beam.runners.dataflow.util.CloudObjects;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.common.runner.v1.RunnerApi;
+import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.values.KV;
+import org.hamcrest.collection.IsMapContaining;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
@@ -68,15 +78,18 @@ import org.mockito.MockitoAnnotations;
/** Tests for {@link BeamFnDataReadRunner}. */
@RunWith(JUnit4.class)
public class BeamFnDataReadRunnerTest {
- private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();
+ private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();
private static final BeamFnApi.RemoteGrpcPort PORT_SPEC = BeamFnApi.RemoteGrpcPort.newBuilder()
.setApiServiceDescriptor(BeamFnApi.ApiServiceDescriptor.getDefaultInstance()).build();
private static final RunnerApi.FunctionSpec FUNCTION_SPEC = RunnerApi.FunctionSpec.newBuilder()
.setParameter(Any.pack(PORT_SPEC)).build();
private static final Coder<WindowedValue<String>> CODER =
WindowedValue.getFullCoder(StringUtf8Coder.of(), GlobalWindow.Coder.INSTANCE);
+ private static final String CODER_SPEC_ID = "string-coder-id";
private static final RunnerApi.Coder CODER_SPEC;
+ private static final String URN = "urn:org.apache.beam:source:runner:0.1";
+
static {
try {
CODER_SPEC = RunnerApi.Coder.newBuilder().setSpec(
@@ -98,7 +111,7 @@ public class BeamFnDataReadRunnerTest {
.build();
@Rule public TestExecutorService executor = TestExecutors.from(Executors::newCachedThreadPool);
- @Mock private BeamFnDataClient mockBeamFnDataClientFactory;
+ @Mock private BeamFnDataClient mockBeamFnDataClient;
@Captor private ArgumentCaptor<ThrowingConsumer<WindowedValue<String>>> consumerCaptor;
@Before
@@ -107,32 +120,93 @@ public class BeamFnDataReadRunnerTest {
}
@Test
+ public void testCreatingAndProcessingBeamFnDataReadRunner() throws Exception {
+ String bundleId = "57";
+ String outputId = "101";
+
+ List<WindowedValue<String>> outputValues = new ArrayList<>();
+
+ Multimap<String, ThrowingConsumer<WindowedValue<?>>> consumers = HashMultimap.create();
+ consumers.put("outputPC",
+ (ThrowingConsumer) (ThrowingConsumer<WindowedValue<String>>) outputValues::add);
+ List<ThrowingRunnable> startFunctions = new ArrayList<>();
+ List<ThrowingRunnable> finishFunctions = new ArrayList<>();
+
+ RunnerApi.FunctionSpec functionSpec = RunnerApi.FunctionSpec.newBuilder()
+ .setUrn("urn:org.apache.beam:source:runner:0.1")
+ .setParameter(Any.pack(PORT_SPEC))
+ .build();
+
+ RunnerApi.PTransform pTransform = RunnerApi.PTransform.newBuilder()
+ .setSpec(functionSpec)
+ .putOutputs(outputId, "outputPC")
+ .build();
+
+ new BeamFnDataReadRunner.Factory<String>().createRunnerForPTransform(
+ PipelineOptionsFactory.create(),
+ mockBeamFnDataClient,
+ "pTransformId",
+ pTransform,
+ Suppliers.ofInstance(bundleId)::get,
+ ImmutableMap.of("outputPC",
+ RunnerApi.PCollection.newBuilder().setCoderId(CODER_SPEC_ID).build()),
+ ImmutableMap.of(CODER_SPEC_ID, CODER_SPEC),
+ consumers,
+ startFunctions::add,
+ finishFunctions::add);
+
+ verifyZeroInteractions(mockBeamFnDataClient);
+
+ CompletableFuture<Void> completionFuture = new CompletableFuture<>();
+ when(mockBeamFnDataClient.forInboundConsumer(any(), any(), any(), any()))
+ .thenReturn(completionFuture);
+ Iterables.getOnlyElement(startFunctions).run();
+ verify(mockBeamFnDataClient).forInboundConsumer(
+ eq(PORT_SPEC.getApiServiceDescriptor()),
+ eq(KV.of(bundleId, BeamFnApi.Target.newBuilder()
+ .setPrimitiveTransformReference("pTransformId")
+ .setName(outputId)
+ .build())),
+ eq(CODER),
+ consumerCaptor.capture());
+
+ consumerCaptor.getValue().accept(valueInGlobalWindow("TestValue"));
+ assertThat(outputValues, contains(valueInGlobalWindow("TestValue")));
+ outputValues.clear();
+
+ assertThat(consumers.keySet(), containsInAnyOrder("outputPC"));
+
+ completionFuture.complete(null);
+ Iterables.getOnlyElement(finishFunctions).run();
+
+ verifyNoMoreInteractions(mockBeamFnDataClient);
+ }
+
+ @Test
public void testReuseForMultipleBundles() throws Exception {
CompletableFuture<Void> bundle1Future = new CompletableFuture<>();
CompletableFuture<Void> bundle2Future = new CompletableFuture<>();
- when(mockBeamFnDataClientFactory.forInboundConsumer(
+ when(mockBeamFnDataClient.forInboundConsumer(
any(),
any(),
any(),
any())).thenReturn(bundle1Future).thenReturn(bundle2Future);
List<WindowedValue<String>> valuesA = new ArrayList<>();
List<WindowedValue<String>> valuesB = new ArrayList<>();
- Map<String, Collection<ThrowingConsumer<WindowedValue<String>>>> outputMap = ImmutableMap.of(
- "outA", ImmutableList.of(valuesA::add),
- "outB", ImmutableList.of(valuesB::add));
+
AtomicReference<String> bundleId = new AtomicReference<>("0");
BeamFnDataReadRunner<String> readRunner = new BeamFnDataReadRunner<>(
FUNCTION_SPEC,
bundleId::get,
INPUT_TARGET,
CODER_SPEC,
- mockBeamFnDataClientFactory,
- outputMap);
+ mockBeamFnDataClient,
+ ImmutableList.of(valuesA::add, valuesB::add));
// Process for bundle id 0
readRunner.registerInputLocation();
- verify(mockBeamFnDataClientFactory).forInboundConsumer(
+ verify(mockBeamFnDataClient).forInboundConsumer(
eq(PORT_SPEC.getApiServiceDescriptor()),
eq(KV.of(bundleId.get(), INPUT_TARGET)),
eq(CODER),
@@ -164,7 +238,7 @@ public class BeamFnDataReadRunnerTest {
valuesB.clear();
readRunner.registerInputLocation();
- verify(mockBeamFnDataClientFactory).forInboundConsumer(
+ verify(mockBeamFnDataClient).forInboundConsumer(
eq(PORT_SPEC.getApiServiceDescriptor()),
eq(KV.of(bundleId.get(), INPUT_TARGET)),
eq(CODER),
@@ -190,6 +264,18 @@ public class BeamFnDataReadRunnerTest {
assertThat(valuesA, contains(valueInGlobalWindow("GHI"), valueInGlobalWindow("JKL")));
assertThat(valuesB, contains(valueInGlobalWindow("GHI"), valueInGlobalWindow("JKL")));
- verifyNoMoreInteractions(mockBeamFnDataClientFactory);
+ verifyNoMoreInteractions(mockBeamFnDataClient);
+ }
+
+ @Test
+ public void testRegistration() {
+ for (Registrar registrar :
+ ServiceLoader.load(Registrar.class)) {
+ if (registrar instanceof BeamFnDataReadRunner.Registrar) {
+ assertThat(registrar.getPTransformRunnerFactories(), IsMapContaining.hasKey(URN));
+ return;
+ }
+ }
+ fail("Expected registrar not found.");
}
}