You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by tg...@apache.org on 2017/03/17 15:54:30 UTC

[2/4] beam git commit: Implement Single-Output ParDo as a composite

http://git-wip-us.apache.org/repos/asf/beam/blob/c6cad209/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/PrimitiveParDoSingleFactoryTest.java
----------------------------------------------------------------------
diff --git a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/PrimitiveParDoSingleFactoryTest.java b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/PrimitiveParDoSingleFactoryTest.java
new file mode 100644
index 0000000..cb1e34e
--- /dev/null
+++ b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/PrimitiveParDoSingleFactoryTest.java
@@ -0,0 +1,126 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.dataflow;
+
+import static org.hamcrest.Matchers.containsInAnyOrder;
+import static org.hamcrest.Matchers.equalTo;
+import static org.junit.Assert.assertThat;
+
+import com.google.common.collect.Iterables;
+import java.io.Serializable;
+import java.util.List;
+import org.apache.beam.runners.dataflow.PrimitiveParDoSingleFactory.ParDoSingle;
+import org.apache.beam.sdk.coders.VarIntCoder;
+import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.transforms.Sum;
+import org.apache.beam.sdk.transforms.View;
+import org.apache.beam.sdk.transforms.display.DisplayData;
+import org.apache.beam.sdk.transforms.display.DisplayDataEvaluator;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PCollectionView;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/**
+ * Tests for {@link PrimitiveParDoSingleFactory}.
+ */
+@RunWith(JUnit4.class)
+public class PrimitiveParDoSingleFactoryTest implements Serializable {
+  // Create a pipeline for testing Side Input propagation. This won't actually run any Pipelines,
+  // so disable enforcement.
+  @Rule
+  public transient TestPipeline pipeline =
+      TestPipeline.create().enableAbandonedNodeEnforcement(false);
+
+  private PrimitiveParDoSingleFactory<Integer, Long> factory = new PrimitiveParDoSingleFactory<>();
+
+  /**
+   * A test that demonstrates that the replacement transform has the Display Data of the
+   * {@link ParDo.Bound} it replaces.
+   */
+  @Test
+  public void getReplacementTransformPopulateDisplayData() {
+    ParDo.Bound<Integer, Long> originalTransform = ParDo.of(new ToLongFn());
+    DisplayData originalDisplayData = DisplayData.from(originalTransform);
+
+    PTransform<PCollection<? extends Integer>, PCollection<Long>> replacement =
+        factory.getReplacementTransform(originalTransform);
+    DisplayData replacementDisplayData = DisplayData.from(replacement);
+
+    assertThat(replacementDisplayData, equalTo(originalDisplayData));
+
+    DisplayData primitiveDisplayData =
+        Iterables.getOnlyElement(
+            DisplayDataEvaluator.create()
+                .displayDataForPrimitiveTransforms(replacement, VarIntCoder.of()));
+    assertThat(primitiveDisplayData, equalTo(replacementDisplayData));
+  }
+
+  @Test
+  public void getReplacementTransformGetSideInputs() {
+    PCollectionView<Long> sideLong =
+        pipeline
+            .apply("LongSideInputVals", Create.of(-1L, -2L, -4L))
+            .apply("SideLongView", Sum.longsGlobally().asSingletonView());
+    PCollectionView<List<String>> sideStrings =
+        pipeline
+            .apply("StringSideInputVals", Create.of("foo", "bar", "baz"))
+            .apply("SideStringsView", View.<String>asList());
+    ParDo.Bound<Integer, Long> originalTransform =
+        ParDo.of(new ToLongFn()).withSideInputs(sideLong, sideStrings);
+
+    PTransform<PCollection<? extends Integer>, PCollection<Long>> replacementTransform =
+        factory.getReplacementTransform(originalTransform);
+    ParDoSingle<Integer, Long> parDoSingle = (ParDoSingle<Integer, Long>) replacementTransform;
+    assertThat(parDoSingle.getSideInputs(), containsInAnyOrder(sideStrings, sideLong));
+  }
+
+  @Test
+  public void getReplacementTransformGetFn() {
+    DoFn<Integer, Long> originalFn = new ToLongFn();
+    ParDo.Bound<Integer, Long> originalTransform = ParDo.of(originalFn);
+    PTransform<PCollection<? extends Integer>, PCollection<Long>> replacementTransform =
+        factory.getReplacementTransform(originalTransform);
+    ParDoSingle<Integer, Long> parDoSingle = (ParDoSingle<Integer, Long>) replacementTransform;
+
+    assertThat(parDoSingle.getFn(), equalTo(originalTransform.getFn()));
+    assertThat(parDoSingle.getFn(), equalTo(originalFn));
+  }
+
+  private static class ToLongFn extends DoFn<Integer, Long> {
+    @ProcessElement
+    public void toLong(ProcessContext ctxt) {
+      ctxt.output(ctxt.element().longValue());
+    }
+
+    public boolean equals(Object other) {
+      return other != null && other.getClass().equals(getClass());
+    }
+
+    public int hashCode() {
+      return getClass().hashCode();
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/beam/blob/c6cad209/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java
index 8d1b82e..b4362b0 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java
@@ -27,6 +27,7 @@ import static org.apache.beam.runners.spark.io.hadoop.ShardNameBuilder.replaceSh
 import static org.apache.beam.runners.spark.translation.TranslationUtils.rejectSplittable;
 import static org.apache.beam.runners.spark.translation.TranslationUtils.rejectStateAndTimers;
 
+import com.google.common.collect.Iterables;
 import com.google.common.collect.Lists;
 import com.google.common.collect.Maps;
 import java.io.IOException;
@@ -348,38 +349,8 @@ public final class TransformTranslator {
     };
   }
 
-  private static <InputT, OutputT> TransformEvaluator<ParDo.Bound<InputT, OutputT>> parDo() {
-    return new TransformEvaluator<ParDo.Bound<InputT, OutputT>>() {
-      @Override
-      public void evaluate(ParDo.Bound<InputT, OutputT> transform, EvaluationContext context) {
-        String stepName = context.getCurrentTransform().getFullName();
-        DoFn<InputT, OutputT> doFn = transform.getFn();
-        rejectSplittable(doFn);
-        rejectStateAndTimers(doFn);
-        @SuppressWarnings("unchecked")
-        JavaRDD<WindowedValue<InputT>> inRDD =
-            ((BoundedDataset<InputT>) context.borrowDataset(transform)).getRDD();
-        WindowingStrategy<?, ?> windowingStrategy =
-            context.getInput(transform).getWindowingStrategy();
-        Accumulator<NamedAggregators> aggAccum = AggregatorsAccumulator.getInstance();
-        Accumulator<SparkMetricsContainer> metricsAccum =
-            MetricsAccumulator.getInstance();
-        Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, SideInputBroadcast<?>>> sideInputs =
-            TranslationUtils.getSideInputs(transform.getSideInputs(), context);
-        context.putDataset(transform,
-            new BoundedDataset<>(inRDD.mapPartitions(new DoFnFunction<>(aggAccum, metricsAccum,
-                stepName, doFn, context.getRuntimeContext(), sideInputs, windowingStrategy))));
-      }
-
-      @Override
-      public String toNativeString() {
-        return "mapPartitions(new <fn>())";
-      }
-    };
-  }
-
   private static <InputT, OutputT> TransformEvaluator<ParDo.BoundMulti<InputT, OutputT>>
-  multiDo() {
+  parDo() {
     return new TransformEvaluator<ParDo.BoundMulti<InputT, OutputT>>() {
       @Override
       public void evaluate(ParDo.BoundMulti<InputT, OutputT> transform, EvaluationContext context) {
@@ -393,24 +364,52 @@ public final class TransformTranslator {
         WindowingStrategy<?, ?> windowingStrategy =
             context.getInput(transform).getWindowingStrategy();
         Accumulator<NamedAggregators> aggAccum = AggregatorsAccumulator.getInstance();
-        Accumulator<SparkMetricsContainer> metricsAccum =
-            MetricsAccumulator.getInstance();
-        JavaPairRDD<TupleTag<?>, WindowedValue<?>> all = inRDD
-            .mapPartitionsToPair(
-                new MultiDoFnFunction<>(aggAccum, metricsAccum, stepName, doFn,
-                    context.getRuntimeContext(), transform.getMainOutputTag(),
-                    TranslationUtils.getSideInputs(transform.getSideInputs(), context),
-                    windowingStrategy)).cache();
-        List<TaggedPValue> pct = context.getOutputs(transform);
-        for (TaggedPValue e : pct) {
-          @SuppressWarnings("unchecked")
-          JavaPairRDD<TupleTag<?>, WindowedValue<?>> filtered =
-              all.filter(new TranslationUtils.TupleTagFilter(e.getTag()));
-          @SuppressWarnings("unchecked")
-          // Object is the best we can do since different outputs can have different tags
-          JavaRDD<WindowedValue<Object>> values =
-              (JavaRDD<WindowedValue<Object>>) (JavaRDD<?>) filtered.values();
-          context.putDataset(e.getValue(), new BoundedDataset<>(values));
+        Accumulator<SparkMetricsContainer> metricsAccum = MetricsAccumulator.getInstance();
+        Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, SideInputBroadcast<?>>> sideInputs =
+            TranslationUtils.getSideInputs(transform.getSideInputs(), context);
+        if (transform.getSideOutputTags().size() == 0) {
+          // Don't tag with the output and filter for a single-output ParDo, as it's additional
+          // identity transforms.
+          // Also see BEAM-1737 for failures when the two versions are condensed.
+          PCollection<OutputT> output =
+              (PCollection<OutputT>)
+                  Iterables.getOnlyElement(context.getOutputs(transform)).getValue();
+          context.putDataset(
+              output,
+              new BoundedDataset<>(
+                  inRDD.mapPartitions(
+                      new DoFnFunction<>(
+                          aggAccum,
+                          metricsAccum,
+                          stepName,
+                          doFn,
+                          context.getRuntimeContext(),
+                          sideInputs,
+                          windowingStrategy))));
+        } else {
+          JavaPairRDD<TupleTag<?>, WindowedValue<?>> all =
+              inRDD
+                  .mapPartitionsToPair(
+                      new MultiDoFnFunction<>(
+                          aggAccum,
+                          metricsAccum,
+                          stepName,
+                          doFn,
+                          context.getRuntimeContext(),
+                          transform.getMainOutputTag(),
+                          TranslationUtils.getSideInputs(transform.getSideInputs(), context),
+                          windowingStrategy))
+                  .cache();
+          for (TaggedPValue output : context.getOutputs(transform)) {
+            @SuppressWarnings("unchecked")
+            JavaPairRDD<TupleTag<?>, WindowedValue<?>> filtered =
+                all.filter(new TranslationUtils.TupleTagFilter(output.getTag()));
+            @SuppressWarnings("unchecked")
+            // Object is the best we can do since different outputs can have different tags
+            JavaRDD<WindowedValue<Object>> values =
+                (JavaRDD<WindowedValue<Object>>) (JavaRDD<?>) filtered.values();
+            context.putDataset(output.getValue(), new BoundedDataset<>(values));
+          }
         }
       }
 
@@ -842,8 +841,7 @@ public final class TransformTranslator {
     EVALUATORS.put(Read.Bounded.class, readBounded());
     EVALUATORS.put(HadoopIO.Read.Bound.class, readHadoop());
     EVALUATORS.put(HadoopIO.Write.Bound.class, writeHadoop());
-    EVALUATORS.put(ParDo.Bound.class, parDo());
-    EVALUATORS.put(ParDo.BoundMulti.class, multiDo());
+    EVALUATORS.put(ParDo.BoundMulti.class, parDo());
     EVALUATORS.put(GroupByKey.class, groupByKey());
     EVALUATORS.put(Combine.GroupedValues.class, combineGrouped());
     EVALUATORS.put(Combine.Globally.class, combineGlobally());

http://git-wip-us.apache.org/repos/asf/beam/blob/c6cad209/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java
index 2744169..25fecf6 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java
@@ -366,62 +366,11 @@ public final class StreamingTransformTranslator {
     };
   }
 
-  private static <InputT, OutputT> TransformEvaluator<ParDo.Bound<InputT, OutputT>> parDo() {
-    return new TransformEvaluator<ParDo.Bound<InputT, OutputT>>() {
-      @Override
-      public void evaluate(final ParDo.Bound<InputT, OutputT> transform,
-                           final EvaluationContext context) {
-        final DoFn<InputT, OutputT> doFn = transform.getFn();
-        rejectSplittable(doFn);
-        rejectStateAndTimers(doFn);
-        final SparkRuntimeContext runtimeContext = context.getRuntimeContext();
-        final WindowingStrategy<?, ?> windowingStrategy =
-            context.getInput(transform).getWindowingStrategy();
-        final SparkPCollectionView pviews = context.getPViews();
-
-        @SuppressWarnings("unchecked")
-        UnboundedDataset<InputT> unboundedDataset =
-            ((UnboundedDataset<InputT>) context.borrowDataset(transform));
-        JavaDStream<WindowedValue<InputT>> dStream = unboundedDataset.getDStream();
-
-        final String stepName = context.getCurrentTransform().getFullName();
-
-        JavaDStream<WindowedValue<OutputT>> outStream =
-            dStream.transform(new Function<JavaRDD<WindowedValue<InputT>>,
-                JavaRDD<WindowedValue<OutputT>>>() {
-          @Override
-          public JavaRDD<WindowedValue<OutputT>> call(JavaRDD<WindowedValue<InputT>> rdd) throws
-              Exception {
-            final JavaSparkContext jsc = new JavaSparkContext(rdd.context());
-            final Accumulator<NamedAggregators> aggAccum = AggregatorsAccumulator.getInstance();
-            final Accumulator<SparkMetricsContainer> metricsAccum =
-                MetricsAccumulator.getInstance();
-            final Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, SideInputBroadcast<?>>> sideInputs =
-                TranslationUtils.getSideInputs(transform.getSideInputs(),
-                    jsc, pviews);
-            return rdd.mapPartitions(
-                new DoFnFunction<>(aggAccum, metricsAccum, stepName, doFn, runtimeContext,
-                    sideInputs, windowingStrategy));
-          }
-        });
-
-        context.putDataset(transform,
-            new UnboundedDataset<>(outStream, unboundedDataset.getStreamSources()));
-      }
-
-      @Override
-      public String toNativeString() {
-        return "mapPartitions(new <fn>())";
-      }
-    };
-  }
-
   private static <InputT, OutputT> TransformEvaluator<ParDo.BoundMulti<InputT, OutputT>>
   multiDo() {
     return new TransformEvaluator<ParDo.BoundMulti<InputT, OutputT>>() {
-      @Override
-      public void evaluate(final ParDo.BoundMulti<InputT, OutputT> transform,
-                           final EvaluationContext context) {
+      public void evaluate(
+          final ParDo.BoundMulti<InputT, OutputT> transform, final EvaluationContext context) {
         final DoFn<InputT, OutputT> doFn = transform.getFn();
         rejectSplittable(doFn);
         rejectStateAndTimers(doFn);
@@ -435,36 +384,90 @@ public final class StreamingTransformTranslator {
             ((UnboundedDataset<InputT>) context.borrowDataset(transform));
         JavaDStream<WindowedValue<InputT>> dStream = unboundedDataset.getDStream();
 
-        JavaPairDStream<TupleTag<?>, WindowedValue<?>> all = dStream.transformToPair(
-            new Function<JavaRDD<WindowedValue<InputT>>,
-                JavaPairRDD<TupleTag<?>, WindowedValue<?>>>() {
-          @Override
-          public JavaPairRDD<TupleTag<?>, WindowedValue<?>> call(
-              JavaRDD<WindowedValue<InputT>> rdd) throws Exception {
-            String stepName = context.getCurrentTransform().getFullName();
-            final Accumulator<NamedAggregators> aggAccum = AggregatorsAccumulator.getInstance();
-            final Accumulator<SparkMetricsContainer> metricsAccum =
-                MetricsAccumulator.getInstance();
-            final Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, SideInputBroadcast<?>>> sideInputs =
-                TranslationUtils.getSideInputs(transform.getSideInputs(),
-                    JavaSparkContext.fromSparkContext(rdd.context()), pviews);
-              return rdd.mapPartitionsToPair(new MultiDoFnFunction<>(aggAccum, metricsAccum,
-                  stepName, doFn, runtimeContext, transform.getMainOutputTag(), sideInputs,
-                  windowingStrategy));
+        final String stepName = context.getCurrentTransform().getFullName();
+        if (transform.getSideOutputTags().size() == 0) {
+          // Don't tag with the output and filter for a single-output ParDo, as it's additional
+          // identity transforms.
+          // Also see BEAM-1737 for failures when the two versions are condensed.
+          JavaDStream<WindowedValue<OutputT>> outStream =
+              dStream.transform(
+                  new Function<JavaRDD<WindowedValue<InputT>>, JavaRDD<WindowedValue<OutputT>>>() {
+                    @Override
+                    public JavaRDD<WindowedValue<OutputT>> call(JavaRDD<WindowedValue<InputT>> rdd)
+                        throws Exception {
+                      final JavaSparkContext jsc = new JavaSparkContext(rdd.context());
+                      final Accumulator<NamedAggregators> aggAccum =
+                          AggregatorsAccumulator.getInstance();
+                      final Accumulator<SparkMetricsContainer> metricsAccum =
+                          MetricsAccumulator.getInstance();
+                      final Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, SideInputBroadcast<?>>>
+                          sideInputs =
+                              TranslationUtils.getSideInputs(
+                                  transform.getSideInputs(), jsc, pviews);
+                      return rdd.mapPartitions(
+                          new DoFnFunction<>(
+                              aggAccum,
+                              metricsAccum,
+                              stepName,
+                              doFn,
+                              runtimeContext,
+                              sideInputs,
+                              windowingStrategy));
+                    }
+                  });
+
+          PCollection<OutputT> output =
+              (PCollection<OutputT>)
+                  Iterables.getOnlyElement(context.getOutputs(transform)).getValue();
+          context.putDataset(
+              output, new UnboundedDataset<>(outStream, unboundedDataset.getStreamSources()));
+        } else {
+          JavaPairDStream<TupleTag<?>, WindowedValue<?>> all =
+              dStream
+                  .transformToPair(
+                      new Function<
+                          JavaRDD<WindowedValue<InputT>>,
+                          JavaPairRDD<TupleTag<?>, WindowedValue<?>>>() {
+                        @Override
+                        public JavaPairRDD<TupleTag<?>, WindowedValue<?>> call(
+                            JavaRDD<WindowedValue<InputT>> rdd) throws Exception {
+                          String stepName = context.getCurrentTransform().getFullName();
+                          final Accumulator<NamedAggregators> aggAccum =
+                              AggregatorsAccumulator.getInstance();
+                          final Accumulator<SparkMetricsContainer> metricsAccum =
+                              MetricsAccumulator.getInstance();
+                          final Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, SideInputBroadcast<?>>>
+                              sideInputs =
+                                  TranslationUtils.getSideInputs(
+                                      transform.getSideInputs(),
+                                      JavaSparkContext.fromSparkContext(rdd.context()),
+                                      pviews);
+                          return rdd.mapPartitionsToPair(
+                              new MultiDoFnFunction<>(
+                                  aggAccum,
+                                  metricsAccum,
+                                  stepName,
+                                  doFn,
+                                  runtimeContext,
+                                  transform.getMainOutputTag(),
+                                  sideInputs,
+                                  windowingStrategy));
+                        }
+                      })
+                  .cache();
+          for (TaggedPValue output : context.getOutputs(transform)) {
+            @SuppressWarnings("unchecked")
+            JavaPairDStream<TupleTag<?>, WindowedValue<?>> filtered =
+                all.filter(new TranslationUtils.TupleTagFilter(output.getTag()));
+            @SuppressWarnings("unchecked")
+            // Object is the best we can do since different outputs can have different tags
+            JavaDStream<WindowedValue<Object>> values =
+                (JavaDStream<WindowedValue<Object>>)
+                    (JavaDStream<?>) TranslationUtils.dStreamValues(filtered);
+            context.putDataset(
+                output.getValue(),
+                new UnboundedDataset<>(values, unboundedDataset.getStreamSources()));
           }
-        }).cache();
-        List<TaggedPValue> pct = context.getOutputs(transform);
-        for (TaggedPValue e : pct) {
-          @SuppressWarnings("unchecked")
-          JavaPairDStream<TupleTag<?>, WindowedValue<?>> filtered =
-              all.filter(new TranslationUtils.TupleTagFilter(e.getTag()));
-          @SuppressWarnings("unchecked")
-          // Object is the best we can do since different outputs can have different tags
-          JavaDStream<WindowedValue<Object>> values =
-              (JavaDStream<WindowedValue<Object>>)
-                  (JavaDStream<?>) TranslationUtils.dStreamValues(filtered);
-          context.putDataset(e.getValue(),
-              new UnboundedDataset<>(values, unboundedDataset.getStreamSources()));
         }
       }
 
@@ -520,7 +523,6 @@ public final class StreamingTransformTranslator {
     EVALUATORS.put(Read.Unbounded.class, readUnbounded());
     EVALUATORS.put(GroupByKey.class, groupByKey());
     EVALUATORS.put(Combine.GroupedValues.class, combineGrouped());
-    EVALUATORS.put(ParDo.Bound.class, parDo());
     EVALUATORS.put(ParDo.BoundMulti.class, multiDo());
     EVALUATORS.put(ConsoleIO.Write.Unbound.class, print());
     EVALUATORS.put(CreateStream.class, createFromQueue());

http://git-wip-us.apache.org/repos/asf/beam/blob/c6cad209/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/TrackStreamingSourcesTest.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/TrackStreamingSourcesTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/TrackStreamingSourcesTest.java
index b181a04..d66633b 100644
--- a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/TrackStreamingSourcesTest.java
+++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/TrackStreamingSourcesTest.java
@@ -83,7 +83,7 @@ public class TrackStreamingSourcesTest {
 
     p.apply(emptyStream).apply(ParDo.of(new PassthroughFn<>()));
 
-    p.traverseTopologically(new StreamingSourceTracker(jssc, p, ParDo.Bound.class,  0));
+    p.traverseTopologically(new StreamingSourceTracker(jssc, p, ParDo.BoundMulti.class,  0));
     assertThat(StreamingSourceTracker.numAssertions, equalTo(1));
   }
 
@@ -111,7 +111,7 @@ public class TrackStreamingSourcesTest {
         PCollectionList.of(pcol1).and(pcol2).apply(Flatten.<Integer>pCollections());
     flattened.apply(ParDo.of(new PassthroughFn<>()));
 
-    p.traverseTopologically(new StreamingSourceTracker(jssc, p, ParDo.Bound.class, 0, 1));
+    p.traverseTopologically(new StreamingSourceTracker(jssc, p, ParDo.BoundMulti.class, 0, 1));
     assertThat(StreamingSourceTracker.numAssertions, equalTo(1));
   }
 

http://git-wip-us.apache.org/repos/asf/beam/blob/c6cad209/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java
index 19c5a2d..9225231 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java
@@ -738,12 +738,8 @@ public class ParDo {
 
     @Override
     public PCollection<OutputT> expand(PCollection<? extends InputT> input) {
-      validateWindowType(input, fn);
-      return PCollection.<OutputT>createPrimitiveOutputInternal(
-              input.getPipeline(),
-              input.getWindowingStrategy(),
-              input.isBounded())
-          .setTypeDescriptor(getFn().getOutputTypeDescriptor());
+      TupleTag<OutputT> mainOutput = new TupleTag<>();
+      return input.apply(withOutputTags(mainOutput, TupleTagList.empty())).get(mainOutput);
     }
 
     @Override