You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by mx...@apache.org on 2016/03/23 19:35:07 UTC
[02/10] incubator-beam git commit: Implement
InProcessPipelineRunner#run
Implement InProcessPipelineRunner#run
Appropriately construct an evaluation context and executor, and start
the pipeline when run is called.
Implement InProcessPipelineResult.
Apply PTransform overrides.
Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/158f9f8d
Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/158f9f8d
Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/158f9f8d
Branch: refs/heads/master
Commit: 158f9f8d41c63f5a002c6187f4f05f169579dd6d
Parents: 5ecb7aa
Author: Thomas Groh <tg...@google.com>
Authored: Fri Feb 26 17:30:13 2016 -0800
Committer: Maximilian Michels <mx...@apache.org>
Committed: Wed Mar 23 19:27:51 2016 +0100
----------------------------------------------------------------------
.../CachedThreadPoolExecutorServiceFactory.java | 42 ++++
.../ConsumerTrackingPipelineVisitor.java | 173 ++++++++++++++
.../inprocess/ExecutorServiceFactory.java | 32 +++
.../ExecutorServiceParallelExecutor.java | 2 +-
.../inprocess/GroupByKeyEvaluatorFactory.java | 4 +-
.../inprocess/InProcessPipelineOptions.java | 56 +++++
.../inprocess/InProcessPipelineRunner.java | 228 +++++++++++++++---
.../inprocess/KeyedPValueTrackingVisitor.java | 95 ++++++++
.../ConsumerTrackingPipelineVisitorTest.java | 233 +++++++++++++++++++
.../inprocess/InProcessPipelineRunnerTest.java | 77 ++++++
.../KeyedPValueTrackingVisitorTest.java | 189 +++++++++++++++
11 files changed, 1101 insertions(+), 30 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/158f9f8d/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/CachedThreadPoolExecutorServiceFactory.java
----------------------------------------------------------------------
diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/CachedThreadPoolExecutorServiceFactory.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/CachedThreadPoolExecutorServiceFactory.java
new file mode 100644
index 0000000..3350d2b
--- /dev/null
+++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/CachedThreadPoolExecutorServiceFactory.java
@@ -0,0 +1,42 @@
+/*
+ * Copyright (C) 2016 Google Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"); you may not
+ * use this file except in compliance with the License. You may obtain a copy of
+ * the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations under
+ * the License.
+ */
+package com.google.cloud.dataflow.sdk.runners.inprocess;
+
+import com.google.cloud.dataflow.sdk.options.DefaultValueFactory;
+import com.google.cloud.dataflow.sdk.options.PipelineOptions;
+
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+
+/**
+ * A {@link ExecutorServiceFactory} that produces cached thread pools via
+ * {@link Executors#newCachedThreadPool()}.
+ */
+class CachedThreadPoolExecutorServiceFactory
+ implements DefaultValueFactory<ExecutorServiceFactory>, ExecutorServiceFactory {
+ private static final CachedThreadPoolExecutorServiceFactory INSTANCE =
+ new CachedThreadPoolExecutorServiceFactory();
+
+ @Override
+ public ExecutorServiceFactory create(PipelineOptions options) {
+ return INSTANCE;
+ }
+
+ @Override
+ public ExecutorService create() {
+ return Executors.newCachedThreadPool();
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/158f9f8d/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ConsumerTrackingPipelineVisitor.java
----------------------------------------------------------------------
diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ConsumerTrackingPipelineVisitor.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ConsumerTrackingPipelineVisitor.java
new file mode 100644
index 0000000..c602b23
--- /dev/null
+++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ConsumerTrackingPipelineVisitor.java
@@ -0,0 +1,173 @@
+/*
+ * Copyright (C) 2016 Google Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"); you may not
+ * use this file except in compliance with the License. You may obtain a copy of
+ * the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations under
+ * the License.
+ */
+package com.google.cloud.dataflow.sdk.runners.inprocess;
+
+import static com.google.common.base.Preconditions.checkState;
+
+import com.google.cloud.dataflow.sdk.Pipeline;
+import com.google.cloud.dataflow.sdk.Pipeline.PipelineVisitor;
+import com.google.cloud.dataflow.sdk.runners.PipelineRunner;
+import com.google.cloud.dataflow.sdk.runners.TransformTreeNode;
+import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform;
+import com.google.cloud.dataflow.sdk.transforms.PTransform;
+import com.google.cloud.dataflow.sdk.values.PCollectionView;
+import com.google.cloud.dataflow.sdk.values.PInput;
+import com.google.cloud.dataflow.sdk.values.PValue;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Set;
+
+/**
+ * Tracks the {@link AppliedPTransform AppliedPTransforms} that consume each {@link PValue} in the
+ * {@link Pipeline}. This is used to schedule consuming {@link PTransform PTransforms} to consume
+ * input after the upstream transform has produced and committed output.
+ */
+public class ConsumerTrackingPipelineVisitor implements PipelineVisitor {
+ private Map<PValue, Collection<AppliedPTransform<?, ?, ?>>> valueToConsumers = new HashMap<>();
+ private Collection<AppliedPTransform<?, ?, ?>> rootTransforms = new ArrayList<>();
+ private Collection<PCollectionView<?>> views = new ArrayList<>();
+ private Map<AppliedPTransform<?, ?, ?>, String> stepNames = new HashMap<>();
+ private Set<PValue> toFinalize = new HashSet<>();
+ private int numTransforms = 0;
+ private boolean finalized = false;
+
+ @Override
+ public void enterCompositeTransform(TransformTreeNode node) {
+ checkState(
+ !finalized,
+ "Attempting to traverse a pipeline (node %s) with a %s "
+ + "which has already visited a Pipeline and is finalized",
+ node.getFullName(),
+ ConsumerTrackingPipelineVisitor.class.getSimpleName());
+ }
+
+ @Override
+ public void leaveCompositeTransform(TransformTreeNode node) {
+ checkState(
+ !finalized,
+ "Attempting to traverse a pipeline (node %s) with a %s which is already finalized",
+ node.getFullName(),
+ ConsumerTrackingPipelineVisitor.class.getSimpleName());
+ if (node.isRootNode()) {
+ finalized = true;
+ }
+ }
+
+ @Override
+ public void visitTransform(TransformTreeNode node) {
+ toFinalize.removeAll(node.getInput().expand());
+ AppliedPTransform<?, ?, ?> appliedTransform = getAppliedTransform(node);
+ if (node.getInput().expand().isEmpty()) {
+ rootTransforms.add(appliedTransform);
+ } else {
+ for (PValue value : node.getInput().expand()) {
+ valueToConsumers.get(value).add(appliedTransform);
+ stepNames.put(appliedTransform, genStepName());
+ }
+ }
+ }
+
+ private AppliedPTransform<?, ?, ?> getAppliedTransform(TransformTreeNode node) {
+ @SuppressWarnings({"rawtypes", "unchecked"})
+ AppliedPTransform<?, ?, ?> application = AppliedPTransform.of(
+ node.getFullName(), node.getInput(), node.getOutput(), (PTransform) node.getTransform());
+ return application;
+ }
+
+ @Override
+ public void visitValue(PValue value, TransformTreeNode producer) {
+ toFinalize.add(value);
+ for (PValue expandedValue : value.expand()) {
+ valueToConsumers.put(expandedValue, new ArrayList<AppliedPTransform<?, ?, ?>>());
+ if (expandedValue instanceof PCollectionView) {
+ views.add((PCollectionView<?>) expandedValue);
+ }
+ expandedValue.recordAsOutput(getAppliedTransform(producer));
+ }
+ value.recordAsOutput(getAppliedTransform(producer));
+ }
+
+ private String genStepName() {
+ return String.format("s%s", numTransforms++);
+ }
+
+
+ /**
+ * Returns a mapping of each fully-expanded {@link PValue} to each
+ * {@link AppliedPTransform} that consumes it. For each AppliedPTransform in the collection
+ * returned from {@code getValueToCustomers().get(PValue)},
+ * {@code AppliedPTransform#getInput().expand()} will contain the argument {@link PValue}.
+ */
+ public Map<PValue, Collection<AppliedPTransform<?, ?, ?>>> getValueToConsumers() {
+ checkState(
+ finalized,
+ "Can't call getValueToConsumers before the Pipeline has been completely traversed");
+
+ return valueToConsumers;
+ }
+
+ /**
+ * Returns the mapping for each {@link AppliedPTransform} in the {@link Pipeline} to a unique step
+ * name.
+ */
+ public Map<AppliedPTransform<?, ?, ?>, String> getStepNames() {
+ checkState(
+ finalized, "Can't call getStepNames before the Pipeline has been completely traversed");
+
+ return stepNames;
+ }
+
+ /**
+ * Returns the root transforms of the {@link Pipeline}. A root {@link AppliedPTransform} consumes
+ * a {@link PInput} where the {@link PInput#expand()} returns an empty collection.
+ */
+ public Collection<AppliedPTransform<?, ?, ?>> getRootTransforms() {
+ checkState(
+ finalized,
+ "Can't call getRootTransforms before the Pipeline has been completely traversed");
+
+ return rootTransforms;
+ }
+
+ /**
+ * Returns all of the {@link PCollectionView PCollectionViews} contained in the visited
+ * {@link Pipeline}.
+ */
+ public Collection<PCollectionView<?>> getViews() {
+ checkState(finalized, "Can't call getViews before the Pipeline has been completely traversed");
+
+ return views;
+ }
+
+ /**
+ * Returns all of the {@link PValue PValues} that have been produced but not consumed. These
+ * {@link PValue PValues} should be finalized by the {@link PipelineRunner} before the
+ * {@link Pipeline} is executed.
+ */
+ public Set<PValue> getUnfinalizedPValues() {
+ checkState(
+ finalized,
+ "Can't call getUnfinalizedPValues before the Pipeline has been completely traversed");
+
+ return toFinalize;
+ }
+}
+
+
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/158f9f8d/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ExecutorServiceFactory.java
----------------------------------------------------------------------
diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ExecutorServiceFactory.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ExecutorServiceFactory.java
new file mode 100644
index 0000000..480bcde
--- /dev/null
+++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ExecutorServiceFactory.java
@@ -0,0 +1,32 @@
+/*
+ * Copyright (C) 2016 Google Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"); you may not
+ * use this file except in compliance with the License. You may obtain a copy of
+ * the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations under
+ * the License.
+ */
+package com.google.cloud.dataflow.sdk.runners.inprocess;
+
+import java.util.concurrent.ExecutorService;
+
+/**
+ * A factory that creates {@link ExecutorService ExecutorServices}.
+ * {@link ExecutorService ExecutorServices} created by this factory should be independent of one
+ * another (e.g., if any executor is shut down the remaining executors should continue to process
+ * work).
+ */
+public interface ExecutorServiceFactory {
+ /**
+ * Create a new {@link ExecutorService}.
+ */
+ ExecutorService create();
+}
+
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/158f9f8d/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ExecutorServiceParallelExecutor.java
----------------------------------------------------------------------
diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ExecutorServiceParallelExecutor.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ExecutorServiceParallelExecutor.java
index ae686f2..c72a115 100644
--- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ExecutorServiceParallelExecutor.java
+++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ExecutorServiceParallelExecutor.java
@@ -126,7 +126,7 @@ final class ExecutorServiceParallelExecutor implements InProcessExecutor {
@Nullable final CommittedBundle<T> bundle,
final CompletionCallback onComplete) {
TransformExecutorService transformExecutor;
- if (isKeyed(bundle.getPCollection())) {
+ if (bundle != null && isKeyed(bundle.getPCollection())) {
final StepAndKey stepAndKey =
StepAndKey.of(transform, bundle == null ? null : bundle.getKey());
transformExecutor = getSerialExecutorService(stepAndKey);
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/158f9f8d/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/GroupByKeyEvaluatorFactory.java
----------------------------------------------------------------------
diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/GroupByKeyEvaluatorFactory.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/GroupByKeyEvaluatorFactory.java
index dec78d6..3ec4af1 100644
--- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/GroupByKeyEvaluatorFactory.java
+++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/GroupByKeyEvaluatorFactory.java
@@ -59,7 +59,7 @@ class GroupByKeyEvaluatorFactory implements TransformEvaluatorFactory {
CommittedBundle<?> inputBundle,
InProcessEvaluationContext evaluationContext) {
@SuppressWarnings({"cast", "unchecked", "rawtypes"})
- TransformEvaluator<InputT> evaluator = (TransformEvaluator<InputT>) createEvaluator(
+ TransformEvaluator<InputT> evaluator = createEvaluator(
(AppliedPTransform) application, (CommittedBundle) inputBundle, evaluationContext);
return evaluator;
}
@@ -184,7 +184,7 @@ class GroupByKeyEvaluatorFactory implements TransformEvaluatorFactory {
extends ForwardingPTransform<PCollection<KV<K, V>>, PCollection<KV<K, Iterable<V>>>> {
private final GroupByKey<K, V> original;
- public InProcessGroupByKey(GroupByKey<K, V> from) {
+ private InProcessGroupByKey(GroupByKey<K, V> from) {
this.original = from;
}
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/158f9f8d/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineOptions.java
----------------------------------------------------------------------
diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineOptions.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineOptions.java
index 27e9a4b..5ee0e88 100644
--- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineOptions.java
+++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineOptions.java
@@ -15,20 +15,76 @@
*/
package com.google.cloud.dataflow.sdk.runners.inprocess;
+import com.google.cloud.dataflow.sdk.Pipeline;
import com.google.cloud.dataflow.sdk.options.ApplicationNameOptions;
import com.google.cloud.dataflow.sdk.options.Default;
+import com.google.cloud.dataflow.sdk.options.Description;
+import com.google.cloud.dataflow.sdk.options.Hidden;
import com.google.cloud.dataflow.sdk.options.PipelineOptions;
+import com.google.cloud.dataflow.sdk.options.Validation.Required;
+import com.google.cloud.dataflow.sdk.transforms.PTransform;
+
+import com.fasterxml.jackson.annotation.JsonIgnore;
+
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
/**
* Options that can be used to configure the {@link InProcessPipelineRunner}.
*/
public interface InProcessPipelineOptions extends PipelineOptions, ApplicationNameOptions {
+ /**
+ * Gets the {@link ExecutorServiceFactory} to use to create instances of {@link ExecutorService}
+ * to execute {@link PTransform PTransforms}.
+ *
+ * <p>Note that {@link ExecutorService ExecutorServices} returned by the factory must ensure that
+ * it cannot enter a state in which it will not schedule additional pending work unless currently
+ * scheduled work completes, as this may cause the {@link Pipeline} to cease processing.
+ *
+ * <p>Defaults to a {@link CachedThreadPoolExecutorServiceFactory}, which produces instances of
+ * {@link Executors#newCachedThreadPool()}.
+ */
+ @JsonIgnore
+ @Required
+ @Hidden
+ @Default.InstanceFactory(CachedThreadPoolExecutorServiceFactory.class)
+ ExecutorServiceFactory getExecutorServiceFactory();
+
+ void setExecutorServiceFactory(ExecutorServiceFactory executorService);
+
+ /**
+ * Gets the {@link Clock} used by this pipeline. The clock is used in place of accessing the
+ * system time when time values are required by the evaluator.
+ */
@Default.InstanceFactory(NanosOffsetClock.Factory.class)
+ @JsonIgnore
+ @Required
+ @Hidden
+ @Description(
+ "The processing time source used by the pipeline. When the current time is "
+ + "needed by the evaluator, the result of clock#now() is used.")
Clock getClock();
void setClock(Clock clock);
+ @Default.Boolean(false)
+ @Description(
+ "If the pipeline should shut down producers which have reached the maximum "
+ + "representable watermark. If this is set to true, a pipeline in which all PTransforms "
+ + "have reached the maximum watermark will be shut down, even if there are unbounded "
+ + "sources that could produce additional (late) data. By default, if the pipeline "
+ + "contains any unbounded PCollections, it will run until explicitly shut down.")
boolean isShutdownUnboundedProducersWithMaxWatermark();
void setShutdownUnboundedProducersWithMaxWatermark(boolean shutdown);
+
+ @Default.Boolean(true)
+ @Description(
+ "If the pipeline should block awaiting completion of the pipeline. If set to true, "
+ + "a call to Pipeline#run() will block until all PTransforms are complete. Otherwise, "
+ + "the Pipeline will execute asynchronously. If set to false, the completion of the "
+ + "pipeline can be awaited on by use of InProcessPipelineResult#awaitCompletion().")
+ boolean isBlockOnRun();
+
+ void setBlockOnRun(boolean b);
}
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/158f9f8d/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineRunner.java
----------------------------------------------------------------------
diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineRunner.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineRunner.java
index 32859da..a1c8756 100644
--- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineRunner.java
+++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineRunner.java
@@ -1,5 +1,5 @@
/*
- * Copyright (C) 2015 Google Inc.
+ * Copyright (C) 2016 Google Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
@@ -15,25 +15,46 @@
*/
package com.google.cloud.dataflow.sdk.runners.inprocess;
-import static com.google.common.base.Preconditions.checkArgument;
-
+import com.google.cloud.dataflow.sdk.Pipeline;
+import com.google.cloud.dataflow.sdk.Pipeline.PipelineExecutionException;
+import com.google.cloud.dataflow.sdk.PipelineResult;
import com.google.cloud.dataflow.sdk.annotations.Experimental;
import com.google.cloud.dataflow.sdk.options.PipelineOptions;
+import com.google.cloud.dataflow.sdk.runners.AggregatorPipelineExtractor;
+import com.google.cloud.dataflow.sdk.runners.AggregatorRetrievalException;
+import com.google.cloud.dataflow.sdk.runners.AggregatorValues;
+import com.google.cloud.dataflow.sdk.runners.PipelineRunner;
import com.google.cloud.dataflow.sdk.runners.inprocess.GroupByKeyEvaluatorFactory.InProcessGroupByKey;
-import com.google.cloud.dataflow.sdk.runners.inprocess.ViewEvaluatorFactory.InProcessCreatePCollectionView;
+import com.google.cloud.dataflow.sdk.runners.inprocess.GroupByKeyEvaluatorFactory.InProcessGroupByKeyOnly;
+import com.google.cloud.dataflow.sdk.transforms.Aggregator;
+import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform;
+import com.google.cloud.dataflow.sdk.transforms.Create;
import com.google.cloud.dataflow.sdk.transforms.GroupByKey;
import com.google.cloud.dataflow.sdk.transforms.PTransform;
import com.google.cloud.dataflow.sdk.transforms.View.CreatePCollectionView;
+import com.google.cloud.dataflow.sdk.util.InstanceBuilder;
+import com.google.cloud.dataflow.sdk.util.MapAggregatorValues;
import com.google.cloud.dataflow.sdk.util.TimerInternals.TimerData;
+import com.google.cloud.dataflow.sdk.util.UserCodeException;
import com.google.cloud.dataflow.sdk.util.WindowedValue;
+import com.google.cloud.dataflow.sdk.util.common.Counter;
+import com.google.cloud.dataflow.sdk.util.common.CounterSet;
import com.google.cloud.dataflow.sdk.values.PCollection;
+import com.google.cloud.dataflow.sdk.values.PCollection.IsBounded;
import com.google.cloud.dataflow.sdk.values.PCollectionView;
+import com.google.cloud.dataflow.sdk.values.PInput;
+import com.google.cloud.dataflow.sdk.values.POutput;
+import com.google.cloud.dataflow.sdk.values.PValue;
+import com.google.common.base.Throwables;
import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.ImmutableSet;
import org.joda.time.Instant;
+import java.util.Collection;
+import java.util.HashMap;
import java.util.Map;
-import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ExecutorService;
import javax.annotation.Nullable;
@@ -42,28 +63,25 @@ import javax.annotation.Nullable;
* {@link PCollection PCollections}.
*/
@Experimental
-public class InProcessPipelineRunner {
- @SuppressWarnings({"rawtypes", "unused"})
+public class InProcessPipelineRunner
+ extends PipelineRunner<InProcessPipelineRunner.InProcessPipelineResult> {
+ /**
+ * The default set of transform overrides to use in the {@link InProcessPipelineRunner}.
+ *
+ * <p>A transform override must have a single-argument constructor that takes an instance of the
+ * type of transform it is overriding.
+ */
+ @SuppressWarnings("rawtypes")
private static Map<Class<? extends PTransform>, Class<? extends PTransform>>
defaultTransformOverrides =
ImmutableMap.<Class<? extends PTransform>, Class<? extends PTransform>>builder()
+ .put(Create.Values.class, InProcessCreate.class)
.put(GroupByKey.class, InProcessGroupByKey.class)
- .put(CreatePCollectionView.class, InProcessCreatePCollectionView.class)
+ .put(
+ CreatePCollectionView.class,
+ ViewEvaluatorFactory.InProcessCreatePCollectionView.class)
.build();
- private static Map<Class<?>, TransformEvaluatorFactory> defaultEvaluatorFactories =
- new ConcurrentHashMap<>();
-
- /**
- * Register a default transform evaluator.
- */
- public static <TransformT extends PTransform<?, ?>> void registerTransformEvaluatorFactory(
- Class<TransformT> clazz, TransformEvaluatorFactory evaluator) {
- checkArgument(defaultEvaluatorFactories.put(clazz, evaluator) == null,
- "Defining a default factory %s to evaluate Transforms of type %s multiple times", evaluator,
- clazz);
- }
-
/**
* Part of a {@link PCollection}. Elements are output to a bundle, which will cause them to be
* executed by {@link PTransform PTransforms} that consume the {@link PCollection} this bundle is
@@ -73,7 +91,7 @@ public class InProcessPipelineRunner {
*/
public static interface UncommittedBundle<T> {
/**
- * Returns the PCollection that the elements of this bundle belong to.
+ * Returns the PCollection that the elements of this {@link UncommittedBundle} belong to.
*/
PCollection<T> getPCollection();
@@ -103,14 +121,13 @@ public class InProcessPipelineRunner {
* @param <T> the type of elements contained within this bundle
*/
public static interface CommittedBundle<T> {
-
/**
* Returns the PCollection that the elements of this bundle belong to.
*/
PCollection<T> getPCollection();
/**
- * Returns weather this bundle is keyed. A bundle that is part of a {@link PCollection} that
+ * Returns whether this bundle is keyed. A bundle that is part of a {@link PCollection} that
* occurs after a {@link GroupByKey} is keyed by the result of the last {@link GroupByKey}.
*/
boolean isKeyed();
@@ -119,11 +136,12 @@ public class InProcessPipelineRunner {
* Returns the (possibly null) key that was output in the most recent {@link GroupByKey} in the
* execution of this bundle.
*/
- @Nullable Object getKey();
+ @Nullable
+ Object getKey();
/**
- * @return an {@link Iterable} containing all of the elements that have been added to this
- * {@link CommittedBundle}
+ * Returns an {@link Iterable} containing all of the elements that have been added to this
+ * {@link CommittedBundle}.
*/
Iterable<WindowedValue<T>> getElements();
@@ -166,4 +184,160 @@ public class InProcessPipelineRunner {
public InProcessPipelineOptions getPipelineOptions() {
return options;
}
+
+ @Override
+ public <OutputT extends POutput, InputT extends PInput> OutputT apply(
+ PTransform<InputT, OutputT> transform, InputT input) {
+ Class<?> overrideClass = defaultTransformOverrides.get(transform.getClass());
+ if (overrideClass != null) {
+ // It is the responsibility of whoever constructs overrides to ensure this is type safe.
+ @SuppressWarnings("unchecked")
+ Class<PTransform<InputT, OutputT>> transformClass =
+ (Class<PTransform<InputT, OutputT>>) transform.getClass();
+
+ @SuppressWarnings("unchecked")
+ Class<PTransform<InputT, OutputT>> customTransformClass =
+ (Class<PTransform<InputT, OutputT>>) overrideClass;
+
+ PTransform<InputT, OutputT> customTransform =
+ InstanceBuilder.ofType(customTransformClass)
+ .withArg(transformClass, transform)
+ .build();
+
+ // This overrides the contents of the apply method without changing the TransformTreeNode that
+ // is generated by the PCollection application.
+ return super.apply(customTransform, input);
+ } else {
+ return super.apply(transform, input);
+ }
+ }
+
+ @Override
+ public InProcessPipelineResult run(Pipeline pipeline) {
+ ConsumerTrackingPipelineVisitor consumerTrackingVisitor = new ConsumerTrackingPipelineVisitor();
+ pipeline.traverseTopologically(consumerTrackingVisitor);
+ for (PValue unfinalized : consumerTrackingVisitor.getUnfinalizedPValues()) {
+ unfinalized.finishSpecifying();
+ }
+ @SuppressWarnings("rawtypes")
+ KeyedPValueTrackingVisitor keyedPValueVisitor =
+ KeyedPValueTrackingVisitor.create(
+ ImmutableSet.<Class<? extends PTransform>>of(
+ GroupByKey.class, InProcessGroupByKeyOnly.class));
+ pipeline.traverseTopologically(keyedPValueVisitor);
+
+ InProcessEvaluationContext context =
+ InProcessEvaluationContext.create(
+ getPipelineOptions(),
+ consumerTrackingVisitor.getRootTransforms(),
+ consumerTrackingVisitor.getValueToConsumers(),
+ consumerTrackingVisitor.getStepNames(),
+ consumerTrackingVisitor.getViews());
+
+ // independent executor service for each run
+ ExecutorService executorService =
+ context.getPipelineOptions().getExecutorServiceFactory().create();
+ InProcessExecutor executor =
+ ExecutorServiceParallelExecutor.create(
+ executorService,
+ consumerTrackingVisitor.getValueToConsumers(),
+ keyedPValueVisitor.getKeyedPValues(),
+ TransformEvaluatorRegistry.defaultRegistry(),
+ context);
+ executor.start(consumerTrackingVisitor.getRootTransforms());
+
+ Map<Aggregator<?, ?>, Collection<PTransform<?, ?>>> aggregatorSteps =
+ new AggregatorPipelineExtractor(pipeline).getAggregatorSteps();
+ InProcessPipelineResult result =
+ new InProcessPipelineResult(executor, context, aggregatorSteps);
+ if (options.isBlockOnRun()) {
+ try {
+ result.awaitCompletion();
+ } catch (UserCodeException userException) {
+ throw new PipelineExecutionException(userException.getCause());
+ } catch (Throwable t) {
+ Throwables.propagate(t);
+ }
+ }
+ return result;
+ }
+
+ /**
+ * The result of running a {@link Pipeline} with the {@link InProcessPipelineRunner}.
+ *
+ * Throws {@link UnsupportedOperationException} for all methods.
+ */
+ public static class InProcessPipelineResult implements PipelineResult {
+ private final InProcessExecutor executor;
+ private final InProcessEvaluationContext evaluationContext;
+ private final Map<Aggregator<?, ?>, Collection<PTransform<?, ?>>> aggregatorSteps;
+ private State state;
+
+ private InProcessPipelineResult(
+ InProcessExecutor executor,
+ InProcessEvaluationContext evaluationContext,
+ Map<Aggregator<?, ?>, Collection<PTransform<?, ?>>> aggregatorSteps) {
+ this.executor = executor;
+ this.evaluationContext = evaluationContext;
+ this.aggregatorSteps = aggregatorSteps;
+ // Only ever constructed after the executor has started.
+ this.state = State.RUNNING;
+ }
+
+ @Override
+ public State getState() {
+ return state;
+ }
+
+ @Override
+ public <T> AggregatorValues<T> getAggregatorValues(Aggregator<?, T> aggregator)
+ throws AggregatorRetrievalException {
+ CounterSet counters = evaluationContext.getCounters();
+ Collection<PTransform<?, ?>> steps = aggregatorSteps.get(aggregator);
+ Map<String, T> stepValues = new HashMap<>();
+ for (AppliedPTransform<?, ?, ?> transform : evaluationContext.getSteps()) {
+ if (steps.contains(transform.getTransform())) {
+ String stepName =
+ String.format(
+ "user-%s-%s", evaluationContext.getStepName(transform), aggregator.getName());
+ Counter<T> counter = (Counter<T>) counters.getExistingCounter(stepName);
+ if (counter != null) {
+ stepValues.put(transform.getFullName(), counter.getAggregate());
+ }
+ }
+ }
+ return new MapAggregatorValues<>(stepValues);
+ }
+
+ /**
+ * Blocks until the {@link Pipeline} execution represented by this
+ * {@link InProcessPipelineResult} is complete, returning the terminal state.
+ *
+ * <p>If the pipeline terminates abnormally by throwing an exception, this will rethrow the
+ * exception. Future calls to {@link #getState()} will return
+ * {@link com.google.cloud.dataflow.sdk.PipelineResult.State#FAILED}.
+ *
+ * <p>NOTE: if the {@link Pipeline} contains an {@link IsBounded#UNBOUNDED unbounded}
+ * {@link PCollection}, and the {@link PipelineRunner} was created with
+ * {@link InProcessPipelineOptions#isShutdownUnboundedProducersWithMaxWatermark()} set to false,
+ * this method will never return.
+ *
+ * See also {@link InProcessExecutor#awaitCompletion()}.
+ */
+ public State awaitCompletion() throws Throwable {
+ if (!state.isTerminal()) {
+ try {
+ executor.awaitCompletion();
+ state = State.DONE;
+ } catch (InterruptedException e) {
+ Thread.currentThread().interrupt();
+ throw e;
+ } catch (Throwable t) {
+ state = State.FAILED;
+ throw t;
+ }
+ }
+ return state;
+ }
+ }
}
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/158f9f8d/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/KeyedPValueTrackingVisitor.java
----------------------------------------------------------------------
diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/KeyedPValueTrackingVisitor.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/KeyedPValueTrackingVisitor.java
new file mode 100644
index 0000000..23a8c0f
--- /dev/null
+++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/KeyedPValueTrackingVisitor.java
@@ -0,0 +1,95 @@
+/*
+ * Copyright (C) 2016 Google Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"); you may not
+ * use this file except in compliance with the License. You may obtain a copy of
+ * the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations under
+ * the License.
+ */
+package com.google.cloud.dataflow.sdk.runners.inprocess;
+
+import static com.google.common.base.Preconditions.checkState;
+
+import com.google.cloud.dataflow.sdk.Pipeline.PipelineVisitor;
+import com.google.cloud.dataflow.sdk.runners.TransformTreeNode;
+import com.google.cloud.dataflow.sdk.transforms.GroupByKey;
+import com.google.cloud.dataflow.sdk.transforms.PTransform;
+import com.google.cloud.dataflow.sdk.values.PValue;
+
+import java.util.HashSet;
+import java.util.Set;
+
+/**
+ * A pipeline visitor that tracks all keyed {@link PValue PValues}. A {@link PValue} is keyed if it
+ * is the result of a {@link PTransform} that produces keyed outputs. A {@link PTransform} that
+ * produces keyed outputs is assumed to colocate output elements that share a key.
+ *
+ * <p>All {@link GroupByKey} transforms, or their runner-specific implementation primitive, produce
+ * keyed output.
+ */
+// TODO: Handle Key-preserving transforms when appropriate and more aggressively make PTransforms
+// unkeyed
+class KeyedPValueTrackingVisitor implements PipelineVisitor {
+ @SuppressWarnings("rawtypes")
+ private final Set<Class<? extends PTransform>> producesKeyedOutputs;
+ private final Set<PValue> keyedValues;
+ private boolean finalized;
+
+ public static KeyedPValueTrackingVisitor create(
+ @SuppressWarnings("rawtypes") Set<Class<? extends PTransform>> producesKeyedOutputs) {
+ return new KeyedPValueTrackingVisitor(producesKeyedOutputs);
+ }
+
+ private KeyedPValueTrackingVisitor(
+ @SuppressWarnings("rawtypes") Set<Class<? extends PTransform>> producesKeyedOutputs) {
+ this.producesKeyedOutputs = producesKeyedOutputs;
+ this.keyedValues = new HashSet<>();
+ }
+
+ @Override
+ public void enterCompositeTransform(TransformTreeNode node) {
+ checkState(
+ !finalized,
+ "Attempted to use a %s that has already been finalized on a pipeline (visiting node %s)",
+ KeyedPValueTrackingVisitor.class.getSimpleName(),
+ node);
+ }
+
+ @Override
+ public void leaveCompositeTransform(TransformTreeNode node) {
+ checkState(
+ !finalized,
+ "Attempted to use a %s that has already been finalized on a pipeline (visiting node %s)",
+ KeyedPValueTrackingVisitor.class.getSimpleName(),
+ node);
+ if (node.isRootNode()) {
+ finalized = true;
+ } else if (producesKeyedOutputs.contains(node.getTransform().getClass())) {
+ keyedValues.addAll(node.getExpandedOutputs());
+ }
+ }
+
+ @Override
+ public void visitTransform(TransformTreeNode node) {}
+
+ @Override
+ public void visitValue(PValue value, TransformTreeNode producer) {
+ if (producesKeyedOutputs.contains(producer.getTransform().getClass())) {
+ keyedValues.addAll(value.expand());
+ }
+ }
+
+ public Set<PValue> getKeyedPValues() {
+ checkState(
+ finalized, "can't call getKeyedPValues before a Pipeline has been completely traversed");
+ return keyedValues;
+ }
+}
+
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/158f9f8d/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/ConsumerTrackingPipelineVisitorTest.java
----------------------------------------------------------------------
diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/ConsumerTrackingPipelineVisitorTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/ConsumerTrackingPipelineVisitorTest.java
new file mode 100644
index 0000000..d921f6c
--- /dev/null
+++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/ConsumerTrackingPipelineVisitorTest.java
@@ -0,0 +1,233 @@
+/*
+ * Copyright (C) 2016 Google Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"); you may not
+ * use this file except in compliance with the License. You may obtain a copy of
+ * the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations under
+ * the License.
+ */
+package com.google.cloud.dataflow.sdk.runners.inprocess;
+
+import static org.hamcrest.Matchers.emptyIterable;
+import static org.junit.Assert.assertThat;
+
+import com.google.cloud.dataflow.sdk.io.CountingInput;
+import com.google.cloud.dataflow.sdk.testing.TestPipeline;
+import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform;
+import com.google.cloud.dataflow.sdk.transforms.Create;
+import com.google.cloud.dataflow.sdk.transforms.DoFn;
+import com.google.cloud.dataflow.sdk.transforms.Flatten;
+import com.google.cloud.dataflow.sdk.transforms.PTransform;
+import com.google.cloud.dataflow.sdk.transforms.ParDo;
+import com.google.cloud.dataflow.sdk.transforms.View;
+import com.google.cloud.dataflow.sdk.values.PCollection;
+import com.google.cloud.dataflow.sdk.values.PCollectionList;
+import com.google.cloud.dataflow.sdk.values.PCollectionView;
+import com.google.cloud.dataflow.sdk.values.PDone;
+import com.google.cloud.dataflow.sdk.values.PInput;
+import com.google.cloud.dataflow.sdk.values.PValue;
+
+import org.hamcrest.Matchers;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.ExpectedException;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+import java.io.Serializable;
+import java.util.List;
+
+/**
+ * Tests for {@link ConsumerTrackingPipelineVisitor}.
+ */
+@RunWith(JUnit4.class)
+public class ConsumerTrackingPipelineVisitorTest implements Serializable {
+ @Rule public transient ExpectedException thrown = ExpectedException.none();
+
+ private transient TestPipeline p = TestPipeline.create();
+ private transient ConsumerTrackingPipelineVisitor visitor = new ConsumerTrackingPipelineVisitor();
+
+ @Test
+ public void getViewsReturnsViews() {
+ PCollectionView<List<String>> listView =
+ p.apply("listCreate", Create.of("foo", "bar"))
+ .apply(
+ ParDo.of(
+ new DoFn<String, String>() {
+ @Override
+ public void processElement(DoFn<String, String>.ProcessContext c)
+ throws Exception {
+ c.output(Integer.toString(c.element().length()));
+ }
+ }))
+ .apply(View.<String>asList());
+ PCollectionView<Object> singletonView =
+ p.apply("singletonCreate", Create.<Object>of(1, 2, 3)).apply(View.<Object>asSingleton());
+ p.traverseTopologically(visitor);
+ assertThat(
+ visitor.getViews(),
+ Matchers.<PCollectionView<?>>containsInAnyOrder(listView, singletonView));
+ }
+
+ @Test
+ public void getRootTransformsContainsPBegins() {
+ PCollection<String> created = p.apply(Create.of("foo", "bar"));
+ PCollection<Long> counted = p.apply(CountingInput.upTo(1234L));
+ PCollection<Long> unCounted = p.apply(CountingInput.unbounded());
+ p.traverseTopologically(visitor);
+ assertThat(
+ visitor.getRootTransforms(),
+ Matchers.<AppliedPTransform<?, ?, ?>>containsInAnyOrder(
+ created.getProducingTransformInternal(),
+ counted.getProducingTransformInternal(),
+ unCounted.getProducingTransformInternal()));
+ }
+
+ @Test
+ public void getRootTransformsContainsEmptyFlatten() {
+ PCollection<String> empty =
+ PCollectionList.<String>empty(p).apply(Flatten.<String>pCollections());
+ p.traverseTopologically(visitor);
+ assertThat(
+ visitor.getRootTransforms(),
+ Matchers.<AppliedPTransform<?, ?, ?>>containsInAnyOrder(
+ empty.getProducingTransformInternal()));
+ }
+
+ @Test
+ public void getValueToConsumersSucceeds() {
+ PCollection<String> created = p.apply(Create.of("1", "2", "3"));
+ PCollection<String> transformed =
+ created.apply(
+ ParDo.of(
+ new DoFn<String, String>() {
+ @Override
+ public void processElement(DoFn<String, String>.ProcessContext c)
+ throws Exception {
+ c.output(Integer.toString(c.element().length()));
+ }
+ }));
+
+ PCollection<String> flattened =
+ PCollectionList.of(created).and(transformed).apply(Flatten.<String>pCollections());
+
+ p.traverseTopologically(visitor);
+
+ assertThat(
+ visitor.getValueToConsumers().get(created),
+ Matchers.<AppliedPTransform<?, ?, ?>>containsInAnyOrder(
+ transformed.getProducingTransformInternal(),
+ flattened.getProducingTransformInternal()));
+ assertThat(
+ visitor.getValueToConsumers().get(transformed),
+ Matchers.<AppliedPTransform<?, ?, ?>>containsInAnyOrder(
+ flattened.getProducingTransformInternal()));
+ assertThat(visitor.getValueToConsumers().get(flattened), emptyIterable());
+ }
+
+ @Test
+ public void getUnfinalizedPValuesContainsDanglingOutputs() {
+ PCollection<String> created = p.apply(Create.of("1", "2", "3"));
+ PCollection<String> transformed =
+ created.apply(
+ ParDo.of(
+ new DoFn<String, String>() {
+ @Override
+ public void processElement(DoFn<String, String>.ProcessContext c)
+ throws Exception {
+ c.output(Integer.toString(c.element().length()));
+ }
+ }));
+
+ p.traverseTopologically(visitor);
+ assertThat(visitor.getUnfinalizedPValues(), Matchers.<PValue>contains(transformed));
+ }
+
+ @Test
+ public void getUnfinalizedPValuesEmpty() {
+ p.apply(Create.of("1", "2", "3"))
+ .apply(
+ ParDo.of(
+ new DoFn<String, String>() {
+ @Override
+ public void processElement(DoFn<String, String>.ProcessContext c)
+ throws Exception {
+ c.output(Integer.toString(c.element().length()));
+ }
+ }))
+ .apply(
+ new PTransform<PInput, PDone>() {
+ @Override
+ public PDone apply(PInput input) {
+ return PDone.in(input.getPipeline());
+ }
+ });
+
+ p.traverseTopologically(visitor);
+ assertThat(visitor.getUnfinalizedPValues(), emptyIterable());
+ }
+
+ @Test
+ public void traverseMultipleTimesThrows() {
+ p.apply(Create.of(1, 2, 3));
+
+ p.traverseTopologically(visitor);
+ thrown.expect(IllegalStateException.class);
+ thrown.expectMessage(ConsumerTrackingPipelineVisitor.class.getSimpleName());
+ thrown.expectMessage("is finalized");
+ p.traverseTopologically(visitor);
+ }
+
+ @Test
+ public void traverseIndependentPathsSucceeds() {
+ p.apply("left", Create.of(1, 2, 3));
+ p.apply("right", Create.of("foo", "bar", "baz"));
+
+ p.traverseTopologically(visitor);
+ }
+
+ @Test
+ public void getRootTransformsWithoutVisitingThrows() {
+ thrown.expect(IllegalStateException.class);
+ thrown.expectMessage("completely traversed");
+ thrown.expectMessage("getRootTransforms");
+ visitor.getRootTransforms();
+ }
+ @Test
+ public void getStepNamesWithoutVisitingThrows() {
+ thrown.expect(IllegalStateException.class);
+ thrown.expectMessage("completely traversed");
+ thrown.expectMessage("getStepNames");
+ visitor.getStepNames();
+ }
+ @Test
+ public void getUnfinalizedPValuesWithoutVisitingThrows() {
+ thrown.expect(IllegalStateException.class);
+ thrown.expectMessage("completely traversed");
+ thrown.expectMessage("getUnfinalizedPValues");
+ visitor.getUnfinalizedPValues();
+ }
+
+ @Test
+ public void getValueToConsumersWithoutVisitingThrows() {
+ thrown.expect(IllegalStateException.class);
+ thrown.expectMessage("completely traversed");
+ thrown.expectMessage("getValueToConsumers");
+ visitor.getValueToConsumers();
+ }
+
+ @Test
+ public void getViewsWithoutVisitingThrows() {
+ thrown.expect(IllegalStateException.class);
+ thrown.expectMessage("completely traversed");
+ thrown.expectMessage("getViews");
+ visitor.getViews();
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/158f9f8d/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineRunnerTest.java
----------------------------------------------------------------------
diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineRunnerTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineRunnerTest.java
new file mode 100644
index 0000000..adb64cd
--- /dev/null
+++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineRunnerTest.java
@@ -0,0 +1,77 @@
+/*
+ * Copyright (C) 2015 Google Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"); you may not
+ * use this file except in compliance with the License. You may obtain a copy of
+ * the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations under
+ * the License.
+ */
+package com.google.cloud.dataflow.sdk.runners.inprocess;
+
+import com.google.cloud.dataflow.sdk.Pipeline;
+import com.google.cloud.dataflow.sdk.options.PipelineOptions;
+import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory;
+import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.InProcessPipelineResult;
+import com.google.cloud.dataflow.sdk.testing.DataflowAssert;
+import com.google.cloud.dataflow.sdk.transforms.Count;
+import com.google.cloud.dataflow.sdk.transforms.Create;
+import com.google.cloud.dataflow.sdk.transforms.MapElements;
+import com.google.cloud.dataflow.sdk.transforms.SimpleFunction;
+import com.google.cloud.dataflow.sdk.values.KV;
+import com.google.cloud.dataflow.sdk.values.PCollection;
+
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+import java.io.Serializable;
+
+/**
+ * Tests for basic {@link InProcessPipelineRunner} functionality.
+ */
+@RunWith(JUnit4.class)
+public class InProcessPipelineRunnerTest implements Serializable {
+ @Test
+ public void wordCountShouldSucceed() throws Throwable {
+ Pipeline p = getPipeline();
+
+ PCollection<KV<String, Long>> counts =
+ p.apply(Create.of("foo", "bar", "foo", "baz", "bar", "foo"))
+ .apply(MapElements.via(new SimpleFunction<String, String>() {
+ @Override
+ public String apply(String input) {
+ return input;
+ }
+ }))
+ .apply(Count.<String>perElement());
+ PCollection<String> countStrs =
+ counts.apply(MapElements.via(new SimpleFunction<KV<String, Long>, String>() {
+ @Override
+ public String apply(KV<String, Long> input) {
+ String str = String.format("%s: %s", input.getKey(), input.getValue());
+ return str;
+ }
+ }));
+
+ DataflowAssert.that(countStrs).containsInAnyOrder("baz: 1", "bar: 2", "foo: 3");
+
+ InProcessPipelineResult result = ((InProcessPipelineResult) p.run());
+ result.awaitCompletion();
+ }
+
+ private Pipeline getPipeline() {
+ PipelineOptions opts = PipelineOptionsFactory.create();
+ opts.setRunner(InProcessPipelineRunner.class);
+
+ Pipeline p = Pipeline.create(opts);
+ return p;
+ }
+}
+
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/158f9f8d/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/KeyedPValueTrackingVisitorTest.java
----------------------------------------------------------------------
diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/KeyedPValueTrackingVisitorTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/KeyedPValueTrackingVisitorTest.java
new file mode 100644
index 0000000..0aaccc2
--- /dev/null
+++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/KeyedPValueTrackingVisitorTest.java
@@ -0,0 +1,189 @@
+/*
+ * Copyright (C) 2016 Google Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"); you may not
+ * use this file except in compliance with the License. You may obtain a copy of
+ * the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations under
+ * the License.
+ */
+package com.google.cloud.dataflow.sdk.runners.inprocess;
+
+import static org.hamcrest.Matchers.hasItem;
+import static org.hamcrest.Matchers.not;
+import static org.junit.Assert.assertThat;
+
+import com.google.cloud.dataflow.sdk.Pipeline;
+import com.google.cloud.dataflow.sdk.coders.IterableCoder;
+import com.google.cloud.dataflow.sdk.coders.KvCoder;
+import com.google.cloud.dataflow.sdk.coders.VarIntCoder;
+import com.google.cloud.dataflow.sdk.coders.VoidCoder;
+import com.google.cloud.dataflow.sdk.options.PipelineOptions;
+import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory;
+import com.google.cloud.dataflow.sdk.transforms.Create;
+import com.google.cloud.dataflow.sdk.transforms.DoFn;
+import com.google.cloud.dataflow.sdk.transforms.GroupByKey;
+import com.google.cloud.dataflow.sdk.transforms.Keys;
+import com.google.cloud.dataflow.sdk.transforms.PTransform;
+import com.google.cloud.dataflow.sdk.transforms.ParDo;
+import com.google.cloud.dataflow.sdk.values.KV;
+import com.google.cloud.dataflow.sdk.values.PCollection;
+import com.google.common.collect.ImmutableSet;
+
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.ExpectedException;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+import java.util.Collections;
+import java.util.Set;
+
+/**
+ * Tests for {@link KeyedPValueTrackingVisitor}.
+ */
+@RunWith(JUnit4.class)
+public class KeyedPValueTrackingVisitorTest {
+ @Rule public ExpectedException thrown = ExpectedException.none();
+
+ private KeyedPValueTrackingVisitor visitor;
+ private Pipeline p;
+
+ @Before
+ public void setup() {
+ PipelineOptions options = PipelineOptionsFactory.create();
+
+ p = Pipeline.create(options);
+ @SuppressWarnings("rawtypes")
+ Set<Class<? extends PTransform>> producesKeyed =
+ ImmutableSet.<Class<? extends PTransform>>of(PrimitiveKeyer.class, CompositeKeyer.class);
+ visitor = KeyedPValueTrackingVisitor.create(producesKeyed);
+ }
+
+ @Test
+ public void primitiveProducesKeyedOutputUnkeyedInputKeyedOutput() {
+ PCollection<Integer> keyed =
+ p.apply(Create.<Integer>of(1, 2, 3)).apply(new PrimitiveKeyer<Integer>());
+
+ p.traverseTopologically(visitor);
+ assertThat(visitor.getKeyedPValues(), hasItem(keyed));
+ }
+
+ @Test
+ public void primitiveProducesKeyedOutputKeyedInputKeyedOutut() {
+ PCollection<Integer> keyed =
+ p.apply(Create.<Integer>of(1, 2, 3))
+ .apply("firstKey", new PrimitiveKeyer<Integer>())
+ .apply("secondKey", new PrimitiveKeyer<Integer>());
+
+ p.traverseTopologically(visitor);
+ assertThat(visitor.getKeyedPValues(), hasItem(keyed));
+ }
+
+ @Test
+ public void compositeProducesKeyedOutputUnkeyedInputKeyedOutput() {
+ PCollection<Integer> keyed =
+ p.apply(Create.<Integer>of(1, 2, 3)).apply(new CompositeKeyer<Integer>());
+
+ p.traverseTopologically(visitor);
+ assertThat(visitor.getKeyedPValues(), hasItem(keyed));
+ }
+
+ @Test
+ public void compositeProducesKeyedOutputKeyedInputKeyedOutut() {
+ PCollection<Integer> keyed =
+ p.apply(Create.<Integer>of(1, 2, 3))
+ .apply("firstKey", new CompositeKeyer<Integer>())
+ .apply("secondKey", new CompositeKeyer<Integer>());
+
+ p.traverseTopologically(visitor);
+ assertThat(visitor.getKeyedPValues(), hasItem(keyed));
+ }
+
+
+ @Test
+ public void noInputUnkeyedOutput() {
+ PCollection<KV<Integer, Iterable<Void>>> unkeyed =
+ p.apply(
+ Create.of(KV.<Integer, Iterable<Void>>of(-1, Collections.<Void>emptyList()))
+ .withCoder(KvCoder.of(VarIntCoder.of(), IterableCoder.of(VoidCoder.of()))));
+
+ p.traverseTopologically(visitor);
+ assertThat(visitor.getKeyedPValues(), not(hasItem(unkeyed)));
+ }
+
+ @Test
+ public void keyedInputNotProducesKeyedOutputUnkeyedOutput() {
+ PCollection<Integer> onceKeyed =
+ p.apply(Create.<Integer>of(1, 2, 3))
+ .apply(new PrimitiveKeyer<Integer>())
+ .apply(ParDo.of(new IdentityFn<Integer>()));
+
+ p.traverseTopologically(visitor);
+ assertThat(visitor.getKeyedPValues(), not(hasItem(onceKeyed)));
+ }
+
+ @Test
+ public void unkeyedInputNotProducesKeyedOutputUnkeyedOutput() {
+ PCollection<Integer> unkeyed =
+ p.apply(Create.<Integer>of(1, 2, 3)).apply(ParDo.of(new IdentityFn<Integer>()));
+
+ p.traverseTopologically(visitor);
+ assertThat(visitor.getKeyedPValues(), not(hasItem(unkeyed)));
+ }
+
+ @Test
+ public void traverseMultipleTimesThrows() {
+ p.apply(
+ Create.<KV<Integer, Void>>of(
+ KV.of(1, (Void) null), KV.of(2, (Void) null), KV.of(3, (Void) null))
+ .withCoder(KvCoder.of(VarIntCoder.of(), VoidCoder.of())))
+ .apply(GroupByKey.<Integer, Void>create())
+ .apply(Keys.<Integer>create());
+
+ p.traverseTopologically(visitor);
+
+ thrown.expect(IllegalStateException.class);
+ thrown.expectMessage("already been finalized");
+ thrown.expectMessage(KeyedPValueTrackingVisitor.class.getSimpleName());
+ p.traverseTopologically(visitor);
+ }
+
+ @Test
+ public void getKeyedPValuesBeforeTraverseThrows() {
+ thrown.expect(IllegalStateException.class);
+ thrown.expectMessage("completely traversed");
+ thrown.expectMessage("getKeyedPValues");
+ visitor.getKeyedPValues();
+ }
+
+ private static class PrimitiveKeyer<K> extends PTransform<PCollection<K>, PCollection<K>> {
+ @Override
+ public PCollection<K> apply(PCollection<K> input) {
+ return PCollection.<K>createPrimitiveOutputInternal(
+ input.getPipeline(), input.getWindowingStrategy(), input.isBounded())
+ .setCoder(input.getCoder());
+ }
+ }
+
+ private static class CompositeKeyer<K> extends PTransform<PCollection<K>, PCollection<K>> {
+ @Override
+ public PCollection<K> apply(PCollection<K> input) {
+ return input.apply(new PrimitiveKeyer<K>()).apply(ParDo.of(new IdentityFn<K>()));
+ }
+ }
+
+ private static class IdentityFn<K> extends DoFn<K, K> {
+ @Override
+ public void processElement(DoFn<K, K>.ProcessContext c) throws Exception {
+ c.output(c.element());
+ }
+ }
+}