You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by ke...@apache.org on 2016/08/10 18:34:42 UTC
[3/4] incubator-beam git commit: Move aggregator support classes out
of runners namespace, make private
Move aggregator support classes out of runners namespace, make private
Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/adec254d
Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/adec254d
Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/adec254d
Branch: refs/heads/master
Commit: adec254d5fdb409e786a1fc2bcee38f8a7a04408
Parents: 9da4bbc
Author: Kenneth Knowles <kl...@google.com>
Authored: Fri Jul 1 14:56:20 2016 -0700
Committer: Kenneth Knowles <kl...@google.com>
Committed: Wed Aug 10 11:34:03 2016 -0700
----------------------------------------------------------------------
.../beam/runners/direct/DirectRunner.java | 7 +-
.../beam/runners/flink/FlinkRunnerResult.java | 4 +-
.../runners/dataflow/DataflowPipelineJob.java | 4 +-
.../beam/runners/dataflow/DataflowRunner.java | 4 +-
.../dataflow/DataflowPipelineJobTest.java | 4 +-
.../spark/translation/EvaluationContext.java | 4 +-
.../spark/translation/SparkRuntimeContext.java | 2 +-
.../translation/MultiOutputWordCountTest.java | 2 +-
.../beam/sdk/AggregatorPipelineExtractor.java | 93 ++++++++
.../beam/sdk/AggregatorRetrievalException.java | 33 +++
.../org/apache/beam/sdk/AggregatorValues.java | 52 +++++
.../main/java/org/apache/beam/sdk/Pipeline.java | 10 +
.../org/apache/beam/sdk/PipelineResult.java | 2 -
.../runners/AggregatorPipelineExtractor.java | 93 --------
.../runners/AggregatorRetrievalException.java | 33 ---
.../beam/sdk/runners/AggregatorValues.java | 52 -----
.../sdk/AggregatorPipelineExtractorTest.java | 229 +++++++++++++++++++
.../AggregatorPipelineExtractorTest.java | 229 -------------------
.../apache/beam/sdk/transforms/DoFnTest.java | 1 +
.../apache/beam/sdk/transforms/OldDoFnTest.java | 3 +-
20 files changed, 434 insertions(+), 427 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/adec254d/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java
index a9c8ecb..f2b781e 100644
--- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java
@@ -20,15 +20,14 @@ package org.apache.beam.runners.direct;
import org.apache.beam.runners.direct.DirectGroupByKey.DirectGroupByKeyOnly;
import org.apache.beam.runners.direct.DirectRunner.DirectPipelineResult;
import org.apache.beam.runners.direct.ViewEvaluatorFactory.ViewOverrideFactory;
+import org.apache.beam.sdk.AggregatorRetrievalException;
+import org.apache.beam.sdk.AggregatorValues;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.Pipeline.PipelineExecutionException;
import org.apache.beam.sdk.PipelineResult;
import org.apache.beam.sdk.annotations.Experimental;
import org.apache.beam.sdk.io.Write;
import org.apache.beam.sdk.options.PipelineOptions;
-import org.apache.beam.sdk.runners.AggregatorPipelineExtractor;
-import org.apache.beam.sdk.runners.AggregatorRetrievalException;
-import org.apache.beam.sdk.runners.AggregatorValues;
import org.apache.beam.sdk.runners.PipelineRunner;
import org.apache.beam.sdk.transforms.Aggregator;
import org.apache.beam.sdk.transforms.AppliedPTransform;
@@ -244,7 +243,7 @@ public class DirectRunner
executor.start(consumerTrackingVisitor.getRootTransforms());
Map<Aggregator<?, ?>, Collection<PTransform<?, ?>>> aggregatorSteps =
- new AggregatorPipelineExtractor(pipeline).getAggregatorSteps();
+ pipeline.getAggregatorSteps();
DirectPipelineResult result =
new DirectPipelineResult(executor, context, aggregatorSteps);
if (options.isBlockOnRun()) {
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/adec254d/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/FlinkRunnerResult.java
----------------------------------------------------------------------
diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/FlinkRunnerResult.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/FlinkRunnerResult.java
index cae0b2a..923d54c 100644
--- a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/FlinkRunnerResult.java
+++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/FlinkRunnerResult.java
@@ -18,8 +18,8 @@
package org.apache.beam.runners.flink;
import org.apache.beam.sdk.PipelineResult;
-import org.apache.beam.sdk.runners.AggregatorRetrievalException;
-import org.apache.beam.sdk.runners.AggregatorValues;
+import org.apache.beam.sdk.AggregatorRetrievalException;
+import org.apache.beam.sdk.AggregatorValues;
import org.apache.beam.sdk.transforms.Aggregator;
import org.joda.time.Duration;
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/adec254d/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineJob.java
----------------------------------------------------------------------
diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineJob.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineJob.java
index a6baa4f..e043e23 100644
--- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineJob.java
+++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineJob.java
@@ -23,9 +23,9 @@ import org.apache.beam.runners.dataflow.internal.DataflowAggregatorTransforms;
import org.apache.beam.runners.dataflow.internal.DataflowMetricUpdateExtractor;
import org.apache.beam.runners.dataflow.options.DataflowPipelineOptions;
import org.apache.beam.runners.dataflow.util.MonitoringUtil;
+import org.apache.beam.sdk.AggregatorRetrievalException;
+import org.apache.beam.sdk.AggregatorValues;
import org.apache.beam.sdk.PipelineResult;
-import org.apache.beam.sdk.runners.AggregatorRetrievalException;
-import org.apache.beam.sdk.runners.AggregatorValues;
import org.apache.beam.sdk.transforms.Aggregator;
import org.apache.beam.sdk.util.AttemptAndTimeBoundedExponentialBackOff;
import org.apache.beam.sdk.util.AttemptBoundedExponentialBackOff;
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/adec254d/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java
----------------------------------------------------------------------
diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java
index fadd9c7..3b68e92 100644
--- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java
+++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java
@@ -71,7 +71,6 @@ import org.apache.beam.sdk.io.Write;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.options.PipelineOptionsValidator;
import org.apache.beam.sdk.options.StreamingOptions;
-import org.apache.beam.sdk.runners.AggregatorPipelineExtractor;
import org.apache.beam.sdk.runners.PipelineRunner;
import org.apache.beam.sdk.runners.TransformTreeNode;
import org.apache.beam.sdk.transforms.Aggregator;
@@ -596,9 +595,8 @@ public class DataflowRunner extends PipelineRunner<DataflowPipelineJob> {
// Obtain all of the extractors from the PTransforms used in the pipeline so the
// DataflowPipelineJob has access to them.
- AggregatorPipelineExtractor aggregatorExtractor = new AggregatorPipelineExtractor(pipeline);
Map<Aggregator<?, ?>, Collection<PTransform<?, ?>>> aggregatorSteps =
- aggregatorExtractor.getAggregatorSteps();
+ pipeline.getAggregatorSteps();
DataflowAggregatorTransforms aggregatorTransforms =
new DataflowAggregatorTransforms(aggregatorSteps, jobSpecification.getStepNames());
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/adec254d/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineJobTest.java
----------------------------------------------------------------------
diff --git a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineJobTest.java b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineJobTest.java
index 343d538..e6277d9 100644
--- a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineJobTest.java
+++ b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineJobTest.java
@@ -35,10 +35,10 @@ import static org.mockito.Mockito.when;
import org.apache.beam.runners.dataflow.internal.DataflowAggregatorTransforms;
import org.apache.beam.runners.dataflow.testing.TestDataflowPipelineOptions;
import org.apache.beam.runners.dataflow.util.MonitoringUtil;
+import org.apache.beam.sdk.AggregatorRetrievalException;
+import org.apache.beam.sdk.AggregatorValues;
import org.apache.beam.sdk.PipelineResult.State;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
-import org.apache.beam.sdk.runners.AggregatorRetrievalException;
-import org.apache.beam.sdk.runners.AggregatorValues;
import org.apache.beam.sdk.testing.FastNanoClockAndSleeper;
import org.apache.beam.sdk.transforms.Aggregator;
import org.apache.beam.sdk.transforms.AppliedPTransform;
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/adec254d/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java
index 169c2af..4ccac0e 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java
@@ -22,10 +22,10 @@ import static com.google.common.base.Preconditions.checkArgument;
import org.apache.beam.runners.spark.EvaluationResult;
import org.apache.beam.runners.spark.coders.CoderHelpers;
+import org.apache.beam.sdk.AggregatorRetrievalException;
+import org.apache.beam.sdk.AggregatorValues;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.coders.Coder;
-import org.apache.beam.sdk.runners.AggregatorRetrievalException;
-import org.apache.beam.sdk.runners.AggregatorValues;
import org.apache.beam.sdk.transforms.Aggregator;
import org.apache.beam.sdk.transforms.AppliedPTransform;
import org.apache.beam.sdk.transforms.PTransform;
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/adec254d/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkRuntimeContext.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkRuntimeContext.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkRuntimeContext.java
index 46f5b33..c2edd02 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkRuntimeContext.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkRuntimeContext.java
@@ -20,12 +20,12 @@ package org.apache.beam.runners.spark.translation;
import org.apache.beam.runners.spark.aggregators.AggAccumParam;
import org.apache.beam.runners.spark.aggregators.NamedAggregators;
+import org.apache.beam.sdk.AggregatorValues;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.coders.CannotProvideCoderException;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.CoderRegistry;
import org.apache.beam.sdk.options.PipelineOptions;
-import org.apache.beam.sdk.runners.AggregatorValues;
import org.apache.beam.sdk.transforms.Aggregator;
import org.apache.beam.sdk.transforms.Combine;
import org.apache.beam.sdk.transforms.Max;
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/adec254d/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/MultiOutputWordCountTest.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/MultiOutputWordCountTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/MultiOutputWordCountTest.java
index 291f7b2..0d0c0b4 100644
--- a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/MultiOutputWordCountTest.java
+++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/MultiOutputWordCountTest.java
@@ -20,11 +20,11 @@ package org.apache.beam.runners.spark.translation;
import org.apache.beam.runners.spark.EvaluationResult;
import org.apache.beam.runners.spark.SparkRunner;
+import org.apache.beam.sdk.AggregatorValues;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
-import org.apache.beam.sdk.runners.AggregatorValues;
import org.apache.beam.sdk.testing.PAssert;
import org.apache.beam.sdk.transforms.Aggregator;
import org.apache.beam.sdk.transforms.ApproximateUnique;
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/adec254d/sdks/java/core/src/main/java/org/apache/beam/sdk/AggregatorPipelineExtractor.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/AggregatorPipelineExtractor.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/AggregatorPipelineExtractor.java
new file mode 100644
index 0000000..ac215c9
--- /dev/null
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/AggregatorPipelineExtractor.java
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk;
+
+import org.apache.beam.sdk.Pipeline.PipelineVisitor;
+import org.apache.beam.sdk.runners.TransformTreeNode;
+import org.apache.beam.sdk.transforms.Aggregator;
+import org.apache.beam.sdk.transforms.AggregatorRetriever;
+import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.values.PValue;
+
+import com.google.common.collect.HashMultimap;
+import com.google.common.collect.SetMultimap;
+
+import java.util.Collection;
+import java.util.Collections;
+import java.util.Map;
+
+/**
+ * Retrieves {@link Aggregator Aggregators} at each {@link ParDo} and returns a {@link Map} of
+ * {@link Aggregator} to the {@link PTransform PTransforms} in which it is present.
+ */
+class AggregatorPipelineExtractor {
+ private final Pipeline pipeline;
+
+ /**
+ * Creates an {@code AggregatorPipelineExtractor} for the given {@link Pipeline}.
+ */
+ public AggregatorPipelineExtractor(Pipeline pipeline) {
+ this.pipeline = pipeline;
+ }
+
+ /**
+ * Returns a {@link Map} between each {@link Aggregator} in the {@link Pipeline} to the {@link
+ * PTransform PTransforms} in which it is used.
+ */
+ public Map<Aggregator<?, ?>, Collection<PTransform<?, ?>>> getAggregatorSteps() {
+ HashMultimap<Aggregator<?, ?>, PTransform<?, ?>> aggregatorSteps = HashMultimap.create();
+ pipeline.traverseTopologically(new AggregatorVisitor(aggregatorSteps));
+ return aggregatorSteps.asMap();
+ }
+
+ private static class AggregatorVisitor extends PipelineVisitor.Defaults {
+ private final SetMultimap<Aggregator<?, ?>, PTransform<?, ?>> aggregatorSteps;
+
+ public AggregatorVisitor(SetMultimap<Aggregator<?, ?>, PTransform<?, ?>> aggregatorSteps) {
+ this.aggregatorSteps = aggregatorSteps;
+ }
+
+ @Override
+ public void visitPrimitiveTransform(TransformTreeNode node) {
+ PTransform<?, ?> transform = node.getTransform();
+ addStepToAggregators(transform, getAggregators(transform));
+ }
+
+ private Collection<Aggregator<?, ?>> getAggregators(PTransform<?, ?> transform) {
+ if (transform != null) {
+ if (transform instanceof ParDo.Bound) {
+ return AggregatorRetriever.getAggregators(((ParDo.Bound<?, ?>) transform).getFn());
+ } else if (transform instanceof ParDo.BoundMulti) {
+ return AggregatorRetriever.getAggregators(((ParDo.BoundMulti<?, ?>) transform).getFn());
+ }
+ }
+ return Collections.emptyList();
+ }
+
+ private void addStepToAggregators(
+ PTransform<?, ?> transform, Collection<Aggregator<?, ?>> aggregators) {
+ for (Aggregator<?, ?> aggregator : aggregators) {
+ aggregatorSteps.put(aggregator, transform);
+ }
+ }
+
+ @Override
+ public void visitValue(PValue value, TransformTreeNode producer) {}
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/adec254d/sdks/java/core/src/main/java/org/apache/beam/sdk/AggregatorRetrievalException.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/AggregatorRetrievalException.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/AggregatorRetrievalException.java
new file mode 100644
index 0000000..3040815
--- /dev/null
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/AggregatorRetrievalException.java
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk;
+
+import org.apache.beam.sdk.transforms.Aggregator;
+
+/**
+ * Signals that an exception has occurred while retrieving {@link Aggregator}s.
+ */
+public class AggregatorRetrievalException extends Exception {
+ /**
+ * Constructs a new {@code AggregatorRetrievalException} with the specified detail message and
+ * cause.
+ */
+ public AggregatorRetrievalException(String message, Throwable cause) {
+ super(message, cause);
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/adec254d/sdks/java/core/src/main/java/org/apache/beam/sdk/AggregatorValues.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/AggregatorValues.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/AggregatorValues.java
new file mode 100644
index 0000000..efaad85
--- /dev/null
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/AggregatorValues.java
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk;
+
+import org.apache.beam.sdk.transforms.Aggregator;
+import org.apache.beam.sdk.transforms.Combine.CombineFn;
+import org.apache.beam.sdk.transforms.OldDoFn;
+
+import java.util.Collection;
+import java.util.Map;
+
+/**
+ * A collection of values associated with an {@link Aggregator}. Aggregators declared in a
+ * {@link OldDoFn} are emitted on a per-{@code OldDoFn}-application basis.
+ *
+ * @param <T> the output type of the aggregator
+ */
+public abstract class AggregatorValues<T> {
+ /**
+ * Get the values of the {@link Aggregator} at all steps it was used.
+ */
+ public Collection<T> getValues() {
+ return getValuesAtSteps().values();
+ }
+
+ /**
+ * Get the values of the {@link Aggregator} by the user name at each step it was used.
+ */
+ public abstract Map<String, T> getValuesAtSteps();
+
+ /**
+ * Get the total value of this {@link Aggregator} by applying the specified {@link CombineFn}.
+ */
+ public T getTotalValue(CombineFn<T, ?, T> combineFn) {
+ return combineFn.apply(getValues());
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/adec254d/sdks/java/core/src/main/java/org/apache/beam/sdk/Pipeline.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/Pipeline.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/Pipeline.java
index e4f3e4a..1bbc56f 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/Pipeline.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/Pipeline.java
@@ -26,6 +26,7 @@ import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.runners.PipelineRunner;
import org.apache.beam.sdk.runners.TransformHierarchy;
import org.apache.beam.sdk.runners.TransformTreeNode;
+import org.apache.beam.sdk.transforms.Aggregator;
import org.apache.beam.sdk.transforms.AppliedPTransform;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.PTransform;
@@ -47,6 +48,7 @@ import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
+import java.util.Map;
import java.util.Set;
/**
@@ -518,6 +520,14 @@ public class Pipeline {
}
/**
+ * Returns a {@link Map} from each {@link Aggregator} in the {@link Pipeline} to the {@link
+ * PTransform PTransforms} in which it is used.
+ */
+ public Map<Aggregator<?, ?>, Collection<PTransform<?, ?>>> getAggregatorSteps() {
+ return new AggregatorPipelineExtractor(this).getAggregatorSteps();
+ }
+
+ /**
* Builds a name from a "/"-delimited prefix and a name.
*/
private String buildName(String namePrefix, String name) {
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/adec254d/sdks/java/core/src/main/java/org/apache/beam/sdk/PipelineResult.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/PipelineResult.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/PipelineResult.java
index 993962c..edfc924 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/PipelineResult.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/PipelineResult.java
@@ -17,8 +17,6 @@
*/
package org.apache.beam.sdk;
-import org.apache.beam.sdk.runners.AggregatorRetrievalException;
-import org.apache.beam.sdk.runners.AggregatorValues;
import org.apache.beam.sdk.transforms.Aggregator;
import org.joda.time.Duration;
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/adec254d/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/AggregatorPipelineExtractor.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/AggregatorPipelineExtractor.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/AggregatorPipelineExtractor.java
deleted file mode 100644
index 146ddfa..0000000
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/AggregatorPipelineExtractor.java
+++ /dev/null
@@ -1,93 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.beam.sdk.runners;
-
-import org.apache.beam.sdk.Pipeline;
-import org.apache.beam.sdk.Pipeline.PipelineVisitor;
-import org.apache.beam.sdk.transforms.Aggregator;
-import org.apache.beam.sdk.transforms.AggregatorRetriever;
-import org.apache.beam.sdk.transforms.PTransform;
-import org.apache.beam.sdk.transforms.ParDo;
-import org.apache.beam.sdk.values.PValue;
-
-import com.google.common.collect.HashMultimap;
-import com.google.common.collect.SetMultimap;
-
-import java.util.Collection;
-import java.util.Collections;
-import java.util.Map;
-
-/**
- * Retrieves {@link Aggregator Aggregators} at each {@link ParDo} and returns a {@link Map} of
- * {@link Aggregator} to the {@link PTransform PTransforms} in which it is present.
- */
-public class AggregatorPipelineExtractor {
- private final Pipeline pipeline;
-
- /**
- * Creates an {@code AggregatorPipelineExtractor} for the given {@link Pipeline}.
- */
- public AggregatorPipelineExtractor(Pipeline pipeline) {
- this.pipeline = pipeline;
- }
-
- /**
- * Returns a {@link Map} between each {@link Aggregator} in the {@link Pipeline} to the {@link
- * PTransform PTransforms} in which it is used.
- */
- public Map<Aggregator<?, ?>, Collection<PTransform<?, ?>>> getAggregatorSteps() {
- HashMultimap<Aggregator<?, ?>, PTransform<?, ?>> aggregatorSteps = HashMultimap.create();
- pipeline.traverseTopologically(new AggregatorVisitor(aggregatorSteps));
- return aggregatorSteps.asMap();
- }
-
- private static class AggregatorVisitor extends PipelineVisitor.Defaults {
- private final SetMultimap<Aggregator<?, ?>, PTransform<?, ?>> aggregatorSteps;
-
- public AggregatorVisitor(SetMultimap<Aggregator<?, ?>, PTransform<?, ?>> aggregatorSteps) {
- this.aggregatorSteps = aggregatorSteps;
- }
-
- @Override
- public void visitPrimitiveTransform(TransformTreeNode node) {
- PTransform<?, ?> transform = node.getTransform();
- addStepToAggregators(transform, getAggregators(transform));
- }
-
- private Collection<Aggregator<?, ?>> getAggregators(PTransform<?, ?> transform) {
- if (transform != null) {
- if (transform instanceof ParDo.Bound) {
- return AggregatorRetriever.getAggregators(((ParDo.Bound<?, ?>) transform).getFn());
- } else if (transform instanceof ParDo.BoundMulti) {
- return AggregatorRetriever.getAggregators(((ParDo.BoundMulti<?, ?>) transform).getFn());
- }
- }
- return Collections.emptyList();
- }
-
- private void addStepToAggregators(
- PTransform<?, ?> transform, Collection<Aggregator<?, ?>> aggregators) {
- for (Aggregator<?, ?> aggregator : aggregators) {
- aggregatorSteps.put(aggregator, transform);
- }
- }
-
- @Override
- public void visitValue(PValue value, TransformTreeNode producer) {}
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/adec254d/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/AggregatorRetrievalException.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/AggregatorRetrievalException.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/AggregatorRetrievalException.java
deleted file mode 100644
index a0973c3..0000000
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/AggregatorRetrievalException.java
+++ /dev/null
@@ -1,33 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.beam.sdk.runners;
-
-import org.apache.beam.sdk.transforms.Aggregator;
-
-/**
- * Signals that an exception has occurred while retrieving {@link Aggregator}s.
- */
-public class AggregatorRetrievalException extends Exception {
- /**
- * Constructs a new {@code AggregatorRetrievalException} with the specified detail message and
- * cause.
- */
- public AggregatorRetrievalException(String message, Throwable cause) {
- super(message, cause);
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/adec254d/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/AggregatorValues.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/AggregatorValues.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/AggregatorValues.java
deleted file mode 100644
index 6f6836e..0000000
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/AggregatorValues.java
+++ /dev/null
@@ -1,52 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.beam.sdk.runners;
-
-import org.apache.beam.sdk.transforms.Aggregator;
-import org.apache.beam.sdk.transforms.Combine.CombineFn;
-import org.apache.beam.sdk.transforms.OldDoFn;
-
-import java.util.Collection;
-import java.util.Map;
-
-/**
- * A collection of values associated with an {@link Aggregator}. Aggregators declared in a
- * {@link OldDoFn} are emitted on a per-{@code OldDoFn}-application basis.
- *
- * @param <T> the output type of the aggregator
- */
-public abstract class AggregatorValues<T> {
- /**
- * Get the values of the {@link Aggregator} at all steps it was used.
- */
- public Collection<T> getValues() {
- return getValuesAtSteps().values();
- }
-
- /**
- * Get the values of the {@link Aggregator} by the user name at each step it was used.
- */
- public abstract Map<String, T> getValuesAtSteps();
-
- /**
- * Get the total value of this {@link Aggregator} by applying the specified {@link CombineFn}.
- */
- public T getTotalValue(CombineFn<T, ?, T> combineFn) {
- return combineFn.apply(getValues());
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/adec254d/sdks/java/core/src/test/java/org/apache/beam/sdk/AggregatorPipelineExtractorTest.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/AggregatorPipelineExtractorTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/AggregatorPipelineExtractorTest.java
new file mode 100644
index 0000000..930fbe7
--- /dev/null
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/AggregatorPipelineExtractorTest.java
@@ -0,0 +1,229 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.fail;
+import static org.mockito.Mockito.doAnswer;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+import org.apache.beam.sdk.Pipeline.PipelineVisitor;
+import org.apache.beam.sdk.runners.TransformTreeNode;
+import org.apache.beam.sdk.transforms.Aggregator;
+import org.apache.beam.sdk.transforms.Combine.CombineFn;
+import org.apache.beam.sdk.transforms.Max;
+import org.apache.beam.sdk.transforms.Min;
+import org.apache.beam.sdk.transforms.OldDoFn;
+import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.transforms.Sum;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableSet;
+
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+import org.mockito.Mock;
+import org.mockito.Mockito;
+import org.mockito.MockitoAnnotations;
+import org.mockito.invocation.InvocationOnMock;
+import org.mockito.stubbing.Answer;
+
+import java.util.Collection;
+import java.util.List;
+import java.util.Map;
+import java.util.UUID;
+
+/**
+ * Tests for {@link AggregatorPipelineExtractor}.
+ */
+@RunWith(JUnit4.class)
+public class AggregatorPipelineExtractorTest {
+ @Mock
+ private Pipeline p;
+
+ @Before
+ public void setup() {
+ MockitoAnnotations.initMocks(this);
+ }
+
+ @SuppressWarnings("unchecked")
+ @Test
+ public void testGetAggregatorStepsWithParDoBoundExtractsSteps() {
+ @SuppressWarnings("rawtypes")
+ ParDo.Bound bound = mock(ParDo.Bound.class, "Bound");
+ AggregatorProvidingDoFn<ThreadGroup, StrictMath> fn = new AggregatorProvidingDoFn<>();
+ when(bound.getFn()).thenReturn(fn);
+
+ Aggregator<Long, Long> aggregatorOne = fn.addAggregator(new Sum.SumLongFn());
+ Aggregator<Integer, Integer> aggregatorTwo = fn.addAggregator(new Min.MinIntegerFn());
+
+ TransformTreeNode transformNode = mock(TransformTreeNode.class);
+ when(transformNode.getTransform()).thenReturn(bound);
+
+ doAnswer(new VisitNodesAnswer(ImmutableList.of(transformNode)))
+ .when(p)
+ .traverseTopologically(Mockito.any(PipelineVisitor.class));
+
+ AggregatorPipelineExtractor extractor = new AggregatorPipelineExtractor(p);
+
+ Map<Aggregator<?, ?>, Collection<PTransform<?, ?>>> aggregatorSteps =
+ extractor.getAggregatorSteps();
+
+ assertEquals(ImmutableSet.<PTransform<?, ?>>of(bound), aggregatorSteps.get(aggregatorOne));
+ assertEquals(ImmutableSet.<PTransform<?, ?>>of(bound), aggregatorSteps.get(aggregatorTwo));
+ assertEquals(aggregatorSteps.size(), 2);
+ }
+
+ @SuppressWarnings("unchecked")
+ @Test
+ public void testGetAggregatorStepsWithParDoBoundMultiExtractsSteps() {
+ @SuppressWarnings("rawtypes")
+ ParDo.BoundMulti bound = mock(ParDo.BoundMulti.class, "BoundMulti");
+ AggregatorProvidingDoFn<Object, Void> fn = new AggregatorProvidingDoFn<>();
+ when(bound.getFn()).thenReturn(fn);
+
+ Aggregator<Long, Long> aggregatorOne = fn.addAggregator(new Max.MaxLongFn());
+ Aggregator<Double, Double> aggregatorTwo = fn.addAggregator(new Min.MinDoubleFn());
+
+ TransformTreeNode transformNode = mock(TransformTreeNode.class);
+ when(transformNode.getTransform()).thenReturn(bound);
+
+ doAnswer(new VisitNodesAnswer(ImmutableList.of(transformNode)))
+ .when(p)
+ .traverseTopologically(Mockito.any(PipelineVisitor.class));
+
+ AggregatorPipelineExtractor extractor = new AggregatorPipelineExtractor(p);
+
+ Map<Aggregator<?, ?>, Collection<PTransform<?, ?>>> aggregatorSteps =
+ extractor.getAggregatorSteps();
+
+ assertEquals(ImmutableSet.<PTransform<?, ?>>of(bound), aggregatorSteps.get(aggregatorOne));
+ assertEquals(ImmutableSet.<PTransform<?, ?>>of(bound), aggregatorSteps.get(aggregatorTwo));
+ assertEquals(2, aggregatorSteps.size());
+ }
+
+ @SuppressWarnings("unchecked")
+ @Test
+ public void testGetAggregatorStepsWithOneAggregatorInMultipleStepsAddsSteps() {
+ @SuppressWarnings("rawtypes")
+ ParDo.Bound bound = mock(ParDo.Bound.class, "Bound");
+ @SuppressWarnings("rawtypes")
+ ParDo.BoundMulti otherBound = mock(ParDo.BoundMulti.class, "otherBound");
+ AggregatorProvidingDoFn<String, Math> fn = new AggregatorProvidingDoFn<>();
+ when(bound.getFn()).thenReturn(fn);
+ when(otherBound.getFn()).thenReturn(fn);
+
+ Aggregator<Long, Long> aggregatorOne = fn.addAggregator(new Sum.SumLongFn());
+ Aggregator<Double, Double> aggregatorTwo = fn.addAggregator(new Min.MinDoubleFn());
+
+ TransformTreeNode transformNode = mock(TransformTreeNode.class);
+ when(transformNode.getTransform()).thenReturn(bound);
+ TransformTreeNode otherTransformNode = mock(TransformTreeNode.class);
+ when(otherTransformNode.getTransform()).thenReturn(otherBound);
+
+ doAnswer(new VisitNodesAnswer(ImmutableList.of(transformNode, otherTransformNode)))
+ .when(p)
+ .traverseTopologically(Mockito.any(PipelineVisitor.class));
+
+ AggregatorPipelineExtractor extractor = new AggregatorPipelineExtractor(p);
+
+ Map<Aggregator<?, ?>, Collection<PTransform<?, ?>>> aggregatorSteps =
+ extractor.getAggregatorSteps();
+
+ assertEquals(
+ ImmutableSet.<PTransform<?, ?>>of(bound, otherBound), aggregatorSteps.get(aggregatorOne));
+ assertEquals(
+ ImmutableSet.<PTransform<?, ?>>of(bound, otherBound), aggregatorSteps.get(aggregatorTwo));
+ assertEquals(2, aggregatorSteps.size());
+ }
+
+ @SuppressWarnings("unchecked")
+ @Test
+ public void testGetAggregatorStepsWithDifferentStepsAddsSteps() {
+ @SuppressWarnings("rawtypes")
+ ParDo.Bound bound = mock(ParDo.Bound.class, "Bound");
+
+ AggregatorProvidingDoFn<ThreadGroup, Void> fn = new AggregatorProvidingDoFn<>();
+ Aggregator<Long, Long> aggregatorOne = fn.addAggregator(new Sum.SumLongFn());
+
+ when(bound.getFn()).thenReturn(fn);
+
+ @SuppressWarnings("rawtypes")
+ ParDo.BoundMulti otherBound = mock(ParDo.BoundMulti.class, "otherBound");
+
+ AggregatorProvidingDoFn<Long, Long> otherFn = new AggregatorProvidingDoFn<>();
+ Aggregator<Double, Double> aggregatorTwo = otherFn.addAggregator(new Sum.SumDoubleFn());
+
+ when(otherBound.getFn()).thenReturn(otherFn);
+
+ TransformTreeNode transformNode = mock(TransformTreeNode.class);
+ when(transformNode.getTransform()).thenReturn(bound);
+ TransformTreeNode otherTransformNode = mock(TransformTreeNode.class);
+ when(otherTransformNode.getTransform()).thenReturn(otherBound);
+
+ doAnswer(new VisitNodesAnswer(ImmutableList.of(transformNode, otherTransformNode)))
+ .when(p)
+ .traverseTopologically(Mockito.any(PipelineVisitor.class));
+
+ AggregatorPipelineExtractor extractor = new AggregatorPipelineExtractor(p);
+
+ Map<Aggregator<?, ?>, Collection<PTransform<?, ?>>> aggregatorSteps =
+ extractor.getAggregatorSteps();
+
+ assertEquals(ImmutableSet.<PTransform<?, ?>>of(bound), aggregatorSteps.get(aggregatorOne));
+ assertEquals(ImmutableSet.<PTransform<?, ?>>of(otherBound), aggregatorSteps.get(aggregatorTwo));
+ assertEquals(2, aggregatorSteps.size());
+ }
+
+ private static class VisitNodesAnswer implements Answer<Object> {
+ private final List<TransformTreeNode> nodes;
+
+ public VisitNodesAnswer(List<TransformTreeNode> nodes) {
+ this.nodes = nodes;
+ }
+
+ @Override
+ public Object answer(InvocationOnMock invocation) throws Throwable {
+ PipelineVisitor visitor = (PipelineVisitor) invocation.getArguments()[0];
+ for (TransformTreeNode node : nodes) {
+ visitor.visitPrimitiveTransform(node);
+ }
+ return null;
+ }
+ }
+
+ private static class AggregatorProvidingDoFn<InT, OuT> extends OldDoFn<InT, OuT> {
+ public <InputT, OutT> Aggregator<InputT, OutT> addAggregator(
+ CombineFn<InputT, ?, OutT> combiner) {
+ return createAggregator(randomName(), combiner);
+ }
+
+ private String randomName() {
+ return UUID.randomUUID().toString();
+ }
+
+ @Override
+ public void processElement(OldDoFn<InT, OuT>.ProcessContext c) throws Exception {
+ fail();
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/adec254d/sdks/java/core/src/test/java/org/apache/beam/sdk/runners/AggregatorPipelineExtractorTest.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/runners/AggregatorPipelineExtractorTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/runners/AggregatorPipelineExtractorTest.java
deleted file mode 100644
index 13476e2..0000000
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/runners/AggregatorPipelineExtractorTest.java
+++ /dev/null
@@ -1,229 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.beam.sdk.runners;
-
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.fail;
-import static org.mockito.Mockito.doAnswer;
-import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.when;
-
-import org.apache.beam.sdk.Pipeline;
-import org.apache.beam.sdk.Pipeline.PipelineVisitor;
-import org.apache.beam.sdk.transforms.Aggregator;
-import org.apache.beam.sdk.transforms.Combine.CombineFn;
-import org.apache.beam.sdk.transforms.Max;
-import org.apache.beam.sdk.transforms.Min;
-import org.apache.beam.sdk.transforms.OldDoFn;
-import org.apache.beam.sdk.transforms.PTransform;
-import org.apache.beam.sdk.transforms.ParDo;
-import org.apache.beam.sdk.transforms.Sum;
-
-import com.google.common.collect.ImmutableList;
-import com.google.common.collect.ImmutableSet;
-
-import org.junit.Before;
-import org.junit.Test;
-import org.junit.runner.RunWith;
-import org.junit.runners.JUnit4;
-import org.mockito.Mock;
-import org.mockito.Mockito;
-import org.mockito.MockitoAnnotations;
-import org.mockito.invocation.InvocationOnMock;
-import org.mockito.stubbing.Answer;
-
-import java.util.Collection;
-import java.util.List;
-import java.util.Map;
-import java.util.UUID;
-
-/**
- * Tests for {@link AggregatorPipelineExtractor}.
- */
-@RunWith(JUnit4.class)
-public class AggregatorPipelineExtractorTest {
- @Mock
- private Pipeline p;
-
- @Before
- public void setup() {
- MockitoAnnotations.initMocks(this);
- }
-
- @SuppressWarnings("unchecked")
- @Test
- public void testGetAggregatorStepsWithParDoBoundExtractsSteps() {
- @SuppressWarnings("rawtypes")
- ParDo.Bound bound = mock(ParDo.Bound.class, "Bound");
- AggregatorProvidingDoFn<ThreadGroup, StrictMath> fn = new AggregatorProvidingDoFn<>();
- when(bound.getFn()).thenReturn(fn);
-
- Aggregator<Long, Long> aggregatorOne = fn.addAggregator(new Sum.SumLongFn());
- Aggregator<Integer, Integer> aggregatorTwo = fn.addAggregator(new Min.MinIntegerFn());
-
- TransformTreeNode transformNode = mock(TransformTreeNode.class);
- when(transformNode.getTransform()).thenReturn(bound);
-
- doAnswer(new VisitNodesAnswer(ImmutableList.of(transformNode)))
- .when(p)
- .traverseTopologically(Mockito.any(PipelineVisitor.class));
-
- AggregatorPipelineExtractor extractor = new AggregatorPipelineExtractor(p);
-
- Map<Aggregator<?, ?>, Collection<PTransform<?, ?>>> aggregatorSteps =
- extractor.getAggregatorSteps();
-
- assertEquals(ImmutableSet.<PTransform<?, ?>>of(bound), aggregatorSteps.get(aggregatorOne));
- assertEquals(ImmutableSet.<PTransform<?, ?>>of(bound), aggregatorSteps.get(aggregatorTwo));
- assertEquals(aggregatorSteps.size(), 2);
- }
-
- @SuppressWarnings("unchecked")
- @Test
- public void testGetAggregatorStepsWithParDoBoundMultiExtractsSteps() {
- @SuppressWarnings("rawtypes")
- ParDo.BoundMulti bound = mock(ParDo.BoundMulti.class, "BoundMulti");
- AggregatorProvidingDoFn<Object, Void> fn = new AggregatorProvidingDoFn<>();
- when(bound.getFn()).thenReturn(fn);
-
- Aggregator<Long, Long> aggregatorOne = fn.addAggregator(new Max.MaxLongFn());
- Aggregator<Double, Double> aggregatorTwo = fn.addAggregator(new Min.MinDoubleFn());
-
- TransformTreeNode transformNode = mock(TransformTreeNode.class);
- when(transformNode.getTransform()).thenReturn(bound);
-
- doAnswer(new VisitNodesAnswer(ImmutableList.of(transformNode)))
- .when(p)
- .traverseTopologically(Mockito.any(PipelineVisitor.class));
-
- AggregatorPipelineExtractor extractor = new AggregatorPipelineExtractor(p);
-
- Map<Aggregator<?, ?>, Collection<PTransform<?, ?>>> aggregatorSteps =
- extractor.getAggregatorSteps();
-
- assertEquals(ImmutableSet.<PTransform<?, ?>>of(bound), aggregatorSteps.get(aggregatorOne));
- assertEquals(ImmutableSet.<PTransform<?, ?>>of(bound), aggregatorSteps.get(aggregatorTwo));
- assertEquals(2, aggregatorSteps.size());
- }
-
- @SuppressWarnings("unchecked")
- @Test
- public void testGetAggregatorStepsWithOneAggregatorInMultipleStepsAddsSteps() {
- @SuppressWarnings("rawtypes")
- ParDo.Bound bound = mock(ParDo.Bound.class, "Bound");
- @SuppressWarnings("rawtypes")
- ParDo.BoundMulti otherBound = mock(ParDo.BoundMulti.class, "otherBound");
- AggregatorProvidingDoFn<String, Math> fn = new AggregatorProvidingDoFn<>();
- when(bound.getFn()).thenReturn(fn);
- when(otherBound.getFn()).thenReturn(fn);
-
- Aggregator<Long, Long> aggregatorOne = fn.addAggregator(new Sum.SumLongFn());
- Aggregator<Double, Double> aggregatorTwo = fn.addAggregator(new Min.MinDoubleFn());
-
- TransformTreeNode transformNode = mock(TransformTreeNode.class);
- when(transformNode.getTransform()).thenReturn(bound);
- TransformTreeNode otherTransformNode = mock(TransformTreeNode.class);
- when(otherTransformNode.getTransform()).thenReturn(otherBound);
-
- doAnswer(new VisitNodesAnswer(ImmutableList.of(transformNode, otherTransformNode)))
- .when(p)
- .traverseTopologically(Mockito.any(PipelineVisitor.class));
-
- AggregatorPipelineExtractor extractor = new AggregatorPipelineExtractor(p);
-
- Map<Aggregator<?, ?>, Collection<PTransform<?, ?>>> aggregatorSteps =
- extractor.getAggregatorSteps();
-
- assertEquals(
- ImmutableSet.<PTransform<?, ?>>of(bound, otherBound), aggregatorSteps.get(aggregatorOne));
- assertEquals(
- ImmutableSet.<PTransform<?, ?>>of(bound, otherBound), aggregatorSteps.get(aggregatorTwo));
- assertEquals(2, aggregatorSteps.size());
- }
-
- @SuppressWarnings("unchecked")
- @Test
- public void testGetAggregatorStepsWithDifferentStepsAddsSteps() {
- @SuppressWarnings("rawtypes")
- ParDo.Bound bound = mock(ParDo.Bound.class, "Bound");
-
- AggregatorProvidingDoFn<ThreadGroup, Void> fn = new AggregatorProvidingDoFn<>();
- Aggregator<Long, Long> aggregatorOne = fn.addAggregator(new Sum.SumLongFn());
-
- when(bound.getFn()).thenReturn(fn);
-
- @SuppressWarnings("rawtypes")
- ParDo.BoundMulti otherBound = mock(ParDo.BoundMulti.class, "otherBound");
-
- AggregatorProvidingDoFn<Long, Long> otherFn = new AggregatorProvidingDoFn<>();
- Aggregator<Double, Double> aggregatorTwo = otherFn.addAggregator(new Sum.SumDoubleFn());
-
- when(otherBound.getFn()).thenReturn(otherFn);
-
- TransformTreeNode transformNode = mock(TransformTreeNode.class);
- when(transformNode.getTransform()).thenReturn(bound);
- TransformTreeNode otherTransformNode = mock(TransformTreeNode.class);
- when(otherTransformNode.getTransform()).thenReturn(otherBound);
-
- doAnswer(new VisitNodesAnswer(ImmutableList.of(transformNode, otherTransformNode)))
- .when(p)
- .traverseTopologically(Mockito.any(PipelineVisitor.class));
-
- AggregatorPipelineExtractor extractor = new AggregatorPipelineExtractor(p);
-
- Map<Aggregator<?, ?>, Collection<PTransform<?, ?>>> aggregatorSteps =
- extractor.getAggregatorSteps();
-
- assertEquals(ImmutableSet.<PTransform<?, ?>>of(bound), aggregatorSteps.get(aggregatorOne));
- assertEquals(ImmutableSet.<PTransform<?, ?>>of(otherBound), aggregatorSteps.get(aggregatorTwo));
- assertEquals(2, aggregatorSteps.size());
- }
-
- private static class VisitNodesAnswer implements Answer<Object> {
- private final List<TransformTreeNode> nodes;
-
- public VisitNodesAnswer(List<TransformTreeNode> nodes) {
- this.nodes = nodes;
- }
-
- @Override
- public Object answer(InvocationOnMock invocation) throws Throwable {
- PipelineVisitor visitor = (PipelineVisitor) invocation.getArguments()[0];
- for (TransformTreeNode node : nodes) {
- visitor.visitPrimitiveTransform(node);
- }
- return null;
- }
- }
-
- private static class AggregatorProvidingDoFn<InT, OuT> extends OldDoFn<InT, OuT> {
- public <InputT, OutT> Aggregator<InputT, OutT> addAggregator(
- CombineFn<InputT, ?, OutT> combiner) {
- return createAggregator(randomName(), combiner);
- }
-
- private String randomName() {
- return UUID.randomUUID().toString();
- }
-
- @Override
- public void processElement(OldDoFn<InT, OuT>.ProcessContext c) throws Exception {
- fail();
- }
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/adec254d/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/DoFnTest.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/DoFnTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/DoFnTest.java
index 710e4ce..3fb3193 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/DoFnTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/DoFnTest.java
@@ -128,6 +128,7 @@ public class DoFnTest implements Serializable {
DoFn<Void, Void> doFn = new NoOpDoFn();
Aggregator<Double, Double> aggregatorOne =
+
doFn.createAggregator(nameOne, combiner);
Aggregator<Double, Double> aggregatorTwo =
doFn.createAggregator(nameTwo, combiner);
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/adec254d/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/OldDoFnTest.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/OldDoFnTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/OldDoFnTest.java
index 9d144b3..5946d9a 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/OldDoFnTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/OldDoFnTest.java
@@ -24,10 +24,10 @@ import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotEquals;
import static org.junit.Assert.assertThat;
+import org.apache.beam.sdk.AggregatorValues;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.Pipeline.PipelineExecutionException;
import org.apache.beam.sdk.PipelineResult;
-import org.apache.beam.sdk.runners.AggregatorValues;
import org.apache.beam.sdk.testing.NeedsRunner;
import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.transforms.Combine.CombineFn;
@@ -36,6 +36,7 @@ import org.apache.beam.sdk.transforms.Sum.SumIntegerFn;
import org.apache.beam.sdk.transforms.display.DisplayData;
import com.google.common.collect.ImmutableMap;
+
import org.junit.Rule;
import org.junit.Test;
import org.junit.experimental.categories.Category;