You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by ke...@apache.org on 2017/07/18 20:05:35 UTC
[2/4] beam git commit: Fix split package in SDK harness
http://git-wip-us.apache.org/repos/asf/beam/blob/f1b4700f/sdks/java/harness/src/main/java/org/apache/beam/runners/core/FnApiDoFnRunner.java
----------------------------------------------------------------------
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/runners/core/FnApiDoFnRunner.java b/sdks/java/harness/src/main/java/org/apache/beam/runners/core/FnApiDoFnRunner.java
deleted file mode 100644
index b3cf3a7..0000000
--- a/sdks/java/harness/src/main/java/org/apache/beam/runners/core/FnApiDoFnRunner.java
+++ /dev/null
@@ -1,547 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.beam.runners.core;
-
-import static com.google.common.base.Preconditions.checkArgument;
-
-import com.google.auto.service.AutoService;
-import com.google.common.collect.Collections2;
-import com.google.common.collect.ImmutableMap;
-import com.google.common.collect.ImmutableMultimap;
-import com.google.common.collect.Multimap;
-import com.google.protobuf.ByteString;
-import com.google.protobuf.BytesValue;
-import com.google.protobuf.InvalidProtocolBufferException;
-import java.util.Collection;
-import java.util.HashSet;
-import java.util.Iterator;
-import java.util.Map;
-import java.util.Objects;
-import java.util.function.Consumer;
-import java.util.function.Supplier;
-import org.apache.beam.fn.harness.data.BeamFnDataClient;
-import org.apache.beam.fn.harness.fn.ThrowingConsumer;
-import org.apache.beam.fn.harness.fn.ThrowingRunnable;
-import org.apache.beam.runners.core.construction.ParDoTranslation;
-import org.apache.beam.runners.dataflow.util.DoFnInfo;
-import org.apache.beam.sdk.common.runner.v1.RunnerApi;
-import org.apache.beam.sdk.options.PipelineOptions;
-import org.apache.beam.sdk.state.State;
-import org.apache.beam.sdk.state.TimeDomain;
-import org.apache.beam.sdk.state.Timer;
-import org.apache.beam.sdk.transforms.DoFn;
-import org.apache.beam.sdk.transforms.DoFn.OnTimerContext;
-import org.apache.beam.sdk.transforms.DoFn.ProcessContext;
-import org.apache.beam.sdk.transforms.reflect.DoFnInvoker;
-import org.apache.beam.sdk.transforms.reflect.DoFnInvokers;
-import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker;
-import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
-import org.apache.beam.sdk.transforms.windowing.PaneInfo;
-import org.apache.beam.sdk.util.SerializableUtils;
-import org.apache.beam.sdk.util.UserCodeException;
-import org.apache.beam.sdk.util.WindowedValue;
-import org.apache.beam.sdk.values.PCollectionView;
-import org.apache.beam.sdk.values.TupleTag;
-import org.apache.beam.sdk.values.WindowingStrategy;
-import org.joda.time.Instant;
-
-/**
- * A {@link DoFnRunner} specific to integrating with the Fn Api. This is to remove the layers
- * of abstraction caused by StateInternals/TimerInternals since they model state and timer
- * concepts differently.
- */
-public class FnApiDoFnRunner<InputT, OutputT> implements DoFnRunner<InputT, OutputT> {
- /**
- * A registrar which provides a factory to handle Java {@link DoFn}s.
- */
- @AutoService(PTransformRunnerFactory.Registrar.class)
- public static class Registrar implements
- PTransformRunnerFactory.Registrar {
-
- @Override
- public Map<String, PTransformRunnerFactory> getPTransformRunnerFactories() {
- return ImmutableMap.of(ParDoTranslation.CUSTOM_JAVA_DO_FN_URN, new Factory());
- }
- }
-
- /** A factory for {@link FnApiDoFnRunner}. */
- static class Factory<InputT, OutputT>
- implements PTransformRunnerFactory<DoFnRunner<InputT, OutputT>> {
-
- @Override
- public DoFnRunner<InputT, OutputT> createRunnerForPTransform(
- PipelineOptions pipelineOptions,
- BeamFnDataClient beamFnDataClient,
- String pTransformId,
- RunnerApi.PTransform pTransform,
- Supplier<String> processBundleInstructionId,
- Map<String, RunnerApi.PCollection> pCollections,
- Map<String, RunnerApi.Coder> coders,
- Multimap<String, ThrowingConsumer<WindowedValue<?>>> pCollectionIdsToConsumers,
- Consumer<ThrowingRunnable> addStartFunction,
- Consumer<ThrowingRunnable> addFinishFunction) {
-
- // For every output PCollection, create a map from output name to Consumer
- ImmutableMap.Builder<String, Collection<ThrowingConsumer<WindowedValue<?>>>>
- outputMapBuilder = ImmutableMap.builder();
- for (Map.Entry<String, String> entry : pTransform.getOutputsMap().entrySet()) {
- outputMapBuilder.put(
- entry.getKey(),
- pCollectionIdsToConsumers.get(entry.getValue()));
- }
- ImmutableMap<String, Collection<ThrowingConsumer<WindowedValue<?>>>> outputMap =
- outputMapBuilder.build();
-
- // Get the DoFnInfo from the serialized blob.
- ByteString serializedFn;
- try {
- serializedFn = pTransform.getSpec().getParameter().unpack(BytesValue.class).getValue();
- } catch (InvalidProtocolBufferException e) {
- throw new IllegalArgumentException(
- String.format("Unable to unwrap DoFn %s", pTransform.getSpec()), e);
- }
- @SuppressWarnings({"unchecked", "rawtypes"})
- DoFnInfo<InputT, OutputT> doFnInfo = (DoFnInfo) SerializableUtils.deserializeFromByteArray(
- serializedFn.toByteArray(), "DoFnInfo");
-
- // Verify that the DoFnInfo tag to output map matches the output map on the PTransform.
- checkArgument(
- Objects.equals(
- new HashSet<>(Collections2.transform(outputMap.keySet(), Long::parseLong)),
- doFnInfo.getOutputMap().keySet()),
- "Unexpected mismatch between transform output map %s and DoFnInfo output map %s.",
- outputMap.keySet(),
- doFnInfo.getOutputMap());
-
- ImmutableMultimap.Builder<TupleTag<?>,
- ThrowingConsumer<WindowedValue<?>>> tagToOutputMapBuilder =
- ImmutableMultimap.builder();
- for (Map.Entry<Long, TupleTag<?>> entry : doFnInfo.getOutputMap().entrySet()) {
- @SuppressWarnings({"unchecked", "rawtypes"})
- Collection<ThrowingConsumer<WindowedValue<?>>> consumers =
- outputMap.get(Long.toString(entry.getKey()));
- tagToOutputMapBuilder.putAll(entry.getValue(), consumers);
- }
-
- ImmutableMultimap<TupleTag<?>, ThrowingConsumer<WindowedValue<?>>> tagToOutputMap =
- tagToOutputMapBuilder.build();
-
- @SuppressWarnings({"unchecked", "rawtypes"})
- DoFnRunner<InputT, OutputT> runner = new FnApiDoFnRunner<>(
- pipelineOptions,
- doFnInfo.getDoFn(),
- (Collection<ThrowingConsumer<WindowedValue<OutputT>>>) (Collection)
- tagToOutputMap.get(doFnInfo.getOutputMap().get(doFnInfo.getMainOutput())),
- tagToOutputMap,
- doFnInfo.getWindowingStrategy());
-
- // Register the appropriate handlers.
- addStartFunction.accept(runner::startBundle);
- for (String pcollectionId : pTransform.getInputsMap().values()) {
- pCollectionIdsToConsumers.put(
- pcollectionId,
- (ThrowingConsumer) (ThrowingConsumer<WindowedValue<InputT>>) runner::processElement);
- }
- addFinishFunction.accept(runner::finishBundle);
- return runner;
- }
- }
-
- //////////////////////////////////////////////////////////////////////////////////////////////////
-
- private final PipelineOptions pipelineOptions;
- private final DoFn<InputT, OutputT> doFn;
- private final Collection<ThrowingConsumer<WindowedValue<OutputT>>> mainOutputConsumers;
- private final Multimap<TupleTag<?>, ThrowingConsumer<WindowedValue<?>>> outputMap;
- private final DoFnInvoker<InputT, OutputT> doFnInvoker;
- private final StartBundleContext startBundleContext;
- private final ProcessBundleContext processBundleContext;
- private final FinishBundleContext finishBundleContext;
-
- /**
- * The lifetime of this member is only valid during {@link #processElement(WindowedValue)}.
- */
- private WindowedValue<InputT> currentElement;
-
- /**
- * The lifetime of this member is only valid during {@link #processElement(WindowedValue)}.
- */
- private BoundedWindow currentWindow;
-
- FnApiDoFnRunner(
- PipelineOptions pipelineOptions,
- DoFn<InputT, OutputT> doFn,
- Collection<ThrowingConsumer<WindowedValue<OutputT>>> mainOutputConsumers,
- Multimap<TupleTag<?>, ThrowingConsumer<WindowedValue<?>>> outputMap,
- WindowingStrategy windowingStrategy) {
- this.pipelineOptions = pipelineOptions;
- this.doFn = doFn;
- this.mainOutputConsumers = mainOutputConsumers;
- this.outputMap = outputMap;
- this.doFnInvoker = DoFnInvokers.invokerFor(doFn);
- this.startBundleContext = new StartBundleContext();
- this.processBundleContext = new ProcessBundleContext();
- this.finishBundleContext = new FinishBundleContext();
- }
-
- @Override
- public void startBundle() {
- doFnInvoker.invokeStartBundle(startBundleContext);
- }
-
- @Override
- public void processElement(WindowedValue<InputT> elem) {
- currentElement = elem;
- try {
- Iterator<BoundedWindow> windowIterator =
- (Iterator<BoundedWindow>) elem.getWindows().iterator();
- while (windowIterator.hasNext()) {
- currentWindow = windowIterator.next();
- doFnInvoker.invokeProcessElement(processBundleContext);
- }
- } finally {
- currentElement = null;
- currentWindow = null;
- }
- }
-
- @Override
- public void onTimer(
- String timerId,
- BoundedWindow window,
- Instant timestamp,
- TimeDomain timeDomain) {
- throw new UnsupportedOperationException("TODO: Add support for timers");
- }
-
- @Override
- public void finishBundle() {
- doFnInvoker.invokeFinishBundle(finishBundleContext);
- }
-
- /**
- * Outputs the given element to the specified set of consumers wrapping any exceptions.
- */
- private <T> void outputTo(
- Collection<ThrowingConsumer<WindowedValue<T>>> consumers,
- WindowedValue<T> output) {
- Iterator<ThrowingConsumer<WindowedValue<T>>> consumerIterator;
- try {
- for (ThrowingConsumer<WindowedValue<T>> consumer : consumers) {
- consumer.accept(output);
- }
- } catch (Throwable t) {
- throw UserCodeException.wrap(t);
- }
- }
-
- /**
- * Provides arguments for a {@link DoFnInvoker} for {@link DoFn.StartBundle @StartBundle}.
- */
- private class StartBundleContext
- extends DoFn<InputT, OutputT>.StartBundleContext
- implements DoFnInvoker.ArgumentProvider<InputT, OutputT> {
-
- private StartBundleContext() {
- doFn.super();
- }
-
- @Override
- public PipelineOptions getPipelineOptions() {
- return pipelineOptions;
- }
-
- @Override
- public PipelineOptions pipelineOptions() {
- return pipelineOptions;
- }
-
- @Override
- public BoundedWindow window() {
- throw new UnsupportedOperationException(
- "Cannot access window outside of @ProcessElement and @OnTimer methods.");
- }
-
- @Override
- public DoFn<InputT, OutputT>.StartBundleContext startBundleContext(
- DoFn<InputT, OutputT> doFn) {
- return this;
- }
-
- @Override
- public DoFn<InputT, OutputT>.FinishBundleContext finishBundleContext(
- DoFn<InputT, OutputT> doFn) {
- throw new UnsupportedOperationException(
- "Cannot access FinishBundleContext outside of @FinishBundle method.");
- }
-
- @Override
- public DoFn<InputT, OutputT>.ProcessContext processContext(DoFn<InputT, OutputT> doFn) {
- throw new UnsupportedOperationException(
- "Cannot access ProcessContext outside of @ProcessElement method.");
- }
-
- @Override
- public DoFn<InputT, OutputT>.OnTimerContext onTimerContext(DoFn<InputT, OutputT> doFn) {
- throw new UnsupportedOperationException(
- "Cannot access OnTimerContext outside of @OnTimer methods.");
- }
-
- @Override
- public RestrictionTracker<?> restrictionTracker() {
- throw new UnsupportedOperationException(
- "Cannot access RestrictionTracker outside of @ProcessElement method.");
- }
-
- @Override
- public State state(String stateId) {
- throw new UnsupportedOperationException(
- "Cannot access state outside of @ProcessElement and @OnTimer methods.");
- }
-
- @Override
- public Timer timer(String timerId) {
- throw new UnsupportedOperationException(
- "Cannot access timers outside of @ProcessElement and @OnTimer methods.");
- }
- }
-
- /**
- * Provides arguments for a {@link DoFnInvoker} for {@link DoFn.ProcessElement @ProcessElement}.
- */
- private class ProcessBundleContext
- extends DoFn<InputT, OutputT>.ProcessContext
- implements DoFnInvoker.ArgumentProvider<InputT, OutputT> {
-
- private ProcessBundleContext() {
- doFn.super();
- }
-
- @Override
- public BoundedWindow window() {
- return currentWindow;
- }
-
- @Override
- public DoFn.StartBundleContext startBundleContext(DoFn<InputT, OutputT> doFn) {
- throw new UnsupportedOperationException(
- "Cannot access StartBundleContext outside of @StartBundle method.");
- }
-
- @Override
- public DoFn.FinishBundleContext finishBundleContext(DoFn<InputT, OutputT> doFn) {
- throw new UnsupportedOperationException(
- "Cannot access FinishBundleContext outside of @FinishBundle method.");
- }
-
- @Override
- public ProcessContext processContext(DoFn<InputT, OutputT> doFn) {
- return this;
- }
-
- @Override
- public OnTimerContext onTimerContext(DoFn<InputT, OutputT> doFn) {
- throw new UnsupportedOperationException("TODO: Add support for timers");
- }
-
- @Override
- public RestrictionTracker<?> restrictionTracker() {
- throw new UnsupportedOperationException("TODO: Add support for SplittableDoFn");
- }
-
- @Override
- public State state(String stateId) {
- throw new UnsupportedOperationException("TODO: Add support for state");
- }
-
- @Override
- public Timer timer(String timerId) {
- throw new UnsupportedOperationException("TODO: Add support for timers");
- }
-
- @Override
- public PipelineOptions getPipelineOptions() {
- return pipelineOptions;
- }
-
- @Override
- public PipelineOptions pipelineOptions() {
- return pipelineOptions;
- }
-
- @Override
- public void output(OutputT output) {
- outputTo(mainOutputConsumers,
- WindowedValue.of(
- output,
- currentElement.getTimestamp(),
- currentWindow,
- currentElement.getPane()));
- }
-
- @Override
- public void outputWithTimestamp(OutputT output, Instant timestamp) {
- outputTo(mainOutputConsumers,
- WindowedValue.of(
- output,
- timestamp,
- currentWindow,
- currentElement.getPane()));
- }
-
- @Override
- public <T> void output(TupleTag<T> tag, T output) {
- Collection<ThrowingConsumer<WindowedValue<T>>> consumers = (Collection) outputMap.get(tag);
- if (consumers == null) {
- throw new IllegalArgumentException(String.format("Unknown output tag %s", tag));
- }
- outputTo(consumers,
- WindowedValue.of(
- output,
- currentElement.getTimestamp(),
- currentWindow,
- currentElement.getPane()));
- }
-
- @Override
- public <T> void outputWithTimestamp(TupleTag<T> tag, T output, Instant timestamp) {
- Collection<ThrowingConsumer<WindowedValue<T>>> consumers = (Collection) outputMap.get(tag);
- if (consumers == null) {
- throw new IllegalArgumentException(String.format("Unknown output tag %s", tag));
- }
- outputTo(consumers,
- WindowedValue.of(
- output,
- timestamp,
- currentWindow,
- currentElement.getPane()));
- }
-
- @Override
- public InputT element() {
- return currentElement.getValue();
- }
-
- @Override
- public <T> T sideInput(PCollectionView<T> view) {
- throw new UnsupportedOperationException("TODO: Support side inputs");
- }
-
- @Override
- public Instant timestamp() {
- return currentElement.getTimestamp();
- }
-
- @Override
- public PaneInfo pane() {
- return currentElement.getPane();
- }
-
- @Override
- public void updateWatermark(Instant watermark) {
- throw new UnsupportedOperationException("TODO: Add support for SplittableDoFn");
- }
- }
-
- /**
- * Provides arguments for a {@link DoFnInvoker} for {@link DoFn.FinishBundle @FinishBundle}.
- */
- private class FinishBundleContext
- extends DoFn<InputT, OutputT>.FinishBundleContext
- implements DoFnInvoker.ArgumentProvider<InputT, OutputT> {
-
- private FinishBundleContext() {
- doFn.super();
- }
-
- @Override
- public PipelineOptions getPipelineOptions() {
- return pipelineOptions;
- }
-
- @Override
- public PipelineOptions pipelineOptions() {
- return pipelineOptions;
- }
-
- @Override
- public BoundedWindow window() {
- throw new UnsupportedOperationException(
- "Cannot access window outside of @ProcessElement and @OnTimer methods.");
- }
-
- @Override
- public DoFn<InputT, OutputT>.StartBundleContext startBundleContext(
- DoFn<InputT, OutputT> doFn) {
- throw new UnsupportedOperationException(
- "Cannot access StartBundleContext outside of @StartBundle method.");
- }
-
- @Override
- public DoFn<InputT, OutputT>.FinishBundleContext finishBundleContext(
- DoFn<InputT, OutputT> doFn) {
- return this;
- }
-
- @Override
- public DoFn<InputT, OutputT>.ProcessContext processContext(DoFn<InputT, OutputT> doFn) {
- throw new UnsupportedOperationException(
- "Cannot access ProcessContext outside of @ProcessElement method.");
- }
-
- @Override
- public DoFn<InputT, OutputT>.OnTimerContext onTimerContext(DoFn<InputT, OutputT> doFn) {
- throw new UnsupportedOperationException(
- "Cannot access OnTimerContext outside of @OnTimer methods.");
- }
-
- @Override
- public RestrictionTracker<?> restrictionTracker() {
- throw new UnsupportedOperationException(
- "Cannot access RestrictionTracker outside of @ProcessElement method.");
- }
-
- @Override
- public State state(String stateId) {
- throw new UnsupportedOperationException(
- "Cannot access state outside of @ProcessElement and @OnTimer methods.");
- }
-
- @Override
- public Timer timer(String timerId) {
- throw new UnsupportedOperationException(
- "Cannot access timers outside of @ProcessElement and @OnTimer methods.");
- }
-
- @Override
- public void output(OutputT output, Instant timestamp, BoundedWindow window) {
- outputTo(mainOutputConsumers,
- WindowedValue.of(output, timestamp, window, PaneInfo.NO_FIRING));
- }
-
- @Override
- public <T> void output(TupleTag<T> tag, T output, Instant timestamp, BoundedWindow window) {
- Collection<ThrowingConsumer<WindowedValue<T>>> consumers = (Collection) outputMap.get(tag);
- if (consumers == null) {
- throw new IllegalArgumentException(String.format("Unknown output tag %s", tag));
- }
- outputTo(consumers,
- WindowedValue.of(output, timestamp, window, PaneInfo.NO_FIRING));
- }
- }
-}
http://git-wip-us.apache.org/repos/asf/beam/blob/f1b4700f/sdks/java/harness/src/main/java/org/apache/beam/runners/core/PTransformRunnerFactory.java
----------------------------------------------------------------------
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/runners/core/PTransformRunnerFactory.java b/sdks/java/harness/src/main/java/org/apache/beam/runners/core/PTransformRunnerFactory.java
deleted file mode 100644
index b325db4..0000000
--- a/sdks/java/harness/src/main/java/org/apache/beam/runners/core/PTransformRunnerFactory.java
+++ /dev/null
@@ -1,81 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.beam.runners.core;
-
-import com.google.common.collect.Multimap;
-import java.io.IOException;
-import java.util.Map;
-import java.util.function.Consumer;
-import java.util.function.Supplier;
-import org.apache.beam.fn.harness.data.BeamFnDataClient;
-import org.apache.beam.fn.harness.fn.ThrowingConsumer;
-import org.apache.beam.fn.harness.fn.ThrowingRunnable;
-import org.apache.beam.sdk.common.runner.v1.RunnerApi;
-import org.apache.beam.sdk.options.PipelineOptions;
-import org.apache.beam.sdk.util.WindowedValue;
-
-/**
- * A factory able to instantiate an appropriate handler for a given PTransform.
- */
-public interface PTransformRunnerFactory<T> {
-
- /**
- * Creates and returns a handler for a given PTransform. Note that the handler must support
- * processing multiple bundles. The handler will be discarded if an error is thrown during
- * element processing, or during execution of start/finish.
- *
- * @param pipelineOptions Pipeline options
- * @param beamFnDataClient
- * @param pTransformId The id of the PTransform.
- * @param pTransform The PTransform definition.
- * @param processBundleInstructionId A supplier containing the active process bundle instruction
- * id.
- * @param pCollections A mapping from PCollection id to PCollection definition.
- * @param coders A mapping from coder id to coder definition.
- * @param pCollectionIdsToConsumers A mapping from PCollection id to a collection of consumers.
- * Note that if this handler is a consumer, it should register itself within this multimap under
- * the appropriate PCollection ids. Also note that all output consumers needed by this PTransform
- * (based on the values of the {@link RunnerApi.PTransform#getOutputsMap()} will have already
- * registered within this multimap.
- * @param addStartFunction A consumer to register a start bundle handler with.
- * @param addFinishFunction A consumer to register a finish bundle handler with.
- */
- T createRunnerForPTransform(
- PipelineOptions pipelineOptions,
- BeamFnDataClient beamFnDataClient,
- String pTransformId,
- RunnerApi.PTransform pTransform,
- Supplier<String> processBundleInstructionId,
- Map<String, RunnerApi.PCollection> pCollections,
- Map<String, RunnerApi.Coder> coders,
- Multimap<String, ThrowingConsumer<WindowedValue<?>>> pCollectionIdsToConsumers,
- Consumer<ThrowingRunnable> addStartFunction,
- Consumer<ThrowingRunnable> addFinishFunction) throws IOException;
-
- /**
- * A registrar which can return a mapping from {@link RunnerApi.FunctionSpec#getUrn()} to
- * a factory capable of instantiating an appropriate handler.
- */
- interface Registrar {
- /**
- * Returns a mapping from {@link RunnerApi.FunctionSpec#getUrn()} to a factory capable of
- * instantiating an appropriate handler.
- */
- Map<String, PTransformRunnerFactory> getPTransformRunnerFactories();
- }
-}
http://git-wip-us.apache.org/repos/asf/beam/blob/f1b4700f/sdks/java/harness/src/main/java/org/apache/beam/runners/core/package-info.java
----------------------------------------------------------------------
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/runners/core/package-info.java b/sdks/java/harness/src/main/java/org/apache/beam/runners/core/package-info.java
deleted file mode 100644
index d250a6a..0000000
--- a/sdks/java/harness/src/main/java/org/apache/beam/runners/core/package-info.java
+++ /dev/null
@@ -1,22 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-/**
- * Provides utilities for Beam runner authors.
- */
-package org.apache.beam.runners.core;
http://git-wip-us.apache.org/repos/asf/beam/blob/f1b4700f/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataReadRunnerTest.java
----------------------------------------------------------------------
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataReadRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataReadRunnerTest.java
new file mode 100644
index 0000000..a7c6666
--- /dev/null
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataReadRunnerTest.java
@@ -0,0 +1,281 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.fn.harness;
+
+import static org.apache.beam.sdk.util.WindowedValue.valueInGlobalWindow;
+import static org.hamcrest.Matchers.contains;
+import static org.hamcrest.Matchers.containsInAnyOrder;
+import static org.junit.Assert.assertThat;
+import static org.junit.Assert.fail;
+import static org.mockito.Matchers.any;
+import static org.mockito.Matchers.eq;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.verifyNoMoreInteractions;
+import static org.mockito.Mockito.verifyZeroInteractions;
+import static org.mockito.Mockito.when;
+
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.google.common.base.Suppliers;
+import com.google.common.collect.HashMultimap;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.Iterables;
+import com.google.common.collect.Multimap;
+import com.google.common.util.concurrent.Uninterruptibles;
+import com.google.protobuf.Any;
+import com.google.protobuf.ByteString;
+import com.google.protobuf.BytesValue;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.ServiceLoader;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.Executors;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicReference;
+import org.apache.beam.fn.harness.PTransformRunnerFactory.Registrar;
+import org.apache.beam.fn.harness.data.BeamFnDataClient;
+import org.apache.beam.fn.harness.fn.ThrowingConsumer;
+import org.apache.beam.fn.harness.fn.ThrowingRunnable;
+import org.apache.beam.fn.harness.test.TestExecutors;
+import org.apache.beam.fn.harness.test.TestExecutors.TestExecutorService;
+import org.apache.beam.fn.v1.BeamFnApi;
+import org.apache.beam.runners.dataflow.util.CloudObjects;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.StringUtf8Coder;
+import org.apache.beam.sdk.common.runner.v1.RunnerApi;
+import org.apache.beam.sdk.options.PipelineOptionsFactory;
+import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
+import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.sdk.values.KV;
+import org.hamcrest.collection.IsMapContaining;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+import org.mockito.ArgumentCaptor;
+import org.mockito.Captor;
+import org.mockito.Mock;
+import org.mockito.MockitoAnnotations;
+
+/** Tests for {@link BeamFnDataReadRunner}. */
+@RunWith(JUnit4.class)
+public class BeamFnDataReadRunnerTest {
+
+ private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();
+ private static final BeamFnApi.RemoteGrpcPort PORT_SPEC = BeamFnApi.RemoteGrpcPort.newBuilder()
+ .setApiServiceDescriptor(BeamFnApi.ApiServiceDescriptor.getDefaultInstance()).build();
+ private static final RunnerApi.FunctionSpec FUNCTION_SPEC = RunnerApi.FunctionSpec.newBuilder()
+ .setParameter(Any.pack(PORT_SPEC)).build();
+ private static final Coder<WindowedValue<String>> CODER =
+ WindowedValue.getFullCoder(StringUtf8Coder.of(), GlobalWindow.Coder.INSTANCE);
+ private static final String CODER_SPEC_ID = "string-coder-id";
+ private static final RunnerApi.Coder CODER_SPEC;
+ private static final String URN = "urn:org.apache.beam:source:runner:0.1";
+
+ static {
+ try {
+ CODER_SPEC = RunnerApi.Coder.newBuilder().setSpec(
+ RunnerApi.SdkFunctionSpec.newBuilder().setSpec(
+ RunnerApi.FunctionSpec.newBuilder().setParameter(
+ Any.pack(BytesValue.newBuilder().setValue(ByteString.copyFrom(
+ OBJECT_MAPPER.writeValueAsBytes(CloudObjects.asCloudObject(CODER))))
+ .build()))
+ .build())
+ .build())
+ .build();
+ } catch (IOException e) {
+ throw new ExceptionInInitializerError(e);
+ }
+ }
+ private static final BeamFnApi.Target INPUT_TARGET = BeamFnApi.Target.newBuilder()
+ .setPrimitiveTransformReference("1")
+ .setName("out")
+ .build();
+
+ @Rule public TestExecutorService executor = TestExecutors.from(Executors::newCachedThreadPool);
+ @Mock private BeamFnDataClient mockBeamFnDataClient;
+ @Captor private ArgumentCaptor<ThrowingConsumer<WindowedValue<String>>> consumerCaptor;
+
+ @Before
+ public void setUp() {
+ MockitoAnnotations.initMocks(this);
+ }
+
+ @Test
+ public void testCreatingAndProcessingBeamFnDataReadRunner() throws Exception {
+ String bundleId = "57";
+ String outputId = "101";
+
+ List<WindowedValue<String>> outputValues = new ArrayList<>();
+
+ Multimap<String, ThrowingConsumer<WindowedValue<?>>> consumers = HashMultimap.create();
+ consumers.put("outputPC",
+ (ThrowingConsumer) (ThrowingConsumer<WindowedValue<String>>) outputValues::add);
+ List<ThrowingRunnable> startFunctions = new ArrayList<>();
+ List<ThrowingRunnable> finishFunctions = new ArrayList<>();
+
+ RunnerApi.FunctionSpec functionSpec = RunnerApi.FunctionSpec.newBuilder()
+ .setUrn("urn:org.apache.beam:source:runner:0.1")
+ .setParameter(Any.pack(PORT_SPEC))
+ .build();
+
+ RunnerApi.PTransform pTransform = RunnerApi.PTransform.newBuilder()
+ .setSpec(functionSpec)
+ .putOutputs(outputId, "outputPC")
+ .build();
+
+ new BeamFnDataReadRunner.Factory<String>().createRunnerForPTransform(
+ PipelineOptionsFactory.create(),
+ mockBeamFnDataClient,
+ "pTransformId",
+ pTransform,
+ Suppliers.ofInstance(bundleId)::get,
+ ImmutableMap.of("outputPC",
+ RunnerApi.PCollection.newBuilder().setCoderId(CODER_SPEC_ID).build()),
+ ImmutableMap.of(CODER_SPEC_ID, CODER_SPEC),
+ consumers,
+ startFunctions::add,
+ finishFunctions::add);
+
+ verifyZeroInteractions(mockBeamFnDataClient);
+
+ CompletableFuture<Void> completionFuture = new CompletableFuture<>();
+ when(mockBeamFnDataClient.forInboundConsumer(any(), any(), any(), any()))
+ .thenReturn(completionFuture);
+ Iterables.getOnlyElement(startFunctions).run();
+ verify(mockBeamFnDataClient).forInboundConsumer(
+ eq(PORT_SPEC.getApiServiceDescriptor()),
+ eq(KV.of(bundleId, BeamFnApi.Target.newBuilder()
+ .setPrimitiveTransformReference("pTransformId")
+ .setName(outputId)
+ .build())),
+ eq(CODER),
+ consumerCaptor.capture());
+
+ consumerCaptor.getValue().accept(valueInGlobalWindow("TestValue"));
+ assertThat(outputValues, contains(valueInGlobalWindow("TestValue")));
+ outputValues.clear();
+
+ assertThat(consumers.keySet(), containsInAnyOrder("outputPC"));
+
+ completionFuture.complete(null);
+ Iterables.getOnlyElement(finishFunctions).run();
+
+ verifyNoMoreInteractions(mockBeamFnDataClient);
+ }
+
+ @Test
+ public void testReuseForMultipleBundles() throws Exception {
+ CompletableFuture<Void> bundle1Future = new CompletableFuture<>();
+ CompletableFuture<Void> bundle2Future = new CompletableFuture<>();
+ when(mockBeamFnDataClient.forInboundConsumer(
+ any(),
+ any(),
+ any(),
+ any())).thenReturn(bundle1Future).thenReturn(bundle2Future);
+ List<WindowedValue<String>> valuesA = new ArrayList<>();
+ List<WindowedValue<String>> valuesB = new ArrayList<>();
+
+ AtomicReference<String> bundleId = new AtomicReference<>("0");
+ BeamFnDataReadRunner<String> readRunner = new BeamFnDataReadRunner<>(
+ FUNCTION_SPEC,
+ bundleId::get,
+ INPUT_TARGET,
+ CODER_SPEC,
+ mockBeamFnDataClient,
+ ImmutableList.of(valuesA::add, valuesB::add));
+
+ // Process for bundle id 0
+ readRunner.registerInputLocation();
+
+ verify(mockBeamFnDataClient).forInboundConsumer(
+ eq(PORT_SPEC.getApiServiceDescriptor()),
+ eq(KV.of(bundleId.get(), INPUT_TARGET)),
+ eq(CODER),
+ consumerCaptor.capture());
+
+ executor.submit(new Runnable() {
+ @Override
+ public void run() {
+ // Sleep for some small amount of time simulating the parent blocking
+ Uninterruptibles.sleepUninterruptibly(100, TimeUnit.MILLISECONDS);
+ try {
+ consumerCaptor.getValue().accept(valueInGlobalWindow("ABC"));
+ consumerCaptor.getValue().accept(valueInGlobalWindow("DEF"));
+ } catch (Exception e) {
+ bundle1Future.completeExceptionally(e);
+ } finally {
+ bundle1Future.complete(null);
+ }
+ }
+ });
+
+ readRunner.blockTillReadFinishes();
+ assertThat(valuesA, contains(valueInGlobalWindow("ABC"), valueInGlobalWindow("DEF")));
+ assertThat(valuesB, contains(valueInGlobalWindow("ABC"), valueInGlobalWindow("DEF")));
+
+ // Process for bundle id 1
+ bundleId.set("1");
+ valuesA.clear();
+ valuesB.clear();
+ readRunner.registerInputLocation();
+
+ verify(mockBeamFnDataClient).forInboundConsumer(
+ eq(PORT_SPEC.getApiServiceDescriptor()),
+ eq(KV.of(bundleId.get(), INPUT_TARGET)),
+ eq(CODER),
+ consumerCaptor.capture());
+
+ executor.submit(new Runnable() {
+ @Override
+ public void run() {
+ // Sleep for some small amount of time simulating the parent blocking
+ Uninterruptibles.sleepUninterruptibly(100, TimeUnit.MILLISECONDS);
+ try {
+ consumerCaptor.getValue().accept(valueInGlobalWindow("GHI"));
+ consumerCaptor.getValue().accept(valueInGlobalWindow("JKL"));
+ } catch (Exception e) {
+ bundle2Future.completeExceptionally(e);
+ } finally {
+ bundle2Future.complete(null);
+ }
+ }
+ });
+
+ readRunner.blockTillReadFinishes();
+ assertThat(valuesA, contains(valueInGlobalWindow("GHI"), valueInGlobalWindow("JKL")));
+ assertThat(valuesB, contains(valueInGlobalWindow("GHI"), valueInGlobalWindow("JKL")));
+
+ verifyNoMoreInteractions(mockBeamFnDataClient);
+ }
+
+ @Test
+ public void testRegistration() {
+ for (Registrar registrar :
+ ServiceLoader.load(Registrar.class)) {
+ if (registrar instanceof BeamFnDataReadRunner.Registrar) {
+ assertThat(registrar.getPTransformRunnerFactories(), IsMapContaining.hasKey(URN));
+ return;
+ }
+ }
+ fail("Expected registrar not found.");
+ }
+}
http://git-wip-us.apache.org/repos/asf/beam/blob/f1b4700f/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataWriteRunnerTest.java
----------------------------------------------------------------------
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataWriteRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataWriteRunnerTest.java
new file mode 100644
index 0000000..28838b1
--- /dev/null
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataWriteRunnerTest.java
@@ -0,0 +1,269 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.fn.harness;
+
+import static org.apache.beam.sdk.util.WindowedValue.valueInGlobalWindow;
+import static org.hamcrest.Matchers.contains;
+import static org.hamcrest.Matchers.containsInAnyOrder;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertThat;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+import static org.mockito.Matchers.any;
+import static org.mockito.Matchers.eq;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.verifyNoMoreInteractions;
+import static org.mockito.Mockito.verifyZeroInteractions;
+import static org.mockito.Mockito.when;
+
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.google.common.base.Suppliers;
+import com.google.common.collect.HashMultimap;
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.Iterables;
+import com.google.common.collect.Multimap;
+import com.google.protobuf.Any;
+import com.google.protobuf.ByteString;
+import com.google.protobuf.BytesValue;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.ServiceLoader;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicReference;
+import org.apache.beam.fn.harness.PTransformRunnerFactory.Registrar;
+import org.apache.beam.fn.harness.data.BeamFnDataClient;
+import org.apache.beam.fn.harness.fn.CloseableThrowingConsumer;
+import org.apache.beam.fn.harness.fn.ThrowingConsumer;
+import org.apache.beam.fn.harness.fn.ThrowingRunnable;
+import org.apache.beam.fn.v1.BeamFnApi;
+import org.apache.beam.runners.dataflow.util.CloudObjects;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.StringUtf8Coder;
+import org.apache.beam.sdk.common.runner.v1.RunnerApi;
+import org.apache.beam.sdk.options.PipelineOptionsFactory;
+import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
+import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.sdk.values.KV;
+import org.hamcrest.collection.IsMapContaining;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+import org.mockito.Matchers;
+import org.mockito.Mock;
+import org.mockito.MockitoAnnotations;
+
+/** Tests for {@link BeamFnDataWriteRunner}. */
+@RunWith(JUnit4.class)
+public class BeamFnDataWriteRunnerTest {
+
+ private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();
+ private static final BeamFnApi.RemoteGrpcPort PORT_SPEC = BeamFnApi.RemoteGrpcPort.newBuilder()
+ .setApiServiceDescriptor(BeamFnApi.ApiServiceDescriptor.getDefaultInstance()).build();
+ private static final RunnerApi.FunctionSpec FUNCTION_SPEC = RunnerApi.FunctionSpec.newBuilder()
+ .setParameter(Any.pack(PORT_SPEC)).build();
+ private static final String CODER_ID = "string-coder-id";
+ private static final Coder<WindowedValue<String>> CODER =
+ WindowedValue.getFullCoder(StringUtf8Coder.of(), GlobalWindow.Coder.INSTANCE);
+ private static final RunnerApi.Coder CODER_SPEC;
+ private static final String URN = "urn:org.apache.beam:sink:runner:0.1";
+
+ static {
+ try {
+ CODER_SPEC = RunnerApi.Coder.newBuilder().setSpec(
+ RunnerApi.SdkFunctionSpec.newBuilder().setSpec(
+ RunnerApi.FunctionSpec.newBuilder().setParameter(
+ Any.pack(BytesValue.newBuilder().setValue(ByteString.copyFrom(
+ OBJECT_MAPPER.writeValueAsBytes(CloudObjects.asCloudObject(CODER))))
+ .build()))
+ .build())
+ .build())
+ .build();
+ } catch (IOException e) {
+ throw new ExceptionInInitializerError(e);
+ }
+ }
+ private static final BeamFnApi.Target OUTPUT_TARGET = BeamFnApi.Target.newBuilder()
+ .setPrimitiveTransformReference("1")
+ .setName("out")
+ .build();
+
+ @Mock private BeamFnDataClient mockBeamFnDataClient;
+
+ @Before
+ public void setUp() {
+ MockitoAnnotations.initMocks(this);
+ }
+
+
+ @Test
+ public void testCreatingAndProcessingBeamFnDataWriteRunner() throws Exception {
+ String bundleId = "57L";
+ String inputId = "100L";
+
+ Multimap<String, ThrowingConsumer<WindowedValue<?>>> consumers = HashMultimap.create();
+ List<ThrowingRunnable> startFunctions = new ArrayList<>();
+ List<ThrowingRunnable> finishFunctions = new ArrayList<>();
+
+ RunnerApi.FunctionSpec functionSpec = RunnerApi.FunctionSpec.newBuilder()
+ .setUrn("urn:org.apache.beam:sink:runner:0.1")
+ .setParameter(Any.pack(PORT_SPEC))
+ .build();
+
+ RunnerApi.PTransform pTransform = RunnerApi.PTransform.newBuilder()
+ .setSpec(functionSpec)
+ .putInputs(inputId, "inputPC")
+ .build();
+
+ new BeamFnDataWriteRunner.Factory<String>().createRunnerForPTransform(
+ PipelineOptionsFactory.create(),
+ mockBeamFnDataClient,
+ "ptransformId",
+ pTransform,
+ Suppliers.ofInstance(bundleId)::get,
+ ImmutableMap.of("inputPC",
+ RunnerApi.PCollection.newBuilder().setCoderId(CODER_ID).build()),
+ ImmutableMap.of(CODER_ID, CODER_SPEC),
+ consumers,
+ startFunctions::add,
+ finishFunctions::add);
+
+ verifyZeroInteractions(mockBeamFnDataClient);
+
+ List<WindowedValue<String>> outputValues = new ArrayList<>();
+ AtomicBoolean wasCloseCalled = new AtomicBoolean();
+ CloseableThrowingConsumer<WindowedValue<String>> outputConsumer =
+ new CloseableThrowingConsumer<WindowedValue<String>>(){
+ @Override
+ public void close() throws Exception {
+ wasCloseCalled.set(true);
+ }
+
+ @Override
+ public void accept(WindowedValue<String> t) throws Exception {
+ outputValues.add(t);
+ }
+ };
+
+ when(mockBeamFnDataClient.forOutboundConsumer(
+ any(),
+ any(),
+ Matchers.<Coder<WindowedValue<String>>>any())).thenReturn(outputConsumer);
+ Iterables.getOnlyElement(startFunctions).run();
+ verify(mockBeamFnDataClient).forOutboundConsumer(
+ eq(PORT_SPEC.getApiServiceDescriptor()),
+ eq(KV.of(bundleId, BeamFnApi.Target.newBuilder()
+ .setPrimitiveTransformReference("ptransformId")
+ .setName(inputId)
+ .build())),
+ eq(CODER));
+
+ assertThat(consumers.keySet(), containsInAnyOrder("inputPC"));
+ Iterables.getOnlyElement(consumers.get("inputPC")).accept(valueInGlobalWindow("TestValue"));
+ assertThat(outputValues, contains(valueInGlobalWindow("TestValue")));
+ outputValues.clear();
+
+ assertFalse(wasCloseCalled.get());
+ Iterables.getOnlyElement(finishFunctions).run();
+ assertTrue(wasCloseCalled.get());
+
+ verifyNoMoreInteractions(mockBeamFnDataClient);
+ }
+
+ @Test
+ public void testReuseForMultipleBundles() throws Exception {
+ RecordingConsumer<WindowedValue<String>> valuesA = new RecordingConsumer<>();
+ RecordingConsumer<WindowedValue<String>> valuesB = new RecordingConsumer<>();
+ when(mockBeamFnDataClient.forOutboundConsumer(
+ any(),
+ any(),
+ Matchers.<Coder<WindowedValue<String>>>any())).thenReturn(valuesA).thenReturn(valuesB);
+ AtomicReference<String> bundleId = new AtomicReference<>("0");
+ BeamFnDataWriteRunner<String> writeRunner = new BeamFnDataWriteRunner<>(
+ FUNCTION_SPEC,
+ bundleId::get,
+ OUTPUT_TARGET,
+ CODER_SPEC,
+ mockBeamFnDataClient);
+
+ // Process for bundle id 0
+ writeRunner.registerForOutput();
+
+ verify(mockBeamFnDataClient).forOutboundConsumer(
+ eq(PORT_SPEC.getApiServiceDescriptor()),
+ eq(KV.of(bundleId.get(), OUTPUT_TARGET)),
+ eq(CODER));
+
+ writeRunner.consume(valueInGlobalWindow("ABC"));
+ writeRunner.consume(valueInGlobalWindow("DEF"));
+ writeRunner.close();
+
+ assertTrue(valuesA.closed);
+ assertThat(valuesA, contains(valueInGlobalWindow("ABC"), valueInGlobalWindow("DEF")));
+
+ // Process for bundle id 1
+ bundleId.set("1");
+ valuesA.clear();
+ valuesB.clear();
+ writeRunner.registerForOutput();
+
+ verify(mockBeamFnDataClient).forOutboundConsumer(
+ eq(PORT_SPEC.getApiServiceDescriptor()),
+ eq(KV.of(bundleId.get(), OUTPUT_TARGET)),
+ eq(CODER));
+
+ writeRunner.consume(valueInGlobalWindow("GHI"));
+ writeRunner.consume(valueInGlobalWindow("JKL"));
+ writeRunner.close();
+
+ assertTrue(valuesB.closed);
+ assertThat(valuesB, contains(valueInGlobalWindow("GHI"), valueInGlobalWindow("JKL")));
+ verifyNoMoreInteractions(mockBeamFnDataClient);
+ }
+
+ private static class RecordingConsumer<T> extends ArrayList<T>
+ implements CloseableThrowingConsumer<T> {
+ private boolean closed;
+ @Override
+ public void close() throws Exception {
+ closed = true;
+ }
+
+ @Override
+ public void accept(T t) throws Exception {
+ if (closed) {
+ throw new IllegalStateException("Consumer is closed but attempting to consume " + t);
+ }
+ add(t);
+ }
+ }
+
+ @Test
+ public void testRegistration() {
+ for (Registrar registrar :
+ ServiceLoader.load(Registrar.class)) {
+ if (registrar instanceof BeamFnDataWriteRunner.Registrar) {
+ assertThat(registrar.getPTransformRunnerFactories(), IsMapContaining.hasKey(URN));
+ return;
+ }
+ }
+ fail("Expected registrar not found.");
+ }
+}
http://git-wip-us.apache.org/repos/asf/beam/blob/f1b4700f/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BoundedSourceRunnerTest.java
----------------------------------------------------------------------
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BoundedSourceRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BoundedSourceRunnerTest.java
new file mode 100644
index 0000000..7aec161
--- /dev/null
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BoundedSourceRunnerTest.java
@@ -0,0 +1,187 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.fn.harness;
+
+import static org.apache.beam.sdk.util.WindowedValue.valueInGlobalWindow;
+import static org.hamcrest.Matchers.contains;
+import static org.hamcrest.Matchers.containsInAnyOrder;
+import static org.hamcrest.collection.IsEmptyCollection.empty;
+import static org.junit.Assert.assertThat;
+import static org.junit.Assert.fail;
+
+import com.google.common.base.Suppliers;
+import com.google.common.collect.HashMultimap;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.Iterables;
+import com.google.common.collect.Multimap;
+import com.google.protobuf.Any;
+import com.google.protobuf.ByteString;
+import com.google.protobuf.BytesValue;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.List;
+import java.util.ServiceLoader;
+import org.apache.beam.fn.harness.PTransformRunnerFactory.Registrar;
+import org.apache.beam.fn.harness.fn.ThrowingConsumer;
+import org.apache.beam.fn.harness.fn.ThrowingRunnable;
+import org.apache.beam.sdk.common.runner.v1.RunnerApi;
+import org.apache.beam.sdk.io.BoundedSource;
+import org.apache.beam.sdk.io.CountingSource;
+import org.apache.beam.sdk.options.PipelineOptionsFactory;
+import org.apache.beam.sdk.util.SerializableUtils;
+import org.apache.beam.sdk.util.WindowedValue;
+import org.hamcrest.Matchers;
+import org.hamcrest.collection.IsMapContaining;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/** Tests for {@link BoundedSourceRunner}. */
+@RunWith(JUnit4.class)
+public class BoundedSourceRunnerTest {
+
+ public static final String URN = "urn:org.apache.beam:source:java:0.1";
+
+ @Test
+ public void testRunReadLoopWithMultipleSources() throws Exception {
+ List<WindowedValue<Long>> out1Values = new ArrayList<>();
+ List<WindowedValue<Long>> out2Values = new ArrayList<>();
+ Collection<ThrowingConsumer<WindowedValue<Long>>> consumers =
+ ImmutableList.of(out1Values::add, out2Values::add);
+
+ BoundedSourceRunner<BoundedSource<Long>, Long> runner = new BoundedSourceRunner<>(
+ PipelineOptionsFactory.create(),
+ RunnerApi.FunctionSpec.getDefaultInstance(),
+ consumers);
+
+ runner.runReadLoop(valueInGlobalWindow(CountingSource.upTo(2)));
+ runner.runReadLoop(valueInGlobalWindow(CountingSource.upTo(1)));
+
+ assertThat(out1Values,
+ contains(valueInGlobalWindow(0L), valueInGlobalWindow(1L), valueInGlobalWindow(0L)));
+ assertThat(out2Values,
+ contains(valueInGlobalWindow(0L), valueInGlobalWindow(1L), valueInGlobalWindow(0L)));
+ }
+
+ @Test
+ public void testRunReadLoopWithEmptySource() throws Exception {
+ List<WindowedValue<Long>> outValues = new ArrayList<>();
+ Collection<ThrowingConsumer<WindowedValue<Long>>> consumers =
+ ImmutableList.of(outValues::add);
+
+ BoundedSourceRunner<BoundedSource<Long>, Long> runner = new BoundedSourceRunner<>(
+ PipelineOptionsFactory.create(),
+ RunnerApi.FunctionSpec.getDefaultInstance(),
+ consumers);
+
+ runner.runReadLoop(valueInGlobalWindow(CountingSource.upTo(0)));
+
+ assertThat(outValues, empty());
+ }
+
+ @Test
+ public void testStart() throws Exception {
+ List<WindowedValue<Long>> outValues = new ArrayList<>();
+ Collection<ThrowingConsumer<WindowedValue<Long>>> consumers =
+ ImmutableList.of(outValues::add);
+
+ ByteString encodedSource =
+ ByteString.copyFrom(SerializableUtils.serializeToByteArray(CountingSource.upTo(3)));
+
+ BoundedSourceRunner<BoundedSource<Long>, Long> runner = new BoundedSourceRunner<>(
+ PipelineOptionsFactory.create(),
+ RunnerApi.FunctionSpec.newBuilder().setParameter(
+ Any.pack(BytesValue.newBuilder().setValue(encodedSource).build())).build(),
+ consumers);
+
+ runner.start();
+
+ assertThat(outValues,
+ contains(valueInGlobalWindow(0L), valueInGlobalWindow(1L), valueInGlobalWindow(2L)));
+ }
+
+ @Test
+ public void testCreatingAndProcessingSourceFromFactory() throws Exception {
+ List<WindowedValue<String>> outputValues = new ArrayList<>();
+
+ Multimap<String, ThrowingConsumer<WindowedValue<?>>> consumers = HashMultimap.create();
+ consumers.put("outputPC",
+ (ThrowingConsumer) (ThrowingConsumer<WindowedValue<String>>) outputValues::add);
+ List<ThrowingRunnable> startFunctions = new ArrayList<>();
+ List<ThrowingRunnable> finishFunctions = new ArrayList<>();
+
+ RunnerApi.FunctionSpec functionSpec = RunnerApi.FunctionSpec.newBuilder()
+ .setUrn("urn:org.apache.beam:source:java:0.1")
+ .setParameter(Any.pack(BytesValue.newBuilder()
+ .setValue(ByteString.copyFrom(
+ SerializableUtils.serializeToByteArray(CountingSource.upTo(3))))
+ .build()))
+ .build();
+
+ RunnerApi.PTransform pTransform = RunnerApi.PTransform.newBuilder()
+ .setSpec(functionSpec)
+ .putInputs("input", "inputPC")
+ .putOutputs("output", "outputPC")
+ .build();
+
+ new BoundedSourceRunner.Factory<>().createRunnerForPTransform(
+ PipelineOptionsFactory.create(),
+ null /* beamFnDataClient */,
+ "pTransformId",
+ pTransform,
+ Suppliers.ofInstance("57L")::get,
+ ImmutableMap.of(),
+ ImmutableMap.of(),
+ consumers,
+ startFunctions::add,
+ finishFunctions::add);
+
+ // This is testing a deprecated way of running sources and should be removed
+ // once all source definitions are instead propagated along the input edge.
+ Iterables.getOnlyElement(startFunctions).run();
+ assertThat(outputValues, contains(
+ valueInGlobalWindow(0L),
+ valueInGlobalWindow(1L),
+ valueInGlobalWindow(2L)));
+ outputValues.clear();
+
+ // Check that when passing a source along as an input, the source is processed.
+ assertThat(consumers.keySet(), containsInAnyOrder("inputPC", "outputPC"));
+ Iterables.getOnlyElement(consumers.get("inputPC")).accept(
+ valueInGlobalWindow(CountingSource.upTo(2)));
+ assertThat(outputValues, contains(
+ valueInGlobalWindow(0L),
+ valueInGlobalWindow(1L)));
+
+ assertThat(finishFunctions, Matchers.empty());
+ }
+
+ @Test
+ public void testRegistration() {
+ for (Registrar registrar :
+ ServiceLoader.load(Registrar.class)) {
+ if (registrar instanceof BoundedSourceRunner.Registrar) {
+ assertThat(registrar.getPTransformRunnerFactories(), IsMapContaining.hasKey(URN));
+ return;
+ }
+ }
+ fail("Expected registrar not found.");
+ }
+}
http://git-wip-us.apache.org/repos/asf/beam/blob/f1b4700f/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java
----------------------------------------------------------------------
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java
new file mode 100644
index 0000000..98362a2
--- /dev/null
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java
@@ -0,0 +1,210 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.fn.harness;
+
+import static org.apache.beam.sdk.util.WindowedValue.timestampedValueInGlobalWindow;
+import static org.apache.beam.sdk.util.WindowedValue.valueInGlobalWindow;
+import static org.hamcrest.Matchers.contains;
+import static org.hamcrest.Matchers.containsInAnyOrder;
+import static org.junit.Assert.assertThat;
+import static org.junit.Assert.fail;
+
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.google.common.base.Suppliers;
+import com.google.common.collect.HashMultimap;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.Iterables;
+import com.google.common.collect.Multimap;
+import com.google.protobuf.Any;
+import com.google.protobuf.ByteString;
+import com.google.protobuf.BytesValue;
+import com.google.protobuf.Message;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+import java.util.ServiceLoader;
+import org.apache.beam.fn.harness.PTransformRunnerFactory.Registrar;
+import org.apache.beam.fn.harness.fn.ThrowingConsumer;
+import org.apache.beam.fn.harness.fn.ThrowingRunnable;
+import org.apache.beam.runners.core.construction.ParDoTranslation;
+import org.apache.beam.runners.dataflow.util.CloudObjects;
+import org.apache.beam.runners.dataflow.util.DoFnInfo;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.StringUtf8Coder;
+import org.apache.beam.sdk.common.runner.v1.RunnerApi;
+import org.apache.beam.sdk.options.PipelineOptionsFactory;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
+import org.apache.beam.sdk.util.SerializableUtils;
+import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.sdk.values.TupleTag;
+import org.apache.beam.sdk.values.WindowingStrategy;
+import org.hamcrest.collection.IsMapContaining;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/** Tests for {@link FnApiDoFnRunner}. */
+@RunWith(JUnit4.class)
+public class FnApiDoFnRunnerTest {
+
+ private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();
+ private static final Coder<WindowedValue<String>> STRING_CODER =
+ WindowedValue.getFullCoder(StringUtf8Coder.of(), GlobalWindow.Coder.INSTANCE);
+ private static final String STRING_CODER_SPEC_ID = "999L";
+ private static final RunnerApi.Coder STRING_CODER_SPEC;
+
+ static {
+ try {
+ STRING_CODER_SPEC = RunnerApi.Coder.newBuilder()
+ .setSpec(RunnerApi.SdkFunctionSpec.newBuilder()
+ .setSpec(RunnerApi.FunctionSpec.newBuilder()
+ .setParameter(Any.pack(BytesValue.newBuilder().setValue(ByteString.copyFrom(
+ OBJECT_MAPPER.writeValueAsBytes(CloudObjects.asCloudObject(STRING_CODER))))
+ .build())))
+ .build())
+ .build();
+ } catch (IOException e) {
+ throw new ExceptionInInitializerError(e);
+ }
+ }
+
+ private static class TestDoFn extends DoFn<String, String> {
+ private static final TupleTag<String> mainOutput = new TupleTag<>("mainOutput");
+ private static final TupleTag<String> additionalOutput = new TupleTag<>("output");
+
+ private BoundedWindow window;
+
+ @ProcessElement
+ public void processElement(ProcessContext context, BoundedWindow window) {
+ context.output("MainOutput" + context.element());
+ context.output(additionalOutput, "AdditionalOutput" + context.element());
+ this.window = window;
+ }
+
+ @FinishBundle
+ public void finishBundle(FinishBundleContext context) {
+ if (window != null) {
+ context.output("FinishBundle", window.maxTimestamp(), window);
+ window = null;
+ }
+ }
+ }
+
+ /**
+ * Create a DoFn that has 3 inputs (inputATarget1, inputATarget2, inputBTarget) and 2 outputs
+ * (mainOutput, output). Validate that inputs are fed to the {@link DoFn} and that outputs
+ * are directed to the correct consumers.
+ */
+ @Test
+ public void testCreatingAndProcessingDoFn() throws Exception {
+ Map<String, Message> fnApiRegistry = ImmutableMap.of(STRING_CODER_SPEC_ID, STRING_CODER_SPEC);
+ String pTransformId = "pTransformId";
+ String mainOutputId = "101";
+ String additionalOutputId = "102";
+
+ DoFnInfo<?, ?> doFnInfo = DoFnInfo.forFn(
+ new TestDoFn(),
+ WindowingStrategy.globalDefault(),
+ ImmutableList.of(),
+ StringUtf8Coder.of(),
+ Long.parseLong(mainOutputId),
+ ImmutableMap.of(
+ Long.parseLong(mainOutputId), TestDoFn.mainOutput,
+ Long.parseLong(additionalOutputId), TestDoFn.additionalOutput));
+ RunnerApi.FunctionSpec functionSpec = RunnerApi.FunctionSpec.newBuilder()
+ .setUrn(ParDoTranslation.CUSTOM_JAVA_DO_FN_URN)
+ .setParameter(Any.pack(BytesValue.newBuilder()
+ .setValue(ByteString.copyFrom(SerializableUtils.serializeToByteArray(doFnInfo)))
+ .build()))
+ .build();
+ RunnerApi.PTransform pTransform = RunnerApi.PTransform.newBuilder()
+ .setSpec(functionSpec)
+ .putInputs("inputA", "inputATarget")
+ .putInputs("inputB", "inputBTarget")
+ .putOutputs(mainOutputId, "mainOutputTarget")
+ .putOutputs(additionalOutputId, "additionalOutputTarget")
+ .build();
+
+ List<WindowedValue<String>> mainOutputValues = new ArrayList<>();
+ List<WindowedValue<String>> additionalOutputValues = new ArrayList<>();
+ Multimap<String, ThrowingConsumer<WindowedValue<?>>> consumers = HashMultimap.create();
+ consumers.put("mainOutputTarget",
+ (ThrowingConsumer) (ThrowingConsumer<WindowedValue<String>>) mainOutputValues::add);
+ consumers.put("additionalOutputTarget",
+ (ThrowingConsumer) (ThrowingConsumer<WindowedValue<String>>) additionalOutputValues::add);
+ List<ThrowingRunnable> startFunctions = new ArrayList<>();
+ List<ThrowingRunnable> finishFunctions = new ArrayList<>();
+
+ new FnApiDoFnRunner.Factory<>().createRunnerForPTransform(
+ PipelineOptionsFactory.create(),
+ null /* beamFnDataClient */,
+ pTransformId,
+ pTransform,
+ Suppliers.ofInstance("57L")::get,
+ ImmutableMap.of(),
+ ImmutableMap.of(),
+ consumers,
+ startFunctions::add,
+ finishFunctions::add);
+
+ Iterables.getOnlyElement(startFunctions).run();
+ mainOutputValues.clear();
+
+ assertThat(consumers.keySet(), containsInAnyOrder(
+ "inputATarget", "inputBTarget", "mainOutputTarget", "additionalOutputTarget"));
+
+ Iterables.getOnlyElement(consumers.get("inputATarget")).accept(valueInGlobalWindow("A1"));
+ Iterables.getOnlyElement(consumers.get("inputATarget")).accept(valueInGlobalWindow("A2"));
+ Iterables.getOnlyElement(consumers.get("inputATarget")).accept(valueInGlobalWindow("B"));
+ assertThat(mainOutputValues, contains(
+ valueInGlobalWindow("MainOutputA1"),
+ valueInGlobalWindow("MainOutputA2"),
+ valueInGlobalWindow("MainOutputB")));
+ assertThat(additionalOutputValues, contains(
+ valueInGlobalWindow("AdditionalOutputA1"),
+ valueInGlobalWindow("AdditionalOutputA2"),
+ valueInGlobalWindow("AdditionalOutputB")));
+ mainOutputValues.clear();
+ additionalOutputValues.clear();
+
+ Iterables.getOnlyElement(finishFunctions).run();
+ assertThat(
+ mainOutputValues,
+ contains(
+ timestampedValueInGlobalWindow("FinishBundle", GlobalWindow.INSTANCE.maxTimestamp())));
+ mainOutputValues.clear();
+ }
+
+ @Test
+ public void testRegistration() {
+ for (Registrar registrar :
+ ServiceLoader.load(Registrar.class)) {
+ if (registrar instanceof FnApiDoFnRunner.Registrar) {
+ assertThat(registrar.getPTransformRunnerFactories(),
+ IsMapContaining.hasKey(ParDoTranslation.CUSTOM_JAVA_DO_FN_URN));
+ return;
+ }
+ }
+ fail("Expected registrar not found.");
+ }
+}
http://git-wip-us.apache.org/repos/asf/beam/blob/f1b4700f/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java
----------------------------------------------------------------------
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java
index a616b2c..0a94b5b 100644
--- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java
@@ -31,11 +31,11 @@ import java.util.List;
import java.util.Map;
import java.util.function.Consumer;
import java.util.function.Supplier;
+import org.apache.beam.fn.harness.PTransformRunnerFactory;
import org.apache.beam.fn.harness.data.BeamFnDataClient;
import org.apache.beam.fn.harness.fn.ThrowingConsumer;
import org.apache.beam.fn.harness.fn.ThrowingRunnable;
import org.apache.beam.fn.v1.BeamFnApi;
-import org.apache.beam.runners.core.PTransformRunnerFactory;
import org.apache.beam.sdk.common.runner.v1.RunnerApi;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
http://git-wip-us.apache.org/repos/asf/beam/blob/f1b4700f/sdks/java/harness/src/test/java/org/apache/beam/runners/core/BeamFnDataReadRunnerTest.java
----------------------------------------------------------------------
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/runners/core/BeamFnDataReadRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/runners/core/BeamFnDataReadRunnerTest.java
deleted file mode 100644
index d6a476e..0000000
--- a/sdks/java/harness/src/test/java/org/apache/beam/runners/core/BeamFnDataReadRunnerTest.java
+++ /dev/null
@@ -1,281 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.beam.runners.core;
-
-import static org.apache.beam.sdk.util.WindowedValue.valueInGlobalWindow;
-import static org.hamcrest.Matchers.contains;
-import static org.hamcrest.Matchers.containsInAnyOrder;
-import static org.junit.Assert.assertThat;
-import static org.junit.Assert.fail;
-import static org.mockito.Matchers.any;
-import static org.mockito.Matchers.eq;
-import static org.mockito.Mockito.verify;
-import static org.mockito.Mockito.verifyNoMoreInteractions;
-import static org.mockito.Mockito.verifyZeroInteractions;
-import static org.mockito.Mockito.when;
-
-import com.fasterxml.jackson.databind.ObjectMapper;
-import com.google.common.base.Suppliers;
-import com.google.common.collect.HashMultimap;
-import com.google.common.collect.ImmutableList;
-import com.google.common.collect.ImmutableMap;
-import com.google.common.collect.Iterables;
-import com.google.common.collect.Multimap;
-import com.google.common.util.concurrent.Uninterruptibles;
-import com.google.protobuf.Any;
-import com.google.protobuf.ByteString;
-import com.google.protobuf.BytesValue;
-import java.io.IOException;
-import java.util.ArrayList;
-import java.util.List;
-import java.util.ServiceLoader;
-import java.util.concurrent.CompletableFuture;
-import java.util.concurrent.Executors;
-import java.util.concurrent.TimeUnit;
-import java.util.concurrent.atomic.AtomicReference;
-import org.apache.beam.fn.harness.data.BeamFnDataClient;
-import org.apache.beam.fn.harness.fn.ThrowingConsumer;
-import org.apache.beam.fn.harness.fn.ThrowingRunnable;
-import org.apache.beam.fn.harness.test.TestExecutors;
-import org.apache.beam.fn.harness.test.TestExecutors.TestExecutorService;
-import org.apache.beam.fn.v1.BeamFnApi;
-import org.apache.beam.runners.core.PTransformRunnerFactory.Registrar;
-import org.apache.beam.runners.dataflow.util.CloudObjects;
-import org.apache.beam.sdk.coders.Coder;
-import org.apache.beam.sdk.coders.StringUtf8Coder;
-import org.apache.beam.sdk.common.runner.v1.RunnerApi;
-import org.apache.beam.sdk.options.PipelineOptionsFactory;
-import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
-import org.apache.beam.sdk.util.WindowedValue;
-import org.apache.beam.sdk.values.KV;
-import org.hamcrest.collection.IsMapContaining;
-import org.junit.Before;
-import org.junit.Rule;
-import org.junit.Test;
-import org.junit.runner.RunWith;
-import org.junit.runners.JUnit4;
-import org.mockito.ArgumentCaptor;
-import org.mockito.Captor;
-import org.mockito.Mock;
-import org.mockito.MockitoAnnotations;
-
-/** Tests for {@link BeamFnDataReadRunner}. */
-@RunWith(JUnit4.class)
-public class BeamFnDataReadRunnerTest {
-
- private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();
- private static final BeamFnApi.RemoteGrpcPort PORT_SPEC = BeamFnApi.RemoteGrpcPort.newBuilder()
- .setApiServiceDescriptor(BeamFnApi.ApiServiceDescriptor.getDefaultInstance()).build();
- private static final RunnerApi.FunctionSpec FUNCTION_SPEC = RunnerApi.FunctionSpec.newBuilder()
- .setParameter(Any.pack(PORT_SPEC)).build();
- private static final Coder<WindowedValue<String>> CODER =
- WindowedValue.getFullCoder(StringUtf8Coder.of(), GlobalWindow.Coder.INSTANCE);
- private static final String CODER_SPEC_ID = "string-coder-id";
- private static final RunnerApi.Coder CODER_SPEC;
- private static final String URN = "urn:org.apache.beam:source:runner:0.1";
-
- static {
- try {
- CODER_SPEC = RunnerApi.Coder.newBuilder().setSpec(
- RunnerApi.SdkFunctionSpec.newBuilder().setSpec(
- RunnerApi.FunctionSpec.newBuilder().setParameter(
- Any.pack(BytesValue.newBuilder().setValue(ByteString.copyFrom(
- OBJECT_MAPPER.writeValueAsBytes(CloudObjects.asCloudObject(CODER))))
- .build()))
- .build())
- .build())
- .build();
- } catch (IOException e) {
- throw new ExceptionInInitializerError(e);
- }
- }
- private static final BeamFnApi.Target INPUT_TARGET = BeamFnApi.Target.newBuilder()
- .setPrimitiveTransformReference("1")
- .setName("out")
- .build();
-
- @Rule public TestExecutorService executor = TestExecutors.from(Executors::newCachedThreadPool);
- @Mock private BeamFnDataClient mockBeamFnDataClient;
- @Captor private ArgumentCaptor<ThrowingConsumer<WindowedValue<String>>> consumerCaptor;
-
- @Before
- public void setUp() {
- MockitoAnnotations.initMocks(this);
- }
-
- @Test
- public void testCreatingAndProcessingBeamFnDataReadRunner() throws Exception {
- String bundleId = "57";
- String outputId = "101";
-
- List<WindowedValue<String>> outputValues = new ArrayList<>();
-
- Multimap<String, ThrowingConsumer<WindowedValue<?>>> consumers = HashMultimap.create();
- consumers.put("outputPC",
- (ThrowingConsumer) (ThrowingConsumer<WindowedValue<String>>) outputValues::add);
- List<ThrowingRunnable> startFunctions = new ArrayList<>();
- List<ThrowingRunnable> finishFunctions = new ArrayList<>();
-
- RunnerApi.FunctionSpec functionSpec = RunnerApi.FunctionSpec.newBuilder()
- .setUrn("urn:org.apache.beam:source:runner:0.1")
- .setParameter(Any.pack(PORT_SPEC))
- .build();
-
- RunnerApi.PTransform pTransform = RunnerApi.PTransform.newBuilder()
- .setSpec(functionSpec)
- .putOutputs(outputId, "outputPC")
- .build();
-
- new BeamFnDataReadRunner.Factory<String>().createRunnerForPTransform(
- PipelineOptionsFactory.create(),
- mockBeamFnDataClient,
- "pTransformId",
- pTransform,
- Suppliers.ofInstance(bundleId)::get,
- ImmutableMap.of("outputPC",
- RunnerApi.PCollection.newBuilder().setCoderId(CODER_SPEC_ID).build()),
- ImmutableMap.of(CODER_SPEC_ID, CODER_SPEC),
- consumers,
- startFunctions::add,
- finishFunctions::add);
-
- verifyZeroInteractions(mockBeamFnDataClient);
-
- CompletableFuture<Void> completionFuture = new CompletableFuture<>();
- when(mockBeamFnDataClient.forInboundConsumer(any(), any(), any(), any()))
- .thenReturn(completionFuture);
- Iterables.getOnlyElement(startFunctions).run();
- verify(mockBeamFnDataClient).forInboundConsumer(
- eq(PORT_SPEC.getApiServiceDescriptor()),
- eq(KV.of(bundleId, BeamFnApi.Target.newBuilder()
- .setPrimitiveTransformReference("pTransformId")
- .setName(outputId)
- .build())),
- eq(CODER),
- consumerCaptor.capture());
-
- consumerCaptor.getValue().accept(valueInGlobalWindow("TestValue"));
- assertThat(outputValues, contains(valueInGlobalWindow("TestValue")));
- outputValues.clear();
-
- assertThat(consumers.keySet(), containsInAnyOrder("outputPC"));
-
- completionFuture.complete(null);
- Iterables.getOnlyElement(finishFunctions).run();
-
- verifyNoMoreInteractions(mockBeamFnDataClient);
- }
-
- @Test
- public void testReuseForMultipleBundles() throws Exception {
- CompletableFuture<Void> bundle1Future = new CompletableFuture<>();
- CompletableFuture<Void> bundle2Future = new CompletableFuture<>();
- when(mockBeamFnDataClient.forInboundConsumer(
- any(),
- any(),
- any(),
- any())).thenReturn(bundle1Future).thenReturn(bundle2Future);
- List<WindowedValue<String>> valuesA = new ArrayList<>();
- List<WindowedValue<String>> valuesB = new ArrayList<>();
-
- AtomicReference<String> bundleId = new AtomicReference<>("0");
- BeamFnDataReadRunner<String> readRunner = new BeamFnDataReadRunner<>(
- FUNCTION_SPEC,
- bundleId::get,
- INPUT_TARGET,
- CODER_SPEC,
- mockBeamFnDataClient,
- ImmutableList.of(valuesA::add, valuesB::add));
-
- // Process for bundle id 0
- readRunner.registerInputLocation();
-
- verify(mockBeamFnDataClient).forInboundConsumer(
- eq(PORT_SPEC.getApiServiceDescriptor()),
- eq(KV.of(bundleId.get(), INPUT_TARGET)),
- eq(CODER),
- consumerCaptor.capture());
-
- executor.submit(new Runnable() {
- @Override
- public void run() {
- // Sleep for some small amount of time simulating the parent blocking
- Uninterruptibles.sleepUninterruptibly(100, TimeUnit.MILLISECONDS);
- try {
- consumerCaptor.getValue().accept(valueInGlobalWindow("ABC"));
- consumerCaptor.getValue().accept(valueInGlobalWindow("DEF"));
- } catch (Exception e) {
- bundle1Future.completeExceptionally(e);
- } finally {
- bundle1Future.complete(null);
- }
- }
- });
-
- readRunner.blockTillReadFinishes();
- assertThat(valuesA, contains(valueInGlobalWindow("ABC"), valueInGlobalWindow("DEF")));
- assertThat(valuesB, contains(valueInGlobalWindow("ABC"), valueInGlobalWindow("DEF")));
-
- // Process for bundle id 1
- bundleId.set("1");
- valuesA.clear();
- valuesB.clear();
- readRunner.registerInputLocation();
-
- verify(mockBeamFnDataClient).forInboundConsumer(
- eq(PORT_SPEC.getApiServiceDescriptor()),
- eq(KV.of(bundleId.get(), INPUT_TARGET)),
- eq(CODER),
- consumerCaptor.capture());
-
- executor.submit(new Runnable() {
- @Override
- public void run() {
- // Sleep for some small amount of time simulating the parent blocking
- Uninterruptibles.sleepUninterruptibly(100, TimeUnit.MILLISECONDS);
- try {
- consumerCaptor.getValue().accept(valueInGlobalWindow("GHI"));
- consumerCaptor.getValue().accept(valueInGlobalWindow("JKL"));
- } catch (Exception e) {
- bundle2Future.completeExceptionally(e);
- } finally {
- bundle2Future.complete(null);
- }
- }
- });
-
- readRunner.blockTillReadFinishes();
- assertThat(valuesA, contains(valueInGlobalWindow("GHI"), valueInGlobalWindow("JKL")));
- assertThat(valuesB, contains(valueInGlobalWindow("GHI"), valueInGlobalWindow("JKL")));
-
- verifyNoMoreInteractions(mockBeamFnDataClient);
- }
-
- @Test
- public void testRegistration() {
- for (Registrar registrar :
- ServiceLoader.load(Registrar.class)) {
- if (registrar instanceof BeamFnDataReadRunner.Registrar) {
- assertThat(registrar.getPTransformRunnerFactories(), IsMapContaining.hasKey(URN));
- return;
- }
- }
- fail("Expected registrar not found.");
- }
-}