You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by tg...@apache.org on 2017/01/24 00:08:41 UTC
[2/3] beam git commit: Always expand in AppliedPTransform
http://git-wip-us.apache.org/repos/asf/beam/blob/7b062d71/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java
index 3e941e4..fa5ae95 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java
@@ -18,6 +18,7 @@
package org.apache.beam.runners.spark.translation;
+import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static org.apache.beam.runners.spark.io.hadoop.ShardNameBuilder.getOutputDirectory;
import static org.apache.beam.runners.spark.io.hadoop.ShardNameBuilder.getOutputFilePrefix;
@@ -28,6 +29,7 @@ import static org.apache.beam.runners.spark.translation.TranslationUtils.rejectS
import com.google.common.collect.Maps;
import java.io.IOException;
import java.util.Collections;
+import java.util.List;
import java.util.Map;
import org.apache.avro.mapred.AvroKey;
import org.apache.avro.mapreduce.AvroJob;
@@ -63,9 +65,8 @@ import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.util.WindowingStrategy;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
-import org.apache.beam.sdk.values.PCollectionList;
-import org.apache.beam.sdk.values.PCollectionTuple;
import org.apache.beam.sdk.values.PCollectionView;
+import org.apache.beam.sdk.values.TaggedPValue;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.io.NullWritable;
@@ -94,14 +95,19 @@ public final class TransformTranslator {
@SuppressWarnings("unchecked")
@Override
public void evaluate(Flatten.FlattenPCollectionList<T> transform, EvaluationContext context) {
- PCollectionList<T> pcs = context.getInput(transform);
+ List<TaggedPValue> pcs = context.getInputs(transform);
JavaRDD<WindowedValue<T>> unionRDD;
if (pcs.size() == 0) {
unionRDD = context.getSparkContext().emptyRDD();
} else {
JavaRDD<WindowedValue<T>>[] rdds = new JavaRDD[pcs.size()];
for (int i = 0; i < rdds.length; i++) {
- rdds[i] = ((BoundedDataset<T>) context.borrowDataset(pcs.get(i))).getRDD();
+ checkArgument(
+ pcs.get(i).getValue() instanceof PCollection,
+ "Flatten had non-PCollection value in input: %s of type %s",
+ pcs.get(i).getValue(),
+ pcs.get(i).getValue().getClass().getSimpleName());
+ rdds[i] = ((BoundedDataset<T>) context.borrowDataset(pcs.get(i).getValue())).getRDD();
}
unionRDD = context.getSparkContext().union(rdds);
}
@@ -124,9 +130,15 @@ public final class TransformTranslator {
final Accumulator<NamedAggregators> accum =
SparkAggregators.getNamedAggregators(context.getSparkContext());
- context.putDataset(transform,
- new BoundedDataset<>(GroupCombineFunctions.groupByKey(inRDD, accum, coder,
- context.getRuntimeContext(), context.getInput(transform).getWindowingStrategy())));
+ context.putDataset(
+ transform,
+ new BoundedDataset<>(
+ GroupCombineFunctions.groupByKey(
+ inRDD,
+ accum,
+ coder,
+ context.getRuntimeContext(),
+ context.getInput(transform).getWindowingStrategy())));
}
};
}
@@ -265,11 +277,11 @@ public final class TransformTranslator {
new MultiDoFnFunction<>(accum, doFn, context.getRuntimeContext(),
transform.getMainOutputTag(), TranslationUtils.getSideInputs(
transform.getSideInputs(), context), windowingStrategy)).cache();
- PCollectionTuple pct = context.getOutput(transform);
- for (Map.Entry<TupleTag<?>, PCollection<?>> e : pct.getAll().entrySet()) {
+ List<TaggedPValue> pct = context.getOutputs(transform);
+ for (TaggedPValue e : pct) {
@SuppressWarnings("unchecked")
JavaPairRDD<TupleTag<?>, WindowedValue<?>> filtered =
- all.filter(new TranslationUtils.TupleTagFilter(e.getKey()));
+ all.filter(new TranslationUtils.TupleTagFilter(e.getTag()));
@SuppressWarnings("unchecked")
// Object is the best we can do since different outputs can have different tags
JavaRDD<WindowedValue<Object>> values =
@@ -529,7 +541,7 @@ public final class TransformTranslator {
@Override
public void evaluate(View.AsSingleton<T> transform, EvaluationContext context) {
Iterable<? extends WindowedValue<?>> iter =
- context.getWindowedValues(context.getInput(transform));
+ context.getWindowedValues(context.getInput(transform));
PCollectionView<T> output = context.getOutput(transform);
Coder<Iterable<WindowedValue<?>>> coderInternal = output.getCoderInternal();
http://git-wip-us.apache.org/repos/asf/beam/blob/7b062d71/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java
index 3c89b99..a2a1d3b 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java
@@ -17,6 +17,7 @@
*/
package org.apache.beam.runners.spark.translation.streaming;
+import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static org.apache.beam.runners.spark.translation.TranslationUtils.rejectStateAndTimers;
@@ -64,8 +65,7 @@ import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.util.WindowingStrategy;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
-import org.apache.beam.sdk.values.PCollectionList;
-import org.apache.beam.sdk.values.PCollectionTuple;
+import org.apache.beam.sdk.values.TaggedPValue;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.spark.Accumulator;
import org.apache.spark.api.java.JavaPairRDD;
@@ -125,14 +125,20 @@ final class StreamingTransformTranslator {
@SuppressWarnings("unchecked")
@Override
public void evaluate(Flatten.FlattenPCollectionList<T> transform, EvaluationContext context) {
- PCollectionList<T> pcs = context.getInput(transform);
+ List<TaggedPValue> pcs = context.getInputs(transform);
// since this is a streaming pipeline, at least one of the PCollections to "flatten" are
// unbounded, meaning it represents a DStream.
// So we could end up with an unbounded unified DStream.
final List<JavaRDD<WindowedValue<T>>> rdds = new ArrayList<>();
final List<JavaDStream<WindowedValue<T>>> dStreams = new ArrayList<>();
- for (PCollection<T> pcol : pcs.getAll()) {
- Dataset dataset = context.borrowDataset(pcol);
+ for (TaggedPValue pv : pcs) {
+ checkArgument(
+ pv.getValue() instanceof PCollection,
+ "Flatten had non-PCollection value in input: %s of type %s",
+ pv.getValue(),
+ pv.getValue().getClass().getSimpleName());
+ PCollection<T> pcol = (PCollection<T>) pv.getValue();
+ Dataset dataset = context.borrowDataset(pcol);
if (dataset instanceof UnboundedDataset) {
dStreams.add(((UnboundedDataset<T>) dataset).getDStream());
} else {
@@ -144,14 +150,15 @@ final class StreamingTransformTranslator {
context.getStreamingContext().union(dStreams.remove(0), dStreams);
// now unify in RDDs.
if (rdds.size() > 0) {
- JavaDStream<WindowedValue<T>> joined = unifiedStreams.transform(
- new Function<JavaRDD<WindowedValue<T>>, JavaRDD<WindowedValue<T>>>() {
- @Override
- public JavaRDD<WindowedValue<T>> call(JavaRDD<WindowedValue<T>> streamRdd)
- throws Exception {
- return new JavaSparkContext(streamRdd.context()).union(streamRdd, rdds);
- }
- });
+ JavaDStream<WindowedValue<T>> joined =
+ unifiedStreams.transform(
+ new Function<JavaRDD<WindowedValue<T>>, JavaRDD<WindowedValue<T>>>() {
+ @Override
+ public JavaRDD<WindowedValue<T>> call(JavaRDD<WindowedValue<T>> streamRdd)
+ throws Exception {
+ return new JavaSparkContext(streamRdd.context()).union(streamRdd, rdds);
+ }
+ });
context.putDataset(transform, new UnboundedDataset<>(joined));
} else {
context.putDataset(transform, new UnboundedDataset<>(unifiedStreams));
@@ -284,8 +291,9 @@ final class StreamingTransformTranslator {
@SuppressWarnings("unchecked")
@Override
- public void evaluate(final Combine.Globally<InputT, OutputT> transform,
- EvaluationContext context) {
+ public void evaluate(
+ final Combine.Globally<InputT, OutputT> transform,
+ EvaluationContext context) {
final PCollection<InputT> input = context.getInput(transform);
// serializable arguments to pass.
final Coder<InputT> iCoder = context.getInput(transform).getCoder();
@@ -372,7 +380,6 @@ final class StreamingTransformTranslator {
final WindowingStrategy<?, ?> windowingStrategy =
context.getInput(transform).getWindowingStrategy();
final SparkPCollectionView pviews = context.getPViews();
-
JavaDStream<WindowedValue<InputT>> dStream =
((UnboundedDataset<InputT>) context.borrowDataset(transform)).getDStream();
@@ -431,11 +438,11 @@ final class StreamingTransformTranslator {
runtimeContext, transform.getMainOutputTag(), sideInputs, windowingStrategy));
}
}).cache();
- PCollectionTuple pct = context.getOutput(transform);
- for (Map.Entry<TupleTag<?>, PCollection<?>> e : pct.getAll().entrySet()) {
+ List<TaggedPValue> pct = context.getOutputs(transform);
+ for (TaggedPValue e : pct) {
@SuppressWarnings("unchecked")
JavaPairDStream<TupleTag<?>, WindowedValue<?>> filtered =
- all.filter(new TranslationUtils.TupleTagFilter(e.getKey()));
+ all.filter(new TranslationUtils.TupleTagFilter(e.getTag()));
@SuppressWarnings("unchecked")
// Object is the best we can do since different outputs can have different tags
JavaDStream<WindowedValue<Object>> values =
http://git-wip-us.apache.org/repos/asf/beam/blob/7b062d71/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AppliedPTransform.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AppliedPTransform.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AppliedPTransform.java
index 77de54a..a6d8859 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AppliedPTransform.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AppliedPTransform.java
@@ -18,8 +18,11 @@
package org.apache.beam.sdk.transforms;
import com.google.auto.value.AutoValue;
+import java.util.List;
+import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.values.PInput;
import org.apache.beam.sdk.values.POutput;
+import org.apache.beam.sdk.values.TaggedPValue;
/**
* Represents the application of a {@link PTransform} to a specific input to produce
@@ -43,14 +46,16 @@ public abstract class AppliedPTransform
AppliedPTransform<InputT, OutputT, TransformT> of(
String fullName, InputT input, OutputT output, TransformT transform) {
return new AutoValue_AppliedPTransform<InputT, OutputT, TransformT>(
- fullName, input, output, transform);
+ fullName, input.expand(), output.expand(), transform, input.getPipeline());
}
public abstract String getFullName();
- public abstract InputT getInput();
+ public abstract List<TaggedPValue> getInputs();
- public abstract OutputT getOutput();
+ public abstract List<TaggedPValue> getOutputs();
public abstract TransformT getTransform();
+
+ public abstract Pipeline getPipeline();
}