You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by tg...@apache.org on 2017/03/17 15:54:30 UTC
[2/4] beam git commit: Implement Single-Output ParDo as a composite
http://git-wip-us.apache.org/repos/asf/beam/blob/c6cad209/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/PrimitiveParDoSingleFactoryTest.java
----------------------------------------------------------------------
diff --git a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/PrimitiveParDoSingleFactoryTest.java b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/PrimitiveParDoSingleFactoryTest.java
new file mode 100644
index 0000000..cb1e34e
--- /dev/null
+++ b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/PrimitiveParDoSingleFactoryTest.java
@@ -0,0 +1,126 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.dataflow;
+
+import static org.hamcrest.Matchers.containsInAnyOrder;
+import static org.hamcrest.Matchers.equalTo;
+import static org.junit.Assert.assertThat;
+
+import com.google.common.collect.Iterables;
+import java.io.Serializable;
+import java.util.List;
+import org.apache.beam.runners.dataflow.PrimitiveParDoSingleFactory.ParDoSingle;
+import org.apache.beam.sdk.coders.VarIntCoder;
+import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.transforms.Sum;
+import org.apache.beam.sdk.transforms.View;
+import org.apache.beam.sdk.transforms.display.DisplayData;
+import org.apache.beam.sdk.transforms.display.DisplayDataEvaluator;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PCollectionView;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/**
+ * Tests for {@link PrimitiveParDoSingleFactory}.
+ */
+@RunWith(JUnit4.class)
+public class PrimitiveParDoSingleFactoryTest implements Serializable {
+ // Create a pipeline for testing Side Input propagation. This won't actually run any Pipelines,
+ // so disable enforcement.
+ @Rule
+ public transient TestPipeline pipeline =
+ TestPipeline.create().enableAbandonedNodeEnforcement(false);
+
+ private PrimitiveParDoSingleFactory<Integer, Long> factory = new PrimitiveParDoSingleFactory<>();
+
+ /**
+ * A test that demonstrates that the replacement transform has the Display Data of the
+ * {@link ParDo.Bound} it replaces.
+ */
+ @Test
+ public void getReplacementTransformPopulateDisplayData() {
+ ParDo.Bound<Integer, Long> originalTransform = ParDo.of(new ToLongFn());
+ DisplayData originalDisplayData = DisplayData.from(originalTransform);
+
+ PTransform<PCollection<? extends Integer>, PCollection<Long>> replacement =
+ factory.getReplacementTransform(originalTransform);
+ DisplayData replacementDisplayData = DisplayData.from(replacement);
+
+ assertThat(replacementDisplayData, equalTo(originalDisplayData));
+
+ DisplayData primitiveDisplayData =
+ Iterables.getOnlyElement(
+ DisplayDataEvaluator.create()
+ .displayDataForPrimitiveTransforms(replacement, VarIntCoder.of()));
+ assertThat(primitiveDisplayData, equalTo(replacementDisplayData));
+ }
+
+ @Test
+ public void getReplacementTransformGetSideInputs() {
+ PCollectionView<Long> sideLong =
+ pipeline
+ .apply("LongSideInputVals", Create.of(-1L, -2L, -4L))
+ .apply("SideLongView", Sum.longsGlobally().asSingletonView());
+ PCollectionView<List<String>> sideStrings =
+ pipeline
+ .apply("StringSideInputVals", Create.of("foo", "bar", "baz"))
+ .apply("SideStringsView", View.<String>asList());
+ ParDo.Bound<Integer, Long> originalTransform =
+ ParDo.of(new ToLongFn()).withSideInputs(sideLong, sideStrings);
+
+ PTransform<PCollection<? extends Integer>, PCollection<Long>> replacementTransform =
+ factory.getReplacementTransform(originalTransform);
+ ParDoSingle<Integer, Long> parDoSingle = (ParDoSingle<Integer, Long>) replacementTransform;
+ assertThat(parDoSingle.getSideInputs(), containsInAnyOrder(sideStrings, sideLong));
+ }
+
+ @Test
+ public void getReplacementTransformGetFn() {
+ DoFn<Integer, Long> originalFn = new ToLongFn();
+ ParDo.Bound<Integer, Long> originalTransform = ParDo.of(originalFn);
+ PTransform<PCollection<? extends Integer>, PCollection<Long>> replacementTransform =
+ factory.getReplacementTransform(originalTransform);
+ ParDoSingle<Integer, Long> parDoSingle = (ParDoSingle<Integer, Long>) replacementTransform;
+
+ assertThat(parDoSingle.getFn(), equalTo(originalTransform.getFn()));
+ assertThat(parDoSingle.getFn(), equalTo(originalFn));
+ }
+
+ private static class ToLongFn extends DoFn<Integer, Long> {
+ @ProcessElement
+ public void toLong(ProcessContext ctxt) {
+ ctxt.output(ctxt.element().longValue());
+ }
+
+ public boolean equals(Object other) {
+ return other != null && other.getClass().equals(getClass());
+ }
+
+ public int hashCode() {
+ return getClass().hashCode();
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/beam/blob/c6cad209/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java
index 8d1b82e..b4362b0 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java
@@ -27,6 +27,7 @@ import static org.apache.beam.runners.spark.io.hadoop.ShardNameBuilder.replaceSh
import static org.apache.beam.runners.spark.translation.TranslationUtils.rejectSplittable;
import static org.apache.beam.runners.spark.translation.TranslationUtils.rejectStateAndTimers;
+import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import java.io.IOException;
@@ -348,38 +349,8 @@ public final class TransformTranslator {
};
}
- private static <InputT, OutputT> TransformEvaluator<ParDo.Bound<InputT, OutputT>> parDo() {
- return new TransformEvaluator<ParDo.Bound<InputT, OutputT>>() {
- @Override
- public void evaluate(ParDo.Bound<InputT, OutputT> transform, EvaluationContext context) {
- String stepName = context.getCurrentTransform().getFullName();
- DoFn<InputT, OutputT> doFn = transform.getFn();
- rejectSplittable(doFn);
- rejectStateAndTimers(doFn);
- @SuppressWarnings("unchecked")
- JavaRDD<WindowedValue<InputT>> inRDD =
- ((BoundedDataset<InputT>) context.borrowDataset(transform)).getRDD();
- WindowingStrategy<?, ?> windowingStrategy =
- context.getInput(transform).getWindowingStrategy();
- Accumulator<NamedAggregators> aggAccum = AggregatorsAccumulator.getInstance();
- Accumulator<SparkMetricsContainer> metricsAccum =
- MetricsAccumulator.getInstance();
- Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, SideInputBroadcast<?>>> sideInputs =
- TranslationUtils.getSideInputs(transform.getSideInputs(), context);
- context.putDataset(transform,
- new BoundedDataset<>(inRDD.mapPartitions(new DoFnFunction<>(aggAccum, metricsAccum,
- stepName, doFn, context.getRuntimeContext(), sideInputs, windowingStrategy))));
- }
-
- @Override
- public String toNativeString() {
- return "mapPartitions(new <fn>())";
- }
- };
- }
-
private static <InputT, OutputT> TransformEvaluator<ParDo.BoundMulti<InputT, OutputT>>
- multiDo() {
+ parDo() {
return new TransformEvaluator<ParDo.BoundMulti<InputT, OutputT>>() {
@Override
public void evaluate(ParDo.BoundMulti<InputT, OutputT> transform, EvaluationContext context) {
@@ -393,24 +364,52 @@ public final class TransformTranslator {
WindowingStrategy<?, ?> windowingStrategy =
context.getInput(transform).getWindowingStrategy();
Accumulator<NamedAggregators> aggAccum = AggregatorsAccumulator.getInstance();
- Accumulator<SparkMetricsContainer> metricsAccum =
- MetricsAccumulator.getInstance();
- JavaPairRDD<TupleTag<?>, WindowedValue<?>> all = inRDD
- .mapPartitionsToPair(
- new MultiDoFnFunction<>(aggAccum, metricsAccum, stepName, doFn,
- context.getRuntimeContext(), transform.getMainOutputTag(),
- TranslationUtils.getSideInputs(transform.getSideInputs(), context),
- windowingStrategy)).cache();
- List<TaggedPValue> pct = context.getOutputs(transform);
- for (TaggedPValue e : pct) {
- @SuppressWarnings("unchecked")
- JavaPairRDD<TupleTag<?>, WindowedValue<?>> filtered =
- all.filter(new TranslationUtils.TupleTagFilter(e.getTag()));
- @SuppressWarnings("unchecked")
- // Object is the best we can do since different outputs can have different tags
- JavaRDD<WindowedValue<Object>> values =
- (JavaRDD<WindowedValue<Object>>) (JavaRDD<?>) filtered.values();
- context.putDataset(e.getValue(), new BoundedDataset<>(values));
+ Accumulator<SparkMetricsContainer> metricsAccum = MetricsAccumulator.getInstance();
+ Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, SideInputBroadcast<?>>> sideInputs =
+ TranslationUtils.getSideInputs(transform.getSideInputs(), context);
+ if (transform.getSideOutputTags().size() == 0) {
+ // Don't tag with the output and filter for a single-output ParDo, as it's additional
+ // identity transforms.
+ // Also see BEAM-1737 for failures when the two versions are condensed.
+ PCollection<OutputT> output =
+ (PCollection<OutputT>)
+ Iterables.getOnlyElement(context.getOutputs(transform)).getValue();
+ context.putDataset(
+ output,
+ new BoundedDataset<>(
+ inRDD.mapPartitions(
+ new DoFnFunction<>(
+ aggAccum,
+ metricsAccum,
+ stepName,
+ doFn,
+ context.getRuntimeContext(),
+ sideInputs,
+ windowingStrategy))));
+ } else {
+ JavaPairRDD<TupleTag<?>, WindowedValue<?>> all =
+ inRDD
+ .mapPartitionsToPair(
+ new MultiDoFnFunction<>(
+ aggAccum,
+ metricsAccum,
+ stepName,
+ doFn,
+ context.getRuntimeContext(),
+ transform.getMainOutputTag(),
+ TranslationUtils.getSideInputs(transform.getSideInputs(), context),
+ windowingStrategy))
+ .cache();
+ for (TaggedPValue output : context.getOutputs(transform)) {
+ @SuppressWarnings("unchecked")
+ JavaPairRDD<TupleTag<?>, WindowedValue<?>> filtered =
+ all.filter(new TranslationUtils.TupleTagFilter(output.getTag()));
+ @SuppressWarnings("unchecked")
+ // Object is the best we can do since different outputs can have different tags
+ JavaRDD<WindowedValue<Object>> values =
+ (JavaRDD<WindowedValue<Object>>) (JavaRDD<?>) filtered.values();
+ context.putDataset(output.getValue(), new BoundedDataset<>(values));
+ }
}
}
@@ -842,8 +841,7 @@ public final class TransformTranslator {
EVALUATORS.put(Read.Bounded.class, readBounded());
EVALUATORS.put(HadoopIO.Read.Bound.class, readHadoop());
EVALUATORS.put(HadoopIO.Write.Bound.class, writeHadoop());
- EVALUATORS.put(ParDo.Bound.class, parDo());
- EVALUATORS.put(ParDo.BoundMulti.class, multiDo());
+ EVALUATORS.put(ParDo.BoundMulti.class, parDo());
EVALUATORS.put(GroupByKey.class, groupByKey());
EVALUATORS.put(Combine.GroupedValues.class, combineGrouped());
EVALUATORS.put(Combine.Globally.class, combineGlobally());
http://git-wip-us.apache.org/repos/asf/beam/blob/c6cad209/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java
index 2744169..25fecf6 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java
@@ -366,62 +366,11 @@ public final class StreamingTransformTranslator {
};
}
- private static <InputT, OutputT> TransformEvaluator<ParDo.Bound<InputT, OutputT>> parDo() {
- return new TransformEvaluator<ParDo.Bound<InputT, OutputT>>() {
- @Override
- public void evaluate(final ParDo.Bound<InputT, OutputT> transform,
- final EvaluationContext context) {
- final DoFn<InputT, OutputT> doFn = transform.getFn();
- rejectSplittable(doFn);
- rejectStateAndTimers(doFn);
- final SparkRuntimeContext runtimeContext = context.getRuntimeContext();
- final WindowingStrategy<?, ?> windowingStrategy =
- context.getInput(transform).getWindowingStrategy();
- final SparkPCollectionView pviews = context.getPViews();
-
- @SuppressWarnings("unchecked")
- UnboundedDataset<InputT> unboundedDataset =
- ((UnboundedDataset<InputT>) context.borrowDataset(transform));
- JavaDStream<WindowedValue<InputT>> dStream = unboundedDataset.getDStream();
-
- final String stepName = context.getCurrentTransform().getFullName();
-
- JavaDStream<WindowedValue<OutputT>> outStream =
- dStream.transform(new Function<JavaRDD<WindowedValue<InputT>>,
- JavaRDD<WindowedValue<OutputT>>>() {
- @Override
- public JavaRDD<WindowedValue<OutputT>> call(JavaRDD<WindowedValue<InputT>> rdd) throws
- Exception {
- final JavaSparkContext jsc = new JavaSparkContext(rdd.context());
- final Accumulator<NamedAggregators> aggAccum = AggregatorsAccumulator.getInstance();
- final Accumulator<SparkMetricsContainer> metricsAccum =
- MetricsAccumulator.getInstance();
- final Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, SideInputBroadcast<?>>> sideInputs =
- TranslationUtils.getSideInputs(transform.getSideInputs(),
- jsc, pviews);
- return rdd.mapPartitions(
- new DoFnFunction<>(aggAccum, metricsAccum, stepName, doFn, runtimeContext,
- sideInputs, windowingStrategy));
- }
- });
-
- context.putDataset(transform,
- new UnboundedDataset<>(outStream, unboundedDataset.getStreamSources()));
- }
-
- @Override
- public String toNativeString() {
- return "mapPartitions(new <fn>())";
- }
- };
- }
-
private static <InputT, OutputT> TransformEvaluator<ParDo.BoundMulti<InputT, OutputT>>
multiDo() {
return new TransformEvaluator<ParDo.BoundMulti<InputT, OutputT>>() {
- @Override
- public void evaluate(final ParDo.BoundMulti<InputT, OutputT> transform,
- final EvaluationContext context) {
+ public void evaluate(
+ final ParDo.BoundMulti<InputT, OutputT> transform, final EvaluationContext context) {
final DoFn<InputT, OutputT> doFn = transform.getFn();
rejectSplittable(doFn);
rejectStateAndTimers(doFn);
@@ -435,36 +384,90 @@ public final class StreamingTransformTranslator {
((UnboundedDataset<InputT>) context.borrowDataset(transform));
JavaDStream<WindowedValue<InputT>> dStream = unboundedDataset.getDStream();
- JavaPairDStream<TupleTag<?>, WindowedValue<?>> all = dStream.transformToPair(
- new Function<JavaRDD<WindowedValue<InputT>>,
- JavaPairRDD<TupleTag<?>, WindowedValue<?>>>() {
- @Override
- public JavaPairRDD<TupleTag<?>, WindowedValue<?>> call(
- JavaRDD<WindowedValue<InputT>> rdd) throws Exception {
- String stepName = context.getCurrentTransform().getFullName();
- final Accumulator<NamedAggregators> aggAccum = AggregatorsAccumulator.getInstance();
- final Accumulator<SparkMetricsContainer> metricsAccum =
- MetricsAccumulator.getInstance();
- final Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, SideInputBroadcast<?>>> sideInputs =
- TranslationUtils.getSideInputs(transform.getSideInputs(),
- JavaSparkContext.fromSparkContext(rdd.context()), pviews);
- return rdd.mapPartitionsToPair(new MultiDoFnFunction<>(aggAccum, metricsAccum,
- stepName, doFn, runtimeContext, transform.getMainOutputTag(), sideInputs,
- windowingStrategy));
+ final String stepName = context.getCurrentTransform().getFullName();
+ if (transform.getSideOutputTags().size() == 0) {
+ // Don't tag with the output and filter for a single-output ParDo, as it's additional
+ // identity transforms.
+ // Also see BEAM-1737 for failures when the two versions are condensed.
+ JavaDStream<WindowedValue<OutputT>> outStream =
+ dStream.transform(
+ new Function<JavaRDD<WindowedValue<InputT>>, JavaRDD<WindowedValue<OutputT>>>() {
+ @Override
+ public JavaRDD<WindowedValue<OutputT>> call(JavaRDD<WindowedValue<InputT>> rdd)
+ throws Exception {
+ final JavaSparkContext jsc = new JavaSparkContext(rdd.context());
+ final Accumulator<NamedAggregators> aggAccum =
+ AggregatorsAccumulator.getInstance();
+ final Accumulator<SparkMetricsContainer> metricsAccum =
+ MetricsAccumulator.getInstance();
+ final Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, SideInputBroadcast<?>>>
+ sideInputs =
+ TranslationUtils.getSideInputs(
+ transform.getSideInputs(), jsc, pviews);
+ return rdd.mapPartitions(
+ new DoFnFunction<>(
+ aggAccum,
+ metricsAccum,
+ stepName,
+ doFn,
+ runtimeContext,
+ sideInputs,
+ windowingStrategy));
+ }
+ });
+
+ PCollection<OutputT> output =
+ (PCollection<OutputT>)
+ Iterables.getOnlyElement(context.getOutputs(transform)).getValue();
+ context.putDataset(
+ output, new UnboundedDataset<>(outStream, unboundedDataset.getStreamSources()));
+ } else {
+ JavaPairDStream<TupleTag<?>, WindowedValue<?>> all =
+ dStream
+ .transformToPair(
+ new Function<
+ JavaRDD<WindowedValue<InputT>>,
+ JavaPairRDD<TupleTag<?>, WindowedValue<?>>>() {
+ @Override
+ public JavaPairRDD<TupleTag<?>, WindowedValue<?>> call(
+ JavaRDD<WindowedValue<InputT>> rdd) throws Exception {
+ String stepName = context.getCurrentTransform().getFullName();
+ final Accumulator<NamedAggregators> aggAccum =
+ AggregatorsAccumulator.getInstance();
+ final Accumulator<SparkMetricsContainer> metricsAccum =
+ MetricsAccumulator.getInstance();
+ final Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, SideInputBroadcast<?>>>
+ sideInputs =
+ TranslationUtils.getSideInputs(
+ transform.getSideInputs(),
+ JavaSparkContext.fromSparkContext(rdd.context()),
+ pviews);
+ return rdd.mapPartitionsToPair(
+ new MultiDoFnFunction<>(
+ aggAccum,
+ metricsAccum,
+ stepName,
+ doFn,
+ runtimeContext,
+ transform.getMainOutputTag(),
+ sideInputs,
+ windowingStrategy));
+ }
+ })
+ .cache();
+ for (TaggedPValue output : context.getOutputs(transform)) {
+ @SuppressWarnings("unchecked")
+ JavaPairDStream<TupleTag<?>, WindowedValue<?>> filtered =
+ all.filter(new TranslationUtils.TupleTagFilter(output.getTag()));
+ @SuppressWarnings("unchecked")
+ // Object is the best we can do since different outputs can have different tags
+ JavaDStream<WindowedValue<Object>> values =
+ (JavaDStream<WindowedValue<Object>>)
+ (JavaDStream<?>) TranslationUtils.dStreamValues(filtered);
+ context.putDataset(
+ output.getValue(),
+ new UnboundedDataset<>(values, unboundedDataset.getStreamSources()));
}
- }).cache();
- List<TaggedPValue> pct = context.getOutputs(transform);
- for (TaggedPValue e : pct) {
- @SuppressWarnings("unchecked")
- JavaPairDStream<TupleTag<?>, WindowedValue<?>> filtered =
- all.filter(new TranslationUtils.TupleTagFilter(e.getTag()));
- @SuppressWarnings("unchecked")
- // Object is the best we can do since different outputs can have different tags
- JavaDStream<WindowedValue<Object>> values =
- (JavaDStream<WindowedValue<Object>>)
- (JavaDStream<?>) TranslationUtils.dStreamValues(filtered);
- context.putDataset(e.getValue(),
- new UnboundedDataset<>(values, unboundedDataset.getStreamSources()));
}
}
@@ -520,7 +523,6 @@ public final class StreamingTransformTranslator {
EVALUATORS.put(Read.Unbounded.class, readUnbounded());
EVALUATORS.put(GroupByKey.class, groupByKey());
EVALUATORS.put(Combine.GroupedValues.class, combineGrouped());
- EVALUATORS.put(ParDo.Bound.class, parDo());
EVALUATORS.put(ParDo.BoundMulti.class, multiDo());
EVALUATORS.put(ConsoleIO.Write.Unbound.class, print());
EVALUATORS.put(CreateStream.class, createFromQueue());
http://git-wip-us.apache.org/repos/asf/beam/blob/c6cad209/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/TrackStreamingSourcesTest.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/TrackStreamingSourcesTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/TrackStreamingSourcesTest.java
index b181a04..d66633b 100644
--- a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/TrackStreamingSourcesTest.java
+++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/TrackStreamingSourcesTest.java
@@ -83,7 +83,7 @@ public class TrackStreamingSourcesTest {
p.apply(emptyStream).apply(ParDo.of(new PassthroughFn<>()));
- p.traverseTopologically(new StreamingSourceTracker(jssc, p, ParDo.Bound.class, 0));
+ p.traverseTopologically(new StreamingSourceTracker(jssc, p, ParDo.BoundMulti.class, 0));
assertThat(StreamingSourceTracker.numAssertions, equalTo(1));
}
@@ -111,7 +111,7 @@ public class TrackStreamingSourcesTest {
PCollectionList.of(pcol1).and(pcol2).apply(Flatten.<Integer>pCollections());
flattened.apply(ParDo.of(new PassthroughFn<>()));
- p.traverseTopologically(new StreamingSourceTracker(jssc, p, ParDo.Bound.class, 0, 1));
+ p.traverseTopologically(new StreamingSourceTracker(jssc, p, ParDo.BoundMulti.class, 0, 1));
assertThat(StreamingSourceTracker.numAssertions, equalTo(1));
}
http://git-wip-us.apache.org/repos/asf/beam/blob/c6cad209/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java
index 19c5a2d..9225231 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java
@@ -738,12 +738,8 @@ public class ParDo {
@Override
public PCollection<OutputT> expand(PCollection<? extends InputT> input) {
- validateWindowType(input, fn);
- return PCollection.<OutputT>createPrimitiveOutputInternal(
- input.getPipeline(),
- input.getWindowingStrategy(),
- input.isBounded())
- .setTypeDescriptor(getFn().getOutputTypeDescriptor());
+ TupleTag<OutputT> mainOutput = new TupleTag<>();
+ return input.apply(withOutputTags(mainOutput, TupleTagList.empty())).get(mainOutput);
}
@Override