You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by tg...@apache.org on 2017/01/24 00:08:41 UTC

[2/3] beam git commit: Always expand in AppliedPTransform

http://git-wip-us.apache.org/repos/asf/beam/blob/7b062d71/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java
index 3e941e4..fa5ae95 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java
@@ -18,6 +18,7 @@
 
 package org.apache.beam.runners.spark.translation;
 
+import static com.google.common.base.Preconditions.checkArgument;
 import static com.google.common.base.Preconditions.checkState;
 import static org.apache.beam.runners.spark.io.hadoop.ShardNameBuilder.getOutputDirectory;
 import static org.apache.beam.runners.spark.io.hadoop.ShardNameBuilder.getOutputFilePrefix;
@@ -28,6 +29,7 @@ import static org.apache.beam.runners.spark.translation.TranslationUtils.rejectS
 import com.google.common.collect.Maps;
 import java.io.IOException;
 import java.util.Collections;
+import java.util.List;
 import java.util.Map;
 import org.apache.avro.mapred.AvroKey;
 import org.apache.avro.mapreduce.AvroJob;
@@ -63,9 +65,8 @@ import org.apache.beam.sdk.util.WindowedValue;
 import org.apache.beam.sdk.util.WindowingStrategy;
 import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PCollection;
-import org.apache.beam.sdk.values.PCollectionList;
-import org.apache.beam.sdk.values.PCollectionTuple;
 import org.apache.beam.sdk.values.PCollectionView;
+import org.apache.beam.sdk.values.TaggedPValue;
 import org.apache.beam.sdk.values.TupleTag;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.io.NullWritable;
@@ -94,14 +95,19 @@ public final class TransformTranslator {
       @SuppressWarnings("unchecked")
       @Override
       public void evaluate(Flatten.FlattenPCollectionList<T> transform, EvaluationContext context) {
-        PCollectionList<T> pcs = context.getInput(transform);
+        List<TaggedPValue> pcs = context.getInputs(transform);
         JavaRDD<WindowedValue<T>> unionRDD;
         if (pcs.size() == 0) {
           unionRDD = context.getSparkContext().emptyRDD();
         } else {
           JavaRDD<WindowedValue<T>>[] rdds = new JavaRDD[pcs.size()];
           for (int i = 0; i < rdds.length; i++) {
-            rdds[i] = ((BoundedDataset<T>) context.borrowDataset(pcs.get(i))).getRDD();
+            checkArgument(
+                pcs.get(i).getValue() instanceof PCollection,
+                "Flatten had non-PCollection value in input: %s of type %s",
+                pcs.get(i).getValue(),
+                pcs.get(i).getValue().getClass().getSimpleName());
+            rdds[i] = ((BoundedDataset<T>) context.borrowDataset(pcs.get(i).getValue())).getRDD();
           }
           unionRDD = context.getSparkContext().union(rdds);
         }
@@ -124,9 +130,15 @@ public final class TransformTranslator {
         final Accumulator<NamedAggregators> accum =
             SparkAggregators.getNamedAggregators(context.getSparkContext());
 
-        context.putDataset(transform,
-            new BoundedDataset<>(GroupCombineFunctions.groupByKey(inRDD, accum, coder,
-                context.getRuntimeContext(), context.getInput(transform).getWindowingStrategy())));
+        context.putDataset(
+            transform,
+            new BoundedDataset<>(
+                GroupCombineFunctions.groupByKey(
+                    inRDD,
+                    accum,
+                    coder,
+                    context.getRuntimeContext(),
+                    context.getInput(transform).getWindowingStrategy())));
       }
     };
   }
@@ -265,11 +277,11 @@ public final class TransformTranslator {
                 new MultiDoFnFunction<>(accum, doFn, context.getRuntimeContext(),
                 transform.getMainOutputTag(), TranslationUtils.getSideInputs(
                     transform.getSideInputs(), context), windowingStrategy)).cache();
-        PCollectionTuple pct = context.getOutput(transform);
-        for (Map.Entry<TupleTag<?>, PCollection<?>> e : pct.getAll().entrySet()) {
+        List<TaggedPValue> pct = context.getOutputs(transform);
+        for (TaggedPValue e : pct) {
           @SuppressWarnings("unchecked")
           JavaPairRDD<TupleTag<?>, WindowedValue<?>> filtered =
-              all.filter(new TranslationUtils.TupleTagFilter(e.getKey()));
+              all.filter(new TranslationUtils.TupleTagFilter(e.getTag()));
           @SuppressWarnings("unchecked")
           // Object is the best we can do since different outputs can have different tags
           JavaRDD<WindowedValue<Object>> values =
@@ -529,7 +541,7 @@ public final class TransformTranslator {
       @Override
       public void evaluate(View.AsSingleton<T> transform, EvaluationContext context) {
         Iterable<? extends WindowedValue<?>> iter =
-            context.getWindowedValues(context.getInput(transform));
+        context.getWindowedValues(context.getInput(transform));
         PCollectionView<T> output = context.getOutput(transform);
         Coder<Iterable<WindowedValue<?>>> coderInternal = output.getCoderInternal();
 

http://git-wip-us.apache.org/repos/asf/beam/blob/7b062d71/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java
index 3c89b99..a2a1d3b 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java
@@ -17,6 +17,7 @@
  */
 package org.apache.beam.runners.spark.translation.streaming;
 
+import static com.google.common.base.Preconditions.checkArgument;
 import static com.google.common.base.Preconditions.checkState;
 import static org.apache.beam.runners.spark.translation.TranslationUtils.rejectStateAndTimers;
 
@@ -64,8 +65,7 @@ import org.apache.beam.sdk.util.WindowedValue;
 import org.apache.beam.sdk.util.WindowingStrategy;
 import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PCollection;
-import org.apache.beam.sdk.values.PCollectionList;
-import org.apache.beam.sdk.values.PCollectionTuple;
+import org.apache.beam.sdk.values.TaggedPValue;
 import org.apache.beam.sdk.values.TupleTag;
 import org.apache.spark.Accumulator;
 import org.apache.spark.api.java.JavaPairRDD;
@@ -125,14 +125,20 @@ final class StreamingTransformTranslator {
       @SuppressWarnings("unchecked")
       @Override
       public void evaluate(Flatten.FlattenPCollectionList<T> transform, EvaluationContext context) {
-        PCollectionList<T> pcs = context.getInput(transform);
+        List<TaggedPValue> pcs = context.getInputs(transform);
         // since this is a streaming pipeline, at least one of the PCollections to "flatten" are
         // unbounded, meaning it represents a DStream.
         // So we could end up with an unbounded unified DStream.
         final List<JavaRDD<WindowedValue<T>>> rdds = new ArrayList<>();
         final List<JavaDStream<WindowedValue<T>>> dStreams = new ArrayList<>();
-        for (PCollection<T> pcol : pcs.getAll()) {
-         Dataset dataset = context.borrowDataset(pcol);
+        for (TaggedPValue pv : pcs) {
+          checkArgument(
+              pv.getValue() instanceof PCollection,
+              "Flatten had non-PCollection value in input: %s of type %s",
+              pv.getValue(),
+              pv.getValue().getClass().getSimpleName());
+          PCollection<T> pcol = (PCollection<T>) pv.getValue();
+          Dataset dataset = context.borrowDataset(pcol);
           if (dataset instanceof UnboundedDataset) {
             dStreams.add(((UnboundedDataset<T>) dataset).getDStream());
           } else {
@@ -144,14 +150,15 @@ final class StreamingTransformTranslator {
             context.getStreamingContext().union(dStreams.remove(0), dStreams);
         // now unify in RDDs.
         if (rdds.size() > 0) {
-          JavaDStream<WindowedValue<T>> joined = unifiedStreams.transform(
-              new Function<JavaRDD<WindowedValue<T>>, JavaRDD<WindowedValue<T>>>() {
-            @Override
-            public JavaRDD<WindowedValue<T>> call(JavaRDD<WindowedValue<T>> streamRdd)
-                throws Exception {
-              return new JavaSparkContext(streamRdd.context()).union(streamRdd, rdds);
-            }
-          });
+          JavaDStream<WindowedValue<T>> joined =
+              unifiedStreams.transform(
+                  new Function<JavaRDD<WindowedValue<T>>, JavaRDD<WindowedValue<T>>>() {
+                    @Override
+                    public JavaRDD<WindowedValue<T>> call(JavaRDD<WindowedValue<T>> streamRdd)
+                        throws Exception {
+                      return new JavaSparkContext(streamRdd.context()).union(streamRdd, rdds);
+                    }
+                  });
           context.putDataset(transform, new UnboundedDataset<>(joined));
         } else {
           context.putDataset(transform, new UnboundedDataset<>(unifiedStreams));
@@ -284,8 +291,9 @@ final class StreamingTransformTranslator {
 
       @SuppressWarnings("unchecked")
       @Override
-      public void evaluate(final Combine.Globally<InputT, OutputT> transform,
-                           EvaluationContext context) {
+      public void evaluate(
+          final Combine.Globally<InputT, OutputT> transform,
+          EvaluationContext context) {
         final PCollection<InputT> input = context.getInput(transform);
         // serializable arguments to pass.
         final Coder<InputT> iCoder = context.getInput(transform).getCoder();
@@ -372,7 +380,6 @@ final class StreamingTransformTranslator {
         final WindowingStrategy<?, ?> windowingStrategy =
             context.getInput(transform).getWindowingStrategy();
         final SparkPCollectionView pviews = context.getPViews();
-
         JavaDStream<WindowedValue<InputT>> dStream =
             ((UnboundedDataset<InputT>) context.borrowDataset(transform)).getDStream();
 
@@ -431,11 +438,11 @@ final class StreamingTransformTranslator {
                   runtimeContext, transform.getMainOutputTag(), sideInputs, windowingStrategy));
           }
         }).cache();
-        PCollectionTuple pct = context.getOutput(transform);
-        for (Map.Entry<TupleTag<?>, PCollection<?>> e : pct.getAll().entrySet()) {
+        List<TaggedPValue> pct = context.getOutputs(transform);
+        for (TaggedPValue e : pct) {
           @SuppressWarnings("unchecked")
           JavaPairDStream<TupleTag<?>, WindowedValue<?>> filtered =
-              all.filter(new TranslationUtils.TupleTagFilter(e.getKey()));
+              all.filter(new TranslationUtils.TupleTagFilter(e.getTag()));
           @SuppressWarnings("unchecked")
           // Object is the best we can do since different outputs can have different tags
           JavaDStream<WindowedValue<Object>> values =

http://git-wip-us.apache.org/repos/asf/beam/blob/7b062d71/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AppliedPTransform.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AppliedPTransform.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AppliedPTransform.java
index 77de54a..a6d8859 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AppliedPTransform.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AppliedPTransform.java
@@ -18,8 +18,11 @@
 package org.apache.beam.sdk.transforms;
 
 import com.google.auto.value.AutoValue;
+import java.util.List;
+import org.apache.beam.sdk.Pipeline;
 import org.apache.beam.sdk.values.PInput;
 import org.apache.beam.sdk.values.POutput;
+import org.apache.beam.sdk.values.TaggedPValue;
 
 /**
  * Represents the application of a {@link PTransform} to a specific input to produce
@@ -43,14 +46,16 @@ public abstract class AppliedPTransform
       AppliedPTransform<InputT, OutputT, TransformT> of(
           String fullName, InputT input, OutputT output, TransformT transform) {
     return new AutoValue_AppliedPTransform<InputT, OutputT, TransformT>(
-        fullName, input, output, transform);
+        fullName, input.expand(), output.expand(), transform, input.getPipeline());
   }
 
   public abstract String getFullName();
 
-  public abstract InputT getInput();
+  public abstract List<TaggedPValue> getInputs();
 
-  public abstract OutputT getOutput();
+  public abstract List<TaggedPValue> getOutputs();
 
   public abstract TransformT getTransform();
+
+  public abstract Pipeline getPipeline();
 }