You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@nemo.apache.org by jo...@apache.org on 2018/08/17 05:00:05 UTC
[incubator-nemo] branch master updated: [NEMO-183] DAG-centric
translation from Beam pipeline to IR DAG (#104)
This is an automated email from the ASF dual-hosted git repository.
johnyangk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-nemo.git
The following commit(s) were added to refs/heads/master by this push:
new 4e09f9a [NEMO-183] DAG-centric translation from Beam pipeline to IR DAG (#104)
4e09f9a is described below
commit 4e09f9a7be603e5e57cf3f10f3526ba9e0d1662e
Author: Jangho Seo <ja...@jangho.io>
AuthorDate: Fri Aug 17 13:59:58 2018 +0900
[NEMO-183] DAG-centric translation from Beam pipeline to IR DAG (#104)
JIRA: [NEMO-183: DAG-centric translation from Beam pipeline to IR DAG](https://issues.apache.org/jira/projects/NEMO/issues/NEMO-183)
**Major changes:**
- Implemented DAG-centric translation from Beam pipeline to Nemo IR, which consists to the two phases:
- **PipelineVisitor** traverses through the given Beam pipeline to construct DAG of Beam transforms, while preserving the hierarchy of CompositeTransforms.
- **PipelineTranslator** defines not only mappings between PrimitiveTransform and IRVertex, but also correspondneces between CompositeTransform and TranslationContext, based on which PipelineTranslator can tune translation behavior.
**Minor changes to note:**
- N/A
**Tests for the changes:**
- Modified existing tests to match the topological ordering produced by the new visitor.
**Other comments:**
- N/A
resolves [NEMO-183](https://issues.apache.org/jira/projects/NEMO/issues/NEMO-183)
---
.../compiler/frontend/beam/NemoPipelineRunner.java | 12 +-
.../frontend/beam/NemoPipelineVisitor.java | 304 ------------
.../compiler/frontend/beam/PipelineTranslator.java | 544 +++++++++++++++++++++
.../compiler/frontend/beam/PipelineVisitor.java | 296 +++++++++++
.../frontend/beam/BeamFrontendALSTest.java | 26 +-
.../frontend/beam/BeamFrontendMLRTest.java | 26 +-
.../TransientResourceCompositePassTest.java | 30 +-
.../reshaping/LoopExtractionPassTest.java | 2 +-
.../LoopInvariantCodeMotionALSInefficientTest.java | 2 +-
.../reshaping/LoopInvariantCodeMotionPassTest.java | 26 +-
.../java/edu/snu/nemo/examples/beam/WordCount.java | 3 +-
11 files changed, 904 insertions(+), 367 deletions(-)
diff --git a/compiler/frontend/beam/src/main/java/edu/snu/nemo/compiler/frontend/beam/NemoPipelineRunner.java b/compiler/frontend/beam/src/main/java/edu/snu/nemo/compiler/frontend/beam/NemoPipelineRunner.java
index 3ced565..2342dc2 100644
--- a/compiler/frontend/beam/src/main/java/edu/snu/nemo/compiler/frontend/beam/NemoPipelineRunner.java
+++ b/compiler/frontend/beam/src/main/java/edu/snu/nemo/compiler/frontend/beam/NemoPipelineRunner.java
@@ -17,7 +17,8 @@ package edu.snu.nemo.compiler.frontend.beam;
import edu.snu.nemo.client.JobLauncher;
import edu.snu.nemo.common.dag.DAG;
-import edu.snu.nemo.common.dag.DAGBuilder;
+import edu.snu.nemo.common.ir.edge.IREdge;
+import edu.snu.nemo.common.ir.vertex.IRVertex;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.PipelineRunner;
import org.apache.beam.sdk.options.PipelineOptions;
@@ -53,10 +54,11 @@ public final class NemoPipelineRunner extends PipelineRunner<NemoPipelineResult>
* @return The result of the pipeline.
*/
public NemoPipelineResult run(final Pipeline pipeline) {
- final DAGBuilder builder = new DAGBuilder<>();
- final NemoPipelineVisitor nemoPipelineVisitor = new NemoPipelineVisitor(builder, nemoPipelineOptions);
- pipeline.traverseTopologically(nemoPipelineVisitor);
- final DAG dag = builder.build();
+ final PipelineVisitor pipelineVisitor = new PipelineVisitor();
+ pipeline.traverseTopologically(pipelineVisitor);
+ final DAG<IRVertex, IREdge> dag = PipelineTranslator.translate(pipelineVisitor.getConvertedPipeline(),
+ nemoPipelineOptions);
+
final NemoPipelineResult nemoPipelineResult = new NemoPipelineResult();
JobLauncher.launchDAG(dag);
return nemoPipelineResult;
diff --git a/compiler/frontend/beam/src/main/java/edu/snu/nemo/compiler/frontend/beam/NemoPipelineVisitor.java b/compiler/frontend/beam/src/main/java/edu/snu/nemo/compiler/frontend/beam/NemoPipelineVisitor.java
deleted file mode 100644
index 38b07a3..0000000
--- a/compiler/frontend/beam/src/main/java/edu/snu/nemo/compiler/frontend/beam/NemoPipelineVisitor.java
+++ /dev/null
@@ -1,304 +0,0 @@
-/*
- * Copyright (C) 2018 Seoul National University
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package edu.snu.nemo.compiler.frontend.beam;
-
-import edu.snu.nemo.common.Pair;
-import edu.snu.nemo.common.ir.edge.executionproperty.*;
-import edu.snu.nemo.common.ir.vertex.transform.Transform;
-import edu.snu.nemo.compiler.frontend.beam.coder.BeamDecoderFactory;
-import edu.snu.nemo.compiler.frontend.beam.coder.BeamEncoderFactory;
-import edu.snu.nemo.common.dag.DAGBuilder;
-import edu.snu.nemo.common.ir.edge.IREdge;
-import edu.snu.nemo.compiler.frontend.beam.source.BeamBoundedSourceVertex;
-import edu.snu.nemo.common.ir.vertex.IRVertex;
-import edu.snu.nemo.common.ir.vertex.LoopVertex;
-import edu.snu.nemo.common.ir.vertex.OperatorVertex;
-
-import edu.snu.nemo.compiler.frontend.beam.transform.*;
-
-import org.apache.beam.sdk.Pipeline;
-import org.apache.beam.sdk.coders.*;
-import org.apache.beam.sdk.io.Read;
-import org.apache.beam.sdk.options.PipelineOptions;
-import org.apache.beam.sdk.runners.TransformHierarchy;
-import org.apache.beam.sdk.transforms.*;
-import org.apache.beam.sdk.transforms.windowing.Window;
-import org.apache.beam.sdk.values.PCollection;
-import org.apache.beam.sdk.values.PCollectionView;
-import org.apache.beam.sdk.values.PCollectionViews;
-import org.apache.beam.sdk.values.PValue;
-import org.apache.beam.sdk.values.TupleTag;
-
-import java.util.*;
-import java.util.stream.Collectors;
-
-/**
- * Visits every node in the beam dag to translate the BEAM program to the IR.
- */
-public final class NemoPipelineVisitor extends Pipeline.PipelineVisitor.Defaults {
- private final DAGBuilder<IRVertex, IREdge> builder;
- private final Map<PValue, IRVertex> pValueToVertex;
- private final PipelineOptions options;
- // loopVertexStack keeps track of where the beam program is: whether it is inside a composite transform or it is not.
- private final Stack<LoopVertex> loopVertexStack;
- private final Map<PValue, Pair<BeamEncoderFactory, BeamDecoderFactory>> pValueToCoder;
- private final Map<IRVertex, Pair<BeamEncoderFactory, BeamDecoderFactory>> sideInputCoder;
- private final Map<PValue, TupleTag> pValueToTag;
- private final Map<IRVertex, Set<PValue>> additionalInputs;
-
- /**
- * Constructor of the BEAM Visitor.
- *
- * @param builder DAGBuilder to build the DAG with.
- * @param options Pipeline options.
- */
- public NemoPipelineVisitor(final DAGBuilder<IRVertex, IREdge> builder, final PipelineOptions options) {
- this.builder = builder;
- this.pValueToVertex = new HashMap<>();
- this.options = options;
- this.loopVertexStack = new Stack<>();
- this.pValueToCoder = new HashMap<>();
- this.sideInputCoder = new HashMap<>();
- this.pValueToTag = new HashMap<>();
- this.additionalInputs = new HashMap<>();
- }
-
- @Override
- public CompositeBehavior enterCompositeTransform(final TransformHierarchy.Node beamNode) {
- if (beamNode.getTransform() instanceof LoopCompositeTransform) {
- final LoopVertex loopVertex = new LoopVertex(beamNode.getFullName());
- this.builder.addVertex(loopVertex, this.loopVertexStack);
- this.builder.removeVertex(loopVertex);
- this.loopVertexStack.push(loopVertex);
- }
- return CompositeBehavior.ENTER_TRANSFORM;
- }
-
- @Override
- public void leaveCompositeTransform(final TransformHierarchy.Node beamNode) {
- if (beamNode.getTransform() instanceof LoopCompositeTransform) {
- this.loopVertexStack.pop();
- }
- }
-
- @Override
- public void visitPrimitiveTransform(final TransformHierarchy.Node beamNode) {
-// Print if needed for development
-// LOG.info("visitp " + beamNode.getTransform());
- final IRVertex irVertex =
- convertToVertex(beamNode, builder, pValueToVertex, sideInputCoder, pValueToTag, additionalInputs,
- options, loopVertexStack);
- beamNode.getOutputs().values().stream().filter(v -> v instanceof PCollection).map(v -> (PCollection) v)
- .forEach(output -> pValueToCoder.put(output,
- Pair.of(new BeamEncoderFactory(output.getCoder()), new BeamDecoderFactory(output.getCoder()))));
-
- beamNode.getOutputs().values().forEach(output -> pValueToVertex.put(output, irVertex));
- final Set<PValue> additionalInputsForThisVertex = additionalInputs.getOrDefault(irVertex, new HashSet<>());
- beamNode.getInputs().values().stream().filter(pValueToVertex::containsKey)
- .filter(pValue -> !additionalInputsForThisVertex.contains(pValue))
- .forEach(pValue -> {
- final boolean isAdditionalOutput = pValueToTag.containsKey(pValue);
- final IRVertex src = pValueToVertex.get(pValue);
- final IREdge edge = new IREdge(getEdgeCommunicationPattern(src, irVertex), src, irVertex, false);
- final Pair<BeamEncoderFactory, BeamDecoderFactory> coderPair = pValueToCoder.get(pValue);
- edge.setProperty(EncoderProperty.of(coderPair.left()));
- edge.setProperty(DecoderProperty.of(coderPair.right()));
- edge.setProperty(KeyExtractorProperty.of(new BeamKeyExtractor()));
- // Apply AdditionalOutputTatProperty to edges that corresponds to additional outputs.
- if (isAdditionalOutput) {
- edge.setProperty(AdditionalOutputTagProperty.of(pValueToTag.get(pValue).getId()));
- }
- this.builder.connectVertices(edge);
- });
- }
-
- /**
- * Convert Beam node to IR vertex.
- *
- * @param beamNode input beam node.
- * @param builder the DAG builder to add the vertex to.
- * @param pValueToVertex PValue to Vertex map.
- * @param sideInputCoder Side input EncoderFactory and DecoderFactory map.
- * @param pValueToTag PValue to Tag map.
- * @param additionalInputs additional inputs.
- * @param options pipeline options.
- * @param loopVertexStack Stack to get the current loop vertex that the operator vertex will be assigned to.
- * @param <I> input type.
- * @param <O> output type.
- * @return newly created vertex.
- */
- private static <I, O> IRVertex
- convertToVertex(final TransformHierarchy.Node beamNode,
- final DAGBuilder<IRVertex, IREdge> builder,
- final Map<PValue, IRVertex> pValueToVertex,
- final Map<IRVertex, Pair<BeamEncoderFactory, BeamDecoderFactory>> sideInputCoder,
- final Map<PValue, TupleTag> pValueToTag,
- final Map<IRVertex, Set<PValue>> additionalInputs,
- final PipelineOptions options,
- final Stack<LoopVertex> loopVertexStack) {
- final PTransform beamTransform = beamNode.getTransform();
- final IRVertex irVertex;
- if (beamTransform instanceof Read.Bounded) {
- final Read.Bounded<O> read = (Read.Bounded) beamTransform;
- irVertex = new BeamBoundedSourceVertex<>(read.getSource());
- builder.addVertex(irVertex, loopVertexStack);
- } else if (beamTransform instanceof GroupByKey) {
- irVertex = new OperatorVertex(new GroupByKeyTransform());
- builder.addVertex(irVertex, loopVertexStack);
- } else if (beamTransform instanceof View.CreatePCollectionView) {
- final View.CreatePCollectionView view = (View.CreatePCollectionView) beamTransform;
- final CreateViewTransform transform = new CreateViewTransform(view.getView());
- irVertex = new OperatorVertex(transform);
- pValueToVertex.put(view.getView(), irVertex);
- builder.addVertex(irVertex, loopVertexStack);
- // Coders for outgoing edges in CreateViewTransform.
- // Since outgoing PValues for CreateViewTransform is PCollectionView,
- // we cannot use PCollection::getEncoderFactory to obtain coders.
- final Coder beamInputCoder = beamNode.getInputs().values().stream()
- .filter(v -> v instanceof PCollection).map(v -> (PCollection) v).findFirst()
- .orElseThrow(() -> new RuntimeException("No inputs provided to " + beamNode.getFullName())).getCoder();
- beamNode.getOutputs().values().stream()
- .forEach(output ->
- sideInputCoder.put(irVertex, getCoderPairForView(view.getView().getViewFn(), beamInputCoder)));
- } else if (beamTransform instanceof Window) {
- final Window<I> window = (Window<I>) beamTransform;
- final WindowTransform transform = new WindowTransform(window.getWindowFn());
- irVertex = new OperatorVertex(transform);
- builder.addVertex(irVertex, loopVertexStack);
- } else if (beamTransform instanceof Window.Assign) {
- final Window.Assign<I> window = (Window.Assign<I>) beamTransform;
- final WindowTransform transform = new WindowTransform(window.getWindowFn());
- irVertex = new OperatorVertex(transform);
- builder.addVertex(irVertex, loopVertexStack);
- } else if (beamTransform instanceof ParDo.SingleOutput) {
- final ParDo.SingleOutput<I, O> parDo = (ParDo.SingleOutput<I, O>) beamTransform;
- final DoTransform transform = new DoTransform(parDo.getFn(), options);
- irVertex = new OperatorVertex(transform);
- additionalInputs.put(irVertex, parDo.getAdditionalInputs().values().stream().collect(Collectors.toSet()));
- builder.addVertex(irVertex, loopVertexStack);
- connectSideInputs(builder, parDo.getSideInputs(), pValueToVertex, sideInputCoder, irVertex);
- } else if (beamTransform instanceof ParDo.MultiOutput) {
- final ParDo.MultiOutput<I, O> parDo = (ParDo.MultiOutput<I, O>) beamTransform;
- final DoTransform transform = new DoTransform(parDo.getFn(), options);
- irVertex = new OperatorVertex(transform);
- additionalInputs.put(irVertex, parDo.getAdditionalInputs().values().stream().collect(Collectors.toSet()));
- if (parDo.getAdditionalOutputTags().size() > 0) {
- // Store PValue to additional tag id mapping.
- beamNode.getOutputs().entrySet().stream()
- .filter(kv -> !kv.getKey().equals(parDo.getMainOutputTag()))
- .forEach(kv -> pValueToTag.put(kv.getValue(), kv.getKey()));
- }
- builder.addVertex(irVertex, loopVertexStack);
- connectSideInputs(builder, parDo.getSideInputs(), pValueToVertex, sideInputCoder, irVertex);
- } else if (beamTransform instanceof Flatten.PCollections) {
- irVertex = new OperatorVertex(new FlattenTransform());
- builder.addVertex(irVertex, loopVertexStack);
- } else {
- throw new UnsupportedOperationException(beamTransform.toString());
- }
- return irVertex;
- }
-
- /**
- * Connect side inputs to the vertex.
- *
- * @param builder the DAG builder to add the vertex to.
- * @param sideInputs side inputs.
- * @param pValueToVertex PValue to Vertex map.
- * @param coderMap Side input to Encoder/Decoder factory map.
- * @param irVertex wrapper for a user operation in the IR. (Where the side input is headed to)
- */
- private static void connectSideInputs(final DAGBuilder<IRVertex, IREdge> builder,
- final List<PCollectionView<?>> sideInputs,
- final Map<PValue, IRVertex> pValueToVertex,
- final Map<IRVertex, Pair<BeamEncoderFactory, BeamDecoderFactory>> coderMap,
- final IRVertex irVertex) {
- sideInputs.stream().filter(pValueToVertex::containsKey)
- .forEach(pValue -> {
- final IRVertex src = pValueToVertex.get(pValue);
- final IREdge edge = new IREdge(getEdgeCommunicationPattern(src, irVertex),
- src, irVertex, true);
- final Pair<BeamEncoderFactory, BeamDecoderFactory> coder = coderMap.get(src);
- edge.setProperty(EncoderProperty.of(coder.left()));
- edge.setProperty(DecoderProperty.of(coder.right()));
- edge.setProperty(KeyExtractorProperty.of(new BeamKeyExtractor()));
- builder.connectVertices(edge);
- });
- }
-
- /**
- * Get appropriate encoder and decoder pair for {@link PCollectionView}.
- *
- * @param viewFn {@link ViewFn} from the corresponding {@link View.CreatePCollectionView} transform
- * @param beamInputCoder Beam {@link Coder} for input value to {@link View.CreatePCollectionView}
- * @return appropriate pair of {@link BeamEncoderFactory} and {@link BeamDecoderFactory}
- */
- private static Pair<BeamEncoderFactory, BeamDecoderFactory> getCoderPairForView(final ViewFn viewFn,
- final Coder beamInputCoder) {
- final Coder beamOutputCoder;
- if (viewFn instanceof PCollectionViews.IterableViewFn) {
- beamOutputCoder = IterableCoder.of(beamInputCoder);
- } else if (viewFn instanceof PCollectionViews.ListViewFn) {
- beamOutputCoder = ListCoder.of(beamInputCoder);
- } else if (viewFn instanceof PCollectionViews.MapViewFn) {
- final KvCoder inputCoder = (KvCoder) beamInputCoder;
- beamOutputCoder = MapCoder.of(inputCoder.getKeyCoder(), inputCoder.getValueCoder());
- } else if (viewFn instanceof PCollectionViews.MultimapViewFn) {
- final KvCoder inputCoder = (KvCoder) beamInputCoder;
- beamOutputCoder = MapCoder.of(inputCoder.getKeyCoder(), IterableCoder.of(inputCoder.getValueCoder()));
- } else if (viewFn instanceof PCollectionViews.SingletonViewFn) {
- beamOutputCoder = beamInputCoder;
- } else {
- throw new UnsupportedOperationException("Unsupported viewFn: " + viewFn.getClass());
- }
- return Pair.of(new BeamEncoderFactory(beamOutputCoder), new BeamDecoderFactory(beamOutputCoder));
- }
-
- /**
- * Get the edge type for the src, dst vertex.
- *
- * @param src source vertex.
- * @param dst destination vertex.
- * @return the appropriate edge type.
- */
- private static CommunicationPatternProperty.Value getEdgeCommunicationPattern(final IRVertex src,
- final IRVertex dst) {
- final Class<?> constructUnionTableFn;
- try {
- constructUnionTableFn = Class.forName("org.apache.beam.sdk.transforms.join.CoGroupByKey$ConstructUnionTableFn");
- } catch (final ClassNotFoundException e) {
- throw new RuntimeException(e);
- }
-
- final Transform srcTransform = src instanceof OperatorVertex ? ((OperatorVertex) src).getTransform() : null;
- final Transform dstTransform = dst instanceof OperatorVertex ? ((OperatorVertex) dst).getTransform() : null;
- final DoFn srcDoFn = srcTransform instanceof DoTransform ? ((DoTransform) srcTransform).getDoFn() : null;
-
- if (srcDoFn != null && srcDoFn.getClass().equals(constructUnionTableFn)) {
- return CommunicationPatternProperty.Value.Shuffle;
- }
- if (srcTransform instanceof FlattenTransform) {
- return CommunicationPatternProperty.Value.OneToOne;
- }
- if (dstTransform instanceof GroupByKeyTransform) {
- return CommunicationPatternProperty.Value.Shuffle;
- }
- if (dstTransform instanceof CreateViewTransform) {
- return CommunicationPatternProperty.Value.BroadCast;
- }
- return CommunicationPatternProperty.Value.OneToOne;
- }
-}
diff --git a/compiler/frontend/beam/src/main/java/edu/snu/nemo/compiler/frontend/beam/PipelineTranslator.java b/compiler/frontend/beam/src/main/java/edu/snu/nemo/compiler/frontend/beam/PipelineTranslator.java
new file mode 100644
index 0000000..7e34ca2
--- /dev/null
+++ b/compiler/frontend/beam/src/main/java/edu/snu/nemo/compiler/frontend/beam/PipelineTranslator.java
@@ -0,0 +1,544 @@
+/*
+ * Copyright (C) 2018 Seoul National University
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package edu.snu.nemo.compiler.frontend.beam;
+
+import edu.snu.nemo.common.dag.DAG;
+import edu.snu.nemo.common.dag.DAGBuilder;
+import edu.snu.nemo.common.ir.edge.IREdge;
+import edu.snu.nemo.common.ir.edge.executionproperty.*;
+import edu.snu.nemo.common.ir.vertex.IRVertex;
+import edu.snu.nemo.common.ir.vertex.LoopVertex;
+import edu.snu.nemo.common.ir.vertex.OperatorVertex;
+import edu.snu.nemo.common.ir.vertex.transform.Transform;
+import edu.snu.nemo.compiler.frontend.beam.PipelineVisitor.*;
+import edu.snu.nemo.compiler.frontend.beam.coder.BeamDecoderFactory;
+import edu.snu.nemo.compiler.frontend.beam.coder.BeamEncoderFactory;
+import edu.snu.nemo.compiler.frontend.beam.source.BeamBoundedSourceVertex;
+import edu.snu.nemo.compiler.frontend.beam.transform.*;
+import org.apache.beam.sdk.coders.*;
+import org.apache.beam.sdk.io.Read;
+import org.apache.beam.sdk.options.PipelineOptions;
+import org.apache.beam.sdk.transforms.*;
+import org.apache.beam.sdk.transforms.windowing.Window;
+import org.apache.beam.sdk.transforms.windowing.WindowFn;
+import org.apache.beam.sdk.values.*;
+
+import java.lang.annotation.*;
+import java.lang.reflect.InvocationTargetException;
+import java.lang.reflect.Method;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Stack;
+import java.util.function.BiFunction;
+
+/**
+ * Converts DAG of Beam pipeline to Nemo IR DAG.
+ * For a {@link PrimitiveTransformVertex}, it defines mapping to the corresponding {@link IRVertex}.
+ * For a {@link CompositeTransformVertex}, it defines how to setup and clear {@link TranslationContext}
+ * before start translating inner Beam transform hierarchy.
+ */
+public final class PipelineTranslator
+ implements BiFunction<CompositeTransformVertex, PipelineOptions, DAG<IRVertex, IREdge>> {
+
+ private static final PipelineTranslator INSTANCE = new PipelineTranslator();
+
+ private final Map<Class<? extends PTransform>, Method> primitiveTransformToTranslator = new HashMap<>();
+ private final Map<Class<? extends PTransform>, Method> compositeTransformToTranslator = new HashMap<>();
+
+ /**
+ * Static translator method.
+ * @param pipeline Top-level Beam transform hierarchy, usually given by {@link PipelineVisitor}
+ * @param pipelineOptions {@link PipelineOptions}
+ * @return Nemo IR DAG
+ */
+ public static DAG<IRVertex, IREdge> translate(final CompositeTransformVertex pipeline,
+ final PipelineOptions pipelineOptions) {
+ return INSTANCE.apply(pipeline, pipelineOptions);
+ }
+
+ /**
+ * Creates the translator, while building a map between {@link PTransform}s and the corresponding translators.
+ */
+ private PipelineTranslator() {
+ for (final Method translator : getClass().getDeclaredMethods()) {
+ final PrimitiveTransformTranslator primitive = translator.getAnnotation(PrimitiveTransformTranslator.class);
+ final CompositeTransformTranslator composite = translator.getAnnotation(CompositeTransformTranslator.class);
+ if (primitive != null) {
+ for (final Class<? extends PTransform> transform : primitive.value()) {
+ if (primitiveTransformToTranslator.containsKey(transform)) {
+ throw new RuntimeException(String.format("Translator for primitive transform %s is"
+ + "already registered: %s", transform, primitiveTransformToTranslator.get(transform)));
+ }
+ primitiveTransformToTranslator.put(transform, translator);
+ }
+ }
+ if (composite != null) {
+ for (final Class<? extends PTransform> transform : composite.value()) {
+ if (compositeTransformToTranslator.containsKey(transform)) {
+ throw new RuntimeException(String.format("Translator for composite transform %s is"
+ + "already registered: %s", transform, compositeTransformToTranslator.get(transform)));
+ }
+ compositeTransformToTranslator.put(transform, translator);
+ }
+ }
+ }
+ }
+
+ @PrimitiveTransformTranslator(Read.Bounded.class)
+ private static void boundedReadTranslator(final TranslationContext ctx,
+ final PrimitiveTransformVertex transformVertex,
+ final Read.Bounded<?> transform) {
+ final IRVertex vertex = new BeamBoundedSourceVertex<>(transform.getSource());
+ ctx.addVertex(vertex);
+ transformVertex.getNode().getInputs().values().forEach(input -> ctx.addEdgeTo(vertex, input, false));
+ transformVertex.getNode().getOutputs().values().forEach(output -> ctx.registerMainOutputFrom(vertex, output));
+ }
+
+ @PrimitiveTransformTranslator(ParDo.SingleOutput.class)
+ private static void parDoSingleOutputTranslator(final TranslationContext ctx,
+ final PrimitiveTransformVertex transformVertex,
+ final ParDo.SingleOutput<?, ?> transform) {
+ final DoTransform doTransform = new DoTransform(transform.getFn(), ctx.pipelineOptions);
+ final IRVertex vertex = new OperatorVertex(doTransform);
+ ctx.addVertex(vertex);
+ transformVertex.getNode().getInputs().values().stream()
+ .filter(input -> !transform.getAdditionalInputs().values().contains(input))
+ .forEach(input -> ctx.addEdgeTo(vertex, input, false));
+ transform.getSideInputs().forEach(input -> ctx.addEdgeTo(vertex, input, true));
+ transformVertex.getNode().getOutputs().values().forEach(output -> ctx.registerMainOutputFrom(vertex, output));
+ }
+
+ @PrimitiveTransformTranslator(ParDo.MultiOutput.class)
+ private static void parDoMultiOutputTranslator(final TranslationContext ctx,
+ final PrimitiveTransformVertex transformVertex,
+ final ParDo.MultiOutput<?, ?> transform) {
+ final DoTransform doTransform = new DoTransform(transform.getFn(), ctx.pipelineOptions);
+ final IRVertex vertex = new OperatorVertex(doTransform);
+ ctx.addVertex(vertex);
+ transformVertex.getNode().getInputs().values().stream()
+ .filter(input -> !transform.getAdditionalInputs().values().contains(input))
+ .forEach(input -> ctx.addEdgeTo(vertex, input, false));
+ transform.getSideInputs().forEach(input -> ctx.addEdgeTo(vertex, input, true));
+ transformVertex.getNode().getOutputs().entrySet().stream()
+ .filter(pValueWithTupleTag -> pValueWithTupleTag.getKey().equals(transform.getMainOutputTag()))
+ .forEach(pValueWithTupleTag -> ctx.registerMainOutputFrom(vertex, pValueWithTupleTag.getValue()));
+ transformVertex.getNode().getOutputs().entrySet().stream()
+ .filter(pValueWithTupleTag -> !pValueWithTupleTag.getKey().equals(transform.getMainOutputTag()))
+ .forEach(pValueWithTupleTag -> ctx.registerAdditionalOutputFrom(vertex, pValueWithTupleTag.getValue(),
+ pValueWithTupleTag.getKey()));
+ }
+
+ @PrimitiveTransformTranslator(GroupByKey.class)
+ private static void groupByKeyTranslator(final TranslationContext ctx,
+ final PrimitiveTransformVertex transformVertex,
+ final GroupByKey<?, ?> transform) {
+ final IRVertex vertex = new OperatorVertex(new GroupByKeyTransform());
+ ctx.addVertex(vertex);
+ transformVertex.getNode().getInputs().values().forEach(input -> ctx.addEdgeTo(vertex, input, false));
+ transformVertex.getNode().getOutputs().values().forEach(output -> ctx.registerMainOutputFrom(vertex, output));
+ }
+
+ @PrimitiveTransformTranslator({Window.class, Window.Assign.class})
+ private static void windowTranslator(final TranslationContext ctx,
+ final PrimitiveTransformVertex transformVertex,
+ final PTransform<?, ?> transform) {
+ final WindowFn windowFn;
+ if (transform instanceof Window) {
+ windowFn = ((Window) transform).getWindowFn();
+ } else if (transform instanceof Window.Assign) {
+ windowFn = ((Window.Assign) transform).getWindowFn();
+ } else {
+ throw new UnsupportedOperationException(String.format("%s is not supported", transform));
+ }
+ final IRVertex vertex = new OperatorVertex(new WindowTransform(windowFn));
+ ctx.addVertex(vertex);
+ transformVertex.getNode().getInputs().values().forEach(input -> ctx.addEdgeTo(vertex, input, false));
+ transformVertex.getNode().getOutputs().values().forEach(output -> ctx.registerMainOutputFrom(vertex, output));
+ }
+
+ @PrimitiveTransformTranslator(View.CreatePCollectionView.class)
+ private static void createPCollectionViewTranslator(final TranslationContext ctx,
+ final PrimitiveTransformVertex transformVertex,
+ final View.CreatePCollectionView<?, ?> transform) {
+ final IRVertex vertex = new OperatorVertex(new CreateViewTransform<>(transform.getView()));
+ ctx.addVertex(vertex);
+ transformVertex.getNode().getInputs().values().forEach(input -> ctx.addEdgeTo(vertex, input, false));
+ ctx.registerMainOutputFrom(vertex, transform.getView());
+ transformVertex.getNode().getOutputs().values().forEach(output -> ctx.registerMainOutputFrom(vertex, output));
+ }
+
+ @PrimitiveTransformTranslator(Flatten.PCollections.class)
+ private static void flattenTranslator(final TranslationContext ctx,
+ final PrimitiveTransformVertex transformVertex,
+ final Flatten.PCollections<?> transform) {
+ final IRVertex vertex = new OperatorVertex(new FlattenTransform());
+ ctx.addVertex(vertex);
+ transformVertex.getNode().getInputs().values().forEach(input -> ctx.addEdgeTo(vertex, input, false));
+ transformVertex.getNode().getOutputs().values().forEach(output -> ctx.registerMainOutputFrom(vertex, output));
+ }
+
+ /**
+ * Default translator for CompositeTransforms. Translates inner DAG without modifying {@link TranslationContext}.
+ *
+ * @param ctx provides translation context
+ * @param transformVertex the given CompositeTransform to translate
+ * @param transform transform which can be obtained from {@code transformVertex}
+ */
+ @CompositeTransformTranslator(PTransform.class)
+ private static void topologicalTranslator(final TranslationContext ctx,
+ final CompositeTransformVertex transformVertex,
+ final PTransform<?, ?> transform) {
+ transformVertex.getDAG().topologicalDo(ctx::translate);
+ }
+
+ /**
+ * Translator for Combine transform. Implements local combining before shuffling key-value pairs.
+ *
+ * @param ctx provides translation context
+ * @param transformVertex the given CompositeTransform to translate
+ * @param transform transform which can be obtained from {@code transformVertex}
+ */
+ @CompositeTransformTranslator({Combine.Globally.class, Combine.PerKey.class, Combine.GroupedValues.class})
+ private static void combineTranslator(final TranslationContext ctx,
+ final CompositeTransformVertex transformVertex,
+ final PTransform<?, ?> transform) {
+ final List<TransformVertex> topologicalOrdering = transformVertex.getDAG().getTopologicalSort();
+ final TransformVertex first = topologicalOrdering.get(0);
+ final TransformVertex last = topologicalOrdering.get(topologicalOrdering.size() - 1);
+
+ if (first.getNode().getTransform() instanceof GroupByKey) {
+ // Translate the given CompositeTransform under OneToOneEdge-enforced context.
+ final TranslationContext oneToOneEdgeContext = new TranslationContext(ctx,
+ OneToOneCommunicationPatternSelector.INSTANCE);
+ transformVertex.getDAG().topologicalDo(oneToOneEdgeContext::translate);
+
+ // Attempt to translate the CompositeTransform again.
+ // Add GroupByKey, which is the first transform in the given CompositeTransform.
+ // Make sure it consumes the output from the last vertex in OneToOneEdge-translated hierarchy.
+ final IRVertex groupByKey = new OperatorVertex(new GroupByKeyTransform());
+ ctx.addVertex(groupByKey);
+ last.getNode().getOutputs().values().forEach(outputFromCombiner
+ -> ctx.addEdgeTo(groupByKey, outputFromCombiner, false));
+ first.getNode().getOutputs().values()
+ .forEach(outputFromGroupByKey -> ctx.registerMainOutputFrom(groupByKey, outputFromGroupByKey));
+
+ // Translate the remaining vertices.
+ topologicalOrdering.stream().skip(1).forEach(ctx::translate);
+ } else {
+ transformVertex.getDAG().topologicalDo(ctx::translate);
+ }
+ }
+
+ /**
+ * Pushes the loop vertex to the stack before translating the inner DAG, and pops it after the translation.
+ *
+ * @param ctx provides translation context
+ * @param transformVertex the given CompositeTransform to translate
+ * @param transform transform which can be obtained from {@code transformVertex}
+ */
+ @CompositeTransformTranslator(LoopCompositeTransform.class)
+ private static void loopTranslator(final TranslationContext ctx,
+ final CompositeTransformVertex transformVertex,
+ final LoopCompositeTransform<?, ?> transform) {
+ final LoopVertex loopVertex = new LoopVertex(transformVertex.getNode().getFullName());
+ ctx.builder.addVertex(loopVertex, ctx.loopVertexStack);
+ ctx.builder.removeVertex(loopVertex);
+ ctx.loopVertexStack.push(loopVertex);
+ topologicalTranslator(ctx, transformVertex, transform);
+ ctx.loopVertexStack.pop();
+ }
+
+ @Override
+ public DAG<IRVertex, IREdge> apply(final CompositeTransformVertex pipeline, final PipelineOptions pipelineOptions) {
+ final TranslationContext ctx = new TranslationContext(pipeline, primitiveTransformToTranslator,
+ compositeTransformToTranslator, DefaultCommunicationPatternSelector.INSTANCE, pipelineOptions);
+ ctx.translate(pipeline);
+ return ctx.builder.build();
+ }
+
+ /**
+ * Annotates translator for PrimitiveTransform.
+ */
+ @Target(ElementType.METHOD)
+ @Retention(RetentionPolicy.RUNTIME)
+ private @interface PrimitiveTransformTranslator {
+ Class<? extends PTransform>[] value();
+ }
+
+ /**
+ * Annotates translator for CompositeTransform.
+ */
+ @Target(ElementType.METHOD)
+ @Retention(RetentionPolicy.RUNTIME)
+ private @interface CompositeTransformTranslator {
+ Class<? extends PTransform>[] value();
+ }
+
+ /**
+ * Translation context.
+ */
+ private static final class TranslationContext {
+ private final CompositeTransformVertex pipeline;
+ private final PipelineOptions pipelineOptions;
+ private final DAGBuilder<IRVertex, IREdge> builder;
+ private final Map<PValue, IRVertex> pValueToProducer;
+ private final Map<PValue, TupleTag<?>> pValueToTag;
+ private final Stack<LoopVertex> loopVertexStack;
+ private final BiFunction<IRVertex, IRVertex, CommunicationPatternProperty.Value> communicationPatternSelector;
+
+ private final Map<Class<? extends PTransform>, Method> primitiveTransformToTranslator;
+ private final Map<Class<? extends PTransform>, Method> compositeTransformToTranslator;
+
+ /**
+ * @param pipeline the pipeline to translate
+ * @param primitiveTransformToTranslator provides translators for PrimitiveTransform
+ * @param compositeTransformToTranslator provides translators for CompositeTransform
+ * @param selector provides {@link CommunicationPatternProperty.Value} for IR edges
+ * @param pipelineOptions {@link PipelineOptions}
+ */
+ private TranslationContext(final CompositeTransformVertex pipeline,
+ final Map<Class<? extends PTransform>, Method> primitiveTransformToTranslator,
+ final Map<Class<? extends PTransform>, Method> compositeTransformToTranslator,
+ final BiFunction<IRVertex, IRVertex, CommunicationPatternProperty.Value> selector,
+ final PipelineOptions pipelineOptions) {
+ this.pipeline = pipeline;
+ this.builder = new DAGBuilder<>();
+ this.pValueToProducer = new HashMap<>();
+ this.pValueToTag = new HashMap<>();
+ this.loopVertexStack = new Stack<>();
+ this.primitiveTransformToTranslator = primitiveTransformToTranslator;
+ this.compositeTransformToTranslator = compositeTransformToTranslator;
+ this.communicationPatternSelector = selector;
+ this.pipelineOptions = pipelineOptions;
+ }
+
+ /**
+ * Copy constructor, except for setting different CommunicationPatternProperty selector.
+ *
+ * @param ctx the original {@link TranslationContext}
+ * @param selector provides {@link CommunicationPatternProperty.Value} for IR edges
+ */
+ private TranslationContext(final TranslationContext ctx,
+ final BiFunction<IRVertex, IRVertex, CommunicationPatternProperty.Value> selector) {
+ this.pipeline = ctx.pipeline;
+ this.pipelineOptions = ctx.pipelineOptions;
+ this.builder = ctx.builder;
+ this.pValueToProducer = ctx.pValueToProducer;
+ this.pValueToTag = ctx.pValueToTag;
+ this.loopVertexStack = ctx.loopVertexStack;
+ this.primitiveTransformToTranslator = ctx.primitiveTransformToTranslator;
+ this.compositeTransformToTranslator = ctx.compositeTransformToTranslator;
+
+ this.communicationPatternSelector = selector;
+ }
+
+ /**
+ * Selects appropriate translator to translate the given hierarchy.
+ *
+ * @param transformVertex the Beam transform hierarchy to translate
+ */
+ private void translate(final TransformVertex transformVertex) {
+ final boolean isComposite = transformVertex instanceof CompositeTransformVertex;
+ final PTransform<?, ?> transform = transformVertex.getNode().getTransform();
+ if (transform == null) {
+ // root node
+ topologicalTranslator(this, (CompositeTransformVertex) transformVertex, null);
+ return;
+ }
+
+ Class<?> clazz = transform.getClass();
+ while (true) {
+ final Method translator = (isComposite ? compositeTransformToTranslator : primitiveTransformToTranslator)
+ .get(clazz);
+ if (translator == null) {
+ if (clazz.getSuperclass() != null) {
+ clazz = clazz.getSuperclass();
+ continue;
+ }
+ throw new UnsupportedOperationException(String.format("%s transform %s is not supported",
+ isComposite ? "Composite" : "Primitive", transform.getClass().getCanonicalName()));
+ } else {
+ try {
+ translator.setAccessible(true);
+ translator.invoke(null, this, transformVertex, transform);
+ break;
+ } catch (final IllegalAccessException e) {
+ throw new RuntimeException(e);
+ } catch (final InvocationTargetException | RuntimeException e) {
+ throw new RuntimeException(String.format(
+ "Translator %s have failed to translate %s", translator, transform), e);
+ }
+ }
+ }
+ }
+
+ /**
+ * Add IR vertex to the builder.
+ *
+ * @param vertex IR vertex to add
+ */
+ private void addVertex(final IRVertex vertex) {
+ builder.addVertex(vertex, loopVertexStack);
+ }
+
+ /**
+ * Add IR edge to the builder.
+ *
+ * @param dst the destination IR vertex.
+ * @param input the {@link PValue} {@code dst} consumes
+ * @param isSideInput whether it is sideInput or not.
+ */
+ private void addEdgeTo(final IRVertex dst, final PValue input, final boolean isSideInput) {
+ final IRVertex src = pValueToProducer.get(input);
+ if (src == null) {
+ try {
+ throw new RuntimeException(String.format("Cannot find a vertex that emits pValue %s, "
+ + "while PTransform %s is known to produce it.", input, pipeline.getPrimitiveProducerOf(input)));
+ } catch (final RuntimeException e) {
+ throw new RuntimeException(String.format("Cannot find a vertex that emits pValue %s, "
+ + "and the corresponding PTransform was not found", input));
+ }
+ }
+ final CommunicationPatternProperty.Value communicationPattern = communicationPatternSelector.apply(src, dst);
+ if (communicationPattern == null) {
+ throw new RuntimeException(String.format("%s have failed to determine communication pattern "
+ + "for an edge from %s to %s", communicationPatternSelector, src, dst));
+ }
+ final IREdge edge = new IREdge(communicationPattern, src, dst, isSideInput);
+ final Coder<?> coder;
+ if (input instanceof PCollection) {
+ coder = ((PCollection) input).getCoder();
+ } else if (input instanceof PCollectionView) {
+ coder = getCoderForView((PCollectionView) input);
+ } else {
+ coder = null;
+ }
+ if (coder == null) {
+ throw new RuntimeException(String.format("While adding an edge from %s, to %s, coder for PValue %s cannot "
+ + "be determined", src, dst, input));
+ }
+ edge.setProperty(EncoderProperty.of(new BeamEncoderFactory<>(coder)));
+ edge.setProperty(DecoderProperty.of(new BeamDecoderFactory<>(coder)));
+ if (pValueToTag.containsKey(input)) {
+ edge.setProperty(AdditionalOutputTagProperty.of(pValueToTag.get(input).getId()));
+ }
+ edge.setProperty(KeyExtractorProperty.of(new BeamKeyExtractor()));
+ builder.connectVertices(edge);
+ }
+
+ /**
+ * Registers a {@link PValue} as a main output from the specified {@link IRVertex}.
+ *
+ * @param irVertex the IR vertex
+ * @param output the {@link PValue} {@code irVertex} emits as main output
+ */
+ private void registerMainOutputFrom(final IRVertex irVertex, final PValue output) {
+ pValueToProducer.put(output, irVertex);
+ }
+
+ /**
+ * Registers a {@link PValue} as an additional output from the specified {@link IRVertex}.
+ *
+ * @param irVertex the IR vertex
+ * @param output the {@link PValue} {@code irVertex} emits as additional output
+ * @param tag the {@link TupleTag} associated with this additional output
+ */
+ private void registerAdditionalOutputFrom(final IRVertex irVertex, final PValue output, final TupleTag<?> tag) {
+ pValueToTag.put(output, tag);
+ pValueToProducer.put(output, irVertex);
+ }
+
+ /**
+ * Get appropriate coder for {@link PCollectionView}.
+ *
+ * @param view {@link PCollectionView} from the corresponding {@link View.CreatePCollectionView} transform
+ * @return appropriate {@link Coder} for {@link PCollectionView}
+ */
+ private Coder<?> getCoderForView(final PCollectionView view) {
+ final PrimitiveTransformVertex src = pipeline.getPrimitiveProducerOf(view);
+ final Coder<?> baseCoder = src.getNode().getInputs().values().stream()
+ .filter(v -> v instanceof PCollection).map(v -> (PCollection) v).findFirst()
+ .orElseThrow(() -> new RuntimeException(String.format("No incoming PCollection to %s", src)))
+ .getCoder();
+ final ViewFn viewFn = view.getViewFn();
+ if (viewFn instanceof PCollectionViews.IterableViewFn) {
+ return IterableCoder.of(baseCoder);
+ } else if (viewFn instanceof PCollectionViews.ListViewFn) {
+ return ListCoder.of(baseCoder);
+ } else if (viewFn instanceof PCollectionViews.MapViewFn) {
+ final KvCoder<?, ?> inputCoder = (KvCoder) baseCoder;
+ return MapCoder.of(inputCoder.getKeyCoder(), inputCoder.getValueCoder());
+ } else if (viewFn instanceof PCollectionViews.MultimapViewFn) {
+ final KvCoder<?, ?> inputCoder = (KvCoder) baseCoder;
+ return MapCoder.of(inputCoder.getKeyCoder(), IterableCoder.of(inputCoder.getValueCoder()));
+ } else if (viewFn instanceof PCollectionViews.SingletonViewFn) {
+ return baseCoder;
+ } else {
+ throw new UnsupportedOperationException(String.format("Unsupported viewFn %s", viewFn.getClass()));
+ }
+ }
+ }
+
+ /**
+ * Default implementation for {@link CommunicationPatternProperty.Value} selector.
+ */
+ private static final class DefaultCommunicationPatternSelector
+ implements BiFunction<IRVertex, IRVertex, CommunicationPatternProperty.Value> {
+
+ private static final DefaultCommunicationPatternSelector INSTANCE = new DefaultCommunicationPatternSelector();
+
+ @Override
+ public CommunicationPatternProperty.Value apply(final IRVertex src, final IRVertex dst) {
+ final Class<?> constructUnionTableFn;
+ try {
+ constructUnionTableFn = Class.forName("org.apache.beam.sdk.transforms.join.CoGroupByKey$ConstructUnionTableFn");
+ } catch (final ClassNotFoundException e) {
+ throw new RuntimeException(e);
+ }
+
+ final Transform srcTransform = src instanceof OperatorVertex ? ((OperatorVertex) src).getTransform() : null;
+ final Transform dstTransform = dst instanceof OperatorVertex ? ((OperatorVertex) dst).getTransform() : null;
+ final DoFn srcDoFn = srcTransform instanceof DoTransform ? ((DoTransform) srcTransform).getDoFn() : null;
+
+ if (srcDoFn != null && srcDoFn.getClass().equals(constructUnionTableFn)) {
+ return CommunicationPatternProperty.Value.Shuffle;
+ }
+ if (srcTransform instanceof FlattenTransform) {
+ return CommunicationPatternProperty.Value.OneToOne;
+ }
+ if (dstTransform instanceof GroupByKeyTransform) {
+ return CommunicationPatternProperty.Value.Shuffle;
+ }
+ if (dstTransform instanceof CreateViewTransform) {
+ return CommunicationPatternProperty.Value.BroadCast;
+ }
+ return CommunicationPatternProperty.Value.OneToOne;
+ }
+ }
+
+ /**
+ * A {@link CommunicationPatternProperty.Value} selector which always emits OneToOne.
+ */
+ private static final class OneToOneCommunicationPatternSelector
+ implements BiFunction<IRVertex, IRVertex, CommunicationPatternProperty.Value> {
+ private static final OneToOneCommunicationPatternSelector INSTANCE = new OneToOneCommunicationPatternSelector();
+ @Override
+ public CommunicationPatternProperty.Value apply(final IRVertex src, final IRVertex dst) {
+ return CommunicationPatternProperty.Value.OneToOne;
+ }
+ }
+}
diff --git a/compiler/frontend/beam/src/main/java/edu/snu/nemo/compiler/frontend/beam/PipelineVisitor.java b/compiler/frontend/beam/src/main/java/edu/snu/nemo/compiler/frontend/beam/PipelineVisitor.java
new file mode 100644
index 0000000..55be865
--- /dev/null
+++ b/compiler/frontend/beam/src/main/java/edu/snu/nemo/compiler/frontend/beam/PipelineVisitor.java
@@ -0,0 +1,296 @@
+/*
+ * Copyright (C) 2018 Seoul National University
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package edu.snu.nemo.compiler.frontend.beam;
+
+import edu.snu.nemo.common.dag.DAG;
+import edu.snu.nemo.common.dag.DAGBuilder;
+import edu.snu.nemo.common.dag.Edge;
+import edu.snu.nemo.common.dag.Vertex;
+import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.runners.TransformHierarchy;
+import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.transforms.View;
+import org.apache.beam.sdk.values.PValue;
+
+import java.util.*;
+
+/**
+ * Traverses through the given Beam pipeline to construct a DAG of Beam Transform,
+ * while preserving hierarchy of CompositeTransforms.
+ * Hierarchy is established when a CompositeTransform is expanded to other CompositeTransforms or PrimitiveTransforms,
+ * as the former CompositeTransform becoming 'enclosingVertex' which have the inner transforms as embedded DAG.
+ * This DAG will be later translated by {@link PipelineTranslator} into Nemo IR DAG.
+ */
+public final class PipelineVisitor extends Pipeline.PipelineVisitor.Defaults {
+
+ private static final String TRANSFORM = "Transform-";
+ private static final String DATAFLOW = "Dataflow-";
+
+ private final Stack<CompositeTransformVertex> compositeTransformVertexStack = new Stack<>();
+ private CompositeTransformVertex rootVertex = null;
+ private int nextIdx = 0;
+
+ @Override
+ public void visitPrimitiveTransform(final TransformHierarchy.Node node) {
+ final PrimitiveTransformVertex vertex = new PrimitiveTransformVertex(node, compositeTransformVertexStack.peek());
+ compositeTransformVertexStack.peek().addVertex(vertex);
+ vertex.getPValuesConsumed()
+ .forEach(pValue -> {
+ final TransformVertex dst = getDestinationOfDataFlowEdge(vertex, pValue);
+ dst.enclosingVertex.addDataFlow(new DataFlowEdge(dst.enclosingVertex.getProducerOf(pValue), dst));
+ });
+ }
+
+ @Override
+ public CompositeBehavior enterCompositeTransform(final TransformHierarchy.Node node) {
+ final CompositeTransformVertex vertex;
+ if (compositeTransformVertexStack.isEmpty()) {
+ // There is always a top-level CompositeTransform that encompasses the entire Beam pipeline.
+ vertex = new CompositeTransformVertex(node, null);
+ } else {
+ vertex = new CompositeTransformVertex(node, compositeTransformVertexStack.peek());
+ }
+ compositeTransformVertexStack.push(vertex);
+ return CompositeBehavior.ENTER_TRANSFORM;
+ }
+
+ @Override
+ public void leaveCompositeTransform(final TransformHierarchy.Node node) {
+ final CompositeTransformVertex vertex = compositeTransformVertexStack.pop();
+ vertex.build();
+ if (compositeTransformVertexStack.isEmpty()) {
+ // The vertex is the root.
+ if (rootVertex != null) {
+ throw new RuntimeException("The visitor already have traversed a Beam pipeline. "
+ + "Re-using a visitor is not allowed.");
+ }
+ rootVertex = vertex;
+ } else {
+ // The CompositeTransformVertex is ready; adding it to its enclosing vertex.
+ compositeTransformVertexStack.peek().addVertex(vertex);
+ }
+ }
+
+ /**
+ * @return A vertex representing the top-level CompositeTransform.
+ */
+ public CompositeTransformVertex getConvertedPipeline() {
+ if (rootVertex == null) {
+ throw new RuntimeException("The visitor have not fully traversed through a Beam pipeline.");
+ }
+ return rootVertex;
+ }
+
+ /**
+ * Represents a {@link org.apache.beam.sdk.transforms.PTransform} as a vertex in DAG.
+ */
+ public abstract class TransformVertex extends Vertex {
+ private final TransformHierarchy.Node node;
+ private final CompositeTransformVertex enclosingVertex;
+
+ /**
+ * @param node the corresponding Beam node
+ * @param enclosingVertex the vertex for the transform which inserted this transform as its expansion,
+ * or {@code null}
+ */
+ private TransformVertex(final TransformHierarchy.Node node, final CompositeTransformVertex enclosingVertex) {
+ super(String.format("%s%d", TRANSFORM, nextIdx++));
+ this.node = node;
+ this.enclosingVertex = enclosingVertex;
+ }
+
+ /**
+ * @return Collection of {@link PValue}s this transform emits.
+ */
+ public abstract Collection<PValue> getPValuesProduced();
+
+ /**
+ * Searches within {@code this} to find a transform that produces the given {@link PValue}.
+ *
+ * @param pValue a {@link PValue}
+ * @return the {@link TransformVertex} whose {@link org.apache.beam.sdk.transforms.PTransform}
+ * produces the given {@code pValue}
+ */
+ public abstract PrimitiveTransformVertex getPrimitiveProducerOf(final PValue pValue);
+
+ /**
+ * @return the corresponding Beam node.
+ */
+ public TransformHierarchy.Node getNode() {
+ return node;
+ }
+
+ /**
+ * @return the enclosing {@link CompositeTransformVertex} if any, {@code null} otherwise.
+ */
+ public CompositeTransformVertex getEnclosingVertex() {
+ return enclosingVertex;
+ }
+ }
+
+ /**
+ * Represents a transform hierarchy for primitive transform.
+ */
+ public final class PrimitiveTransformVertex extends TransformVertex {
+ private final List<PValue> pValuesProduced = new ArrayList<>();
+ private final List<PValue> pValuesConsumed = new ArrayList<>();
+
+ private PrimitiveTransformVertex(final TransformHierarchy.Node node,
+ final CompositeTransformVertex enclosingVertex) {
+ super(node, enclosingVertex);
+ if (node.getTransform() instanceof View.CreatePCollectionView) {
+ pValuesProduced.add(((View.CreatePCollectionView) node.getTransform()).getView());
+ }
+ if (node.getTransform() instanceof ParDo.SingleOutput) {
+ pValuesConsumed.addAll(((ParDo.SingleOutput) node.getTransform()).getSideInputs());
+ }
+ if (node.getTransform() instanceof ParDo.MultiOutput) {
+ pValuesConsumed.addAll(((ParDo.MultiOutput) node.getTransform()).getSideInputs());
+ }
+ pValuesProduced.addAll(getNode().getOutputs().values());
+ pValuesConsumed.addAll(getNode().getInputs().values());
+ }
+
+ @Override
+ public Collection<PValue> getPValuesProduced() {
+ return pValuesProduced;
+ }
+
+ @Override
+ public PrimitiveTransformVertex getPrimitiveProducerOf(final PValue pValue) {
+ if (!getPValuesProduced().contains(pValue)) {
+ throw new RuntimeException();
+ }
+ return this;
+ }
+
+ /**
+ * @return collection of {@link PValue} this transform consumes.
+ */
+ public Collection<PValue> getPValuesConsumed() {
+ return pValuesConsumed;
+ }
+ }
+ /**
+ * Represents a transform hierarchy for composite transform.
+ */
+ public final class CompositeTransformVertex extends TransformVertex {
+ private final Map<PValue, TransformVertex> pValueToProducer = new HashMap<>();
+ private final Collection<DataFlowEdge> dataFlowEdges = new ArrayList<>();
+ private final DAGBuilder<TransformVertex, DataFlowEdge> builder = new DAGBuilder<>();
+ private DAG<TransformVertex, DataFlowEdge> dag = null;
+
+ private CompositeTransformVertex(final TransformHierarchy.Node node,
+ final CompositeTransformVertex enclosingVertex) {
+ super(node, enclosingVertex);
+ }
+
+ /**
+ * Finalize this vertex and make it ready to be added to another {@link CompositeTransformVertex}.
+ */
+ private void build() {
+ if (dag != null) {
+ throw new RuntimeException("DAG already have been built.");
+ }
+ dataFlowEdges.forEach(builder::connectVertices);
+ dag = builder.build();
+ }
+
+ /**
+ * Add a {@link TransformVertex}.
+ *
+ * @param vertex the vertex to add
+ */
+ private void addVertex(final TransformVertex vertex) {
+ vertex.getPValuesProduced().forEach(value -> pValueToProducer.put(value, vertex));
+ builder.addVertex(vertex);
+ }
+
+ /**
+ * Add a {@link DataFlowEdge}.
+ *
+ * @param dataFlowEdge the edge to add
+ */
+ private void addDataFlow(final DataFlowEdge dataFlowEdge) {
+ dataFlowEdges.add(dataFlowEdge);
+ }
+
+ @Override
+ public Collection<PValue> getPValuesProduced() {
+ return pValueToProducer.keySet();
+ }
+
+ /**
+ * Get a direct child of this vertex which produces the given {@link PValue}.
+ *
+ * @param pValue the {@link PValue} to search
+ * @return the direct child of this vertex which produces {@code pValue}
+ */
+ public TransformVertex getProducerOf(final PValue pValue) {
+ final TransformVertex vertex = pValueToProducer.get(pValue);
+ if (vertex == null) {
+ throw new RuntimeException();
+ }
+ return vertex;
+ }
+
+ @Override
+ public PrimitiveTransformVertex getPrimitiveProducerOf(final PValue pValue) {
+ return getProducerOf(pValue).getPrimitiveProducerOf(pValue);
+ }
+
+ /**
+ * @return DAG of Beam hierarchy
+ */
+ public DAG<TransformVertex, DataFlowEdge> getDAG() {
+ return dag;
+ }
+ }
+
+ /**
+ * Represents data flow from a transform to another transform.
+ */
+ public final class DataFlowEdge extends Edge<TransformVertex> {
+ /**
+ * @param src source vertex
+ * @param dst destination vertex
+ */
+ private DataFlowEdge(final TransformVertex src, final TransformVertex dst) {
+ super(String.format("%s%d", DATAFLOW, nextIdx++), src, dst);
+ }
+ }
+
+ /**
+ * @param primitiveConsumer a {@link PrimitiveTransformVertex} which consumes {@code pValue}
+ * @param pValue the specified {@link PValue}
+ * @return the closest {@link TransformVertex} to {@code primitiveConsumer},
+ * which is equal to or encloses {@code primitiveConsumer} and can be the destination vertex of
+ * data flow edge from the producer of {@code pValue} to {@code primitiveConsumer}.
+ */
+ private TransformVertex getDestinationOfDataFlowEdge(final PrimitiveTransformVertex primitiveConsumer,
+ final PValue pValue) {
+ TransformVertex current = primitiveConsumer;
+ while (true) {
+ if (current.getEnclosingVertex().getPValuesProduced().contains(pValue)) {
+ return current;
+ }
+ current = current.getEnclosingVertex();
+ if (current.getEnclosingVertex() == null) {
+ throw new RuntimeException(String.format("Cannot find producer of %s", pValue));
+ }
+ }
+ }
+}
diff --git a/compiler/test/src/test/java/edu/snu/nemo/compiler/frontend/beam/BeamFrontendALSTest.java b/compiler/test/src/test/java/edu/snu/nemo/compiler/frontend/beam/BeamFrontendALSTest.java
index 67f146a..12c8244 100644
--- a/compiler/test/src/test/java/edu/snu/nemo/compiler/frontend/beam/BeamFrontendALSTest.java
+++ b/compiler/test/src/test/java/edu/snu/nemo/compiler/frontend/beam/BeamFrontendALSTest.java
@@ -38,22 +38,22 @@ public final class BeamFrontendALSTest {
final DAG<IRVertex, IREdge> producedDAG = CompilerTestUtil.compileALSDAG();
assertEquals(producedDAG.getTopologicalSort(), producedDAG.getTopologicalSort());
- assertEquals(38, producedDAG.getVertices().size());
+ assertEquals(42, producedDAG.getVertices().size());
// producedDAG.getTopologicalSort().forEach(v -> System.out.println(v.getId()));
- final IRVertex vertex4 = producedDAG.getTopologicalSort().get(6);
- assertEquals(1, producedDAG.getIncomingEdgesOf(vertex4).size());
- assertEquals(1, producedDAG.getIncomingEdgesOf(vertex4.getId()).size());
- assertEquals(4, producedDAG.getOutgoingEdgesOf(vertex4).size());
+ final IRVertex vertex11 = producedDAG.getTopologicalSort().get(5);
+ assertEquals(1, producedDAG.getIncomingEdgesOf(vertex11).size());
+ assertEquals(1, producedDAG.getIncomingEdgesOf(vertex11.getId()).size());
+ assertEquals(4, producedDAG.getOutgoingEdgesOf(vertex11).size());
- final IRVertex vertex13 = producedDAG.getTopologicalSort().get(11);
- assertEquals(1, producedDAG.getIncomingEdgesOf(vertex13).size());
- assertEquals(1, producedDAG.getIncomingEdgesOf(vertex13.getId()).size());
- assertEquals(1, producedDAG.getOutgoingEdgesOf(vertex13).size());
+ final IRVertex vertex17 = producedDAG.getTopologicalSort().get(10);
+ assertEquals(1, producedDAG.getIncomingEdgesOf(vertex17).size());
+ assertEquals(1, producedDAG.getIncomingEdgesOf(vertex17.getId()).size());
+ assertEquals(1, producedDAG.getOutgoingEdgesOf(vertex17).size());
- final IRVertex vertex14 = producedDAG.getTopologicalSort().get(12);
- assertEquals(2, producedDAG.getIncomingEdgesOf(vertex14).size());
- assertEquals(2, producedDAG.getIncomingEdgesOf(vertex14.getId()).size());
- assertEquals(1, producedDAG.getOutgoingEdgesOf(vertex14).size());
+ final IRVertex vertex18 = producedDAG.getTopologicalSort().get(16);
+ assertEquals(2, producedDAG.getIncomingEdgesOf(vertex18).size());
+ assertEquals(2, producedDAG.getIncomingEdgesOf(vertex18.getId()).size());
+ assertEquals(1, producedDAG.getOutgoingEdgesOf(vertex18).size());
}
}
diff --git a/compiler/test/src/test/java/edu/snu/nemo/compiler/frontend/beam/BeamFrontendMLRTest.java b/compiler/test/src/test/java/edu/snu/nemo/compiler/frontend/beam/BeamFrontendMLRTest.java
index 0cb3a26..7f4d591 100644
--- a/compiler/test/src/test/java/edu/snu/nemo/compiler/frontend/beam/BeamFrontendMLRTest.java
+++ b/compiler/test/src/test/java/edu/snu/nemo/compiler/frontend/beam/BeamFrontendMLRTest.java
@@ -38,21 +38,21 @@ public class BeamFrontendMLRTest {
final DAG<IRVertex, IREdge> producedDAG = CompilerTestUtil.compileMLRDAG();
assertEquals(producedDAG.getTopologicalSort(), producedDAG.getTopologicalSort());
- assertEquals(36, producedDAG.getVertices().size());
+ assertEquals(42, producedDAG.getVertices().size());
- final IRVertex vertex3 = producedDAG.getTopologicalSort().get(0);
- assertEquals(0, producedDAG.getIncomingEdgesOf(vertex3).size());
- assertEquals(0, producedDAG.getIncomingEdgesOf(vertex3.getId()).size());
- assertEquals(3, producedDAG.getOutgoingEdgesOf(vertex3).size());
+ final IRVertex vertex1 = producedDAG.getTopologicalSort().get(5);
+ assertEquals(0, producedDAG.getIncomingEdgesOf(vertex1).size());
+ assertEquals(0, producedDAG.getIncomingEdgesOf(vertex1.getId()).size());
+ assertEquals(3, producedDAG.getOutgoingEdgesOf(vertex1).size());
- final IRVertex vertex13 = producedDAG.getTopologicalSort().get(11);
- assertEquals(1, producedDAG.getIncomingEdgesOf(vertex13).size());
- assertEquals(1, producedDAG.getIncomingEdgesOf(vertex13.getId()).size());
- assertEquals(1, producedDAG.getOutgoingEdgesOf(vertex13).size());
+ final IRVertex vertex15 = producedDAG.getTopologicalSort().get(13);
+ assertEquals(1, producedDAG.getIncomingEdgesOf(vertex15).size());
+ assertEquals(1, producedDAG.getIncomingEdgesOf(vertex15.getId()).size());
+ assertEquals(1, producedDAG.getOutgoingEdgesOf(vertex15).size());
- final IRVertex vertex19 = producedDAG.getTopologicalSort().get(17);
- assertEquals(2, producedDAG.getIncomingEdgesOf(vertex19).size());
- assertEquals(2, producedDAG.getIncomingEdgesOf(vertex19.getId()).size());
- assertEquals(1, producedDAG.getOutgoingEdgesOf(vertex19).size());
+ final IRVertex vertex21 = producedDAG.getTopologicalSort().get(19);
+ assertEquals(2, producedDAG.getIncomingEdgesOf(vertex21).size());
+ assertEquals(2, producedDAG.getIncomingEdgesOf(vertex21.getId()).size());
+ assertEquals(1, producedDAG.getOutgoingEdgesOf(vertex21).size());
}
}
diff --git a/compiler/test/src/test/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/composite/TransientResourceCompositePassTest.java b/compiler/test/src/test/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/composite/TransientResourceCompositePassTest.java
index e035241..97a5cd2 100644
--- a/compiler/test/src/test/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/composite/TransientResourceCompositePassTest.java
+++ b/compiler/test/src/test/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/composite/TransientResourceCompositePassTest.java
@@ -53,37 +53,37 @@ public class TransientResourceCompositePassTest {
final IRVertex vertex1 = processedDAG.getTopologicalSort().get(0);
assertEquals(ResourcePriorityProperty.TRANSIENT, vertex1.getPropertyValue(ResourcePriorityProperty.class).get());
- final IRVertex vertex5 = processedDAG.getTopologicalSort().get(1);
- assertEquals(ResourcePriorityProperty.TRANSIENT, vertex5.getPropertyValue(ResourcePriorityProperty.class).get());
- processedDAG.getIncomingEdgesOf(vertex5).forEach(irEdge -> {
+ final IRVertex vertex2 = processedDAG.getTopologicalSort().get(11);
+ assertEquals(ResourcePriorityProperty.TRANSIENT, vertex2.getPropertyValue(ResourcePriorityProperty.class).get());
+ processedDAG.getIncomingEdgesOf(vertex2).forEach(irEdge -> {
assertEquals(DataStoreProperty.Value.MemoryStore, irEdge.getPropertyValue(DataStoreProperty.class).get());
assertEquals(DataFlowProperty.Value.Pull, irEdge.getPropertyValue(DataFlowProperty.class).get());
});
- final IRVertex vertex6 = processedDAG.getTopologicalSort().get(2);
- assertEquals(ResourcePriorityProperty.RESERVED, vertex6.getPropertyValue(ResourcePriorityProperty.class).get());
- processedDAG.getIncomingEdgesOf(vertex6).forEach(irEdge -> {
+ final IRVertex vertex5 = processedDAG.getTopologicalSort().get(14);
+ assertEquals(ResourcePriorityProperty.RESERVED, vertex5.getPropertyValue(ResourcePriorityProperty.class).get());
+ processedDAG.getIncomingEdgesOf(vertex5).forEach(irEdge -> {
assertEquals(DataStoreProperty.Value.LocalFileStore, irEdge.getPropertyValue(DataStoreProperty.class).get());
assertEquals(DataFlowProperty.Value.Push, irEdge.getPropertyValue(DataFlowProperty.class).get());
});
- final IRVertex vertex4 = processedDAG.getTopologicalSort().get(6);
- assertEquals(ResourcePriorityProperty.RESERVED, vertex4.getPropertyValue(ResourcePriorityProperty.class).get());
- processedDAG.getIncomingEdgesOf(vertex4).forEach(irEdge -> {
+ final IRVertex vertex11 = processedDAG.getTopologicalSort().get(5);
+ assertEquals(ResourcePriorityProperty.RESERVED, vertex11.getPropertyValue(ResourcePriorityProperty.class).get());
+ processedDAG.getIncomingEdgesOf(vertex11).forEach(irEdge -> {
assertEquals(DataStoreProperty.Value.MemoryStore, irEdge.getPropertyValue(DataStoreProperty.class).get());
assertEquals(DataFlowProperty.Value.Pull, irEdge.getPropertyValue(DataFlowProperty.class).get());
});
- final IRVertex vertex13 = processedDAG.getTopologicalSort().get(11);
- assertEquals(ResourcePriorityProperty.TRANSIENT, vertex13.getPropertyValue(ResourcePriorityProperty.class).get());
- processedDAG.getIncomingEdgesOf(vertex13).forEach(irEdge -> {
+ final IRVertex vertex17 = processedDAG.getTopologicalSort().get(10);
+ assertEquals(ResourcePriorityProperty.TRANSIENT, vertex17.getPropertyValue(ResourcePriorityProperty.class).get());
+ processedDAG.getIncomingEdgesOf(vertex17).forEach(irEdge -> {
assertEquals(DataStoreProperty.Value.LocalFileStore, irEdge.getPropertyValue(DataStoreProperty.class).get());
assertEquals(DataFlowProperty.Value.Pull, irEdge.getPropertyValue(DataFlowProperty.class).get());
});
- final IRVertex vertex15 = processedDAG.getTopologicalSort().get(13);
- assertEquals(ResourcePriorityProperty.RESERVED, vertex15.getPropertyValue(ResourcePriorityProperty.class).get());
- processedDAG.getIncomingEdgesOf(vertex15).forEach(irEdge -> {
+ final IRVertex vertex19 = processedDAG.getTopologicalSort().get(17);
+ assertEquals(ResourcePriorityProperty.RESERVED, vertex19.getPropertyValue(ResourcePriorityProperty.class).get());
+ processedDAG.getIncomingEdgesOf(vertex19).forEach(irEdge -> {
assertEquals(DataStoreProperty.Value.LocalFileStore, irEdge.getPropertyValue(DataStoreProperty.class).get());
assertEquals(DataFlowProperty.Value.Push, irEdge.getPropertyValue(DataFlowProperty.class).get());
});
diff --git a/compiler/test/src/test/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/reshaping/LoopExtractionPassTest.java b/compiler/test/src/test/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/reshaping/LoopExtractionPassTest.java
index acccd73..7e0c425 100644
--- a/compiler/test/src/test/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/reshaping/LoopExtractionPassTest.java
+++ b/compiler/test/src/test/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/reshaping/LoopExtractionPassTest.java
@@ -45,6 +45,6 @@ public class LoopExtractionPassTest {
public void testLoopGrouping() {
final DAG<IRVertex, IREdge> processedDAG = new LoopExtractionPass().apply(compiledDAG);
- assertEquals(9, processedDAG.getTopologicalSort().size());
+ assertEquals(13, processedDAG.getTopologicalSort().size());
}
}
diff --git a/compiler/test/src/test/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/reshaping/LoopInvariantCodeMotionALSInefficientTest.java b/compiler/test/src/test/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/reshaping/LoopInvariantCodeMotionALSInefficientTest.java
index 64d5b74..2162463 100644
--- a/compiler/test/src/test/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/reshaping/LoopInvariantCodeMotionALSInefficientTest.java
+++ b/compiler/test/src/test/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/reshaping/LoopInvariantCodeMotionALSInefficientTest.java
@@ -45,7 +45,7 @@ public class LoopInvariantCodeMotionALSInefficientTest {
@Test
public void testForInefficientALSDAG() throws Exception {
- final long expectedNumOfVertices = groupedDAG.getVertices().size() + 3;
+ final long expectedNumOfVertices = groupedDAG.getVertices().size() + 5;
final DAG<IRVertex, IREdge> processedDAG = LoopOptimizations.getLoopInvariantCodeMotionPass()
.apply(groupedDAG);
diff --git a/compiler/test/src/test/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/reshaping/LoopInvariantCodeMotionPassTest.java b/compiler/test/src/test/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/reshaping/LoopInvariantCodeMotionPassTest.java
index 07210ee..5cd7928 100644
--- a/compiler/test/src/test/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/reshaping/LoopInvariantCodeMotionPassTest.java
+++ b/compiler/test/src/test/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/reshaping/LoopInvariantCodeMotionPassTest.java
@@ -59,32 +59,32 @@ public class LoopInvariantCodeMotionPassTest {
assertTrue(alsLoopOpt.isPresent());
final LoopVertex alsLoop = alsLoopOpt.get();
- final IRVertex vertex7 = groupedDAG.getTopologicalSort().get(3);
- final IRVertex vertex14 = alsLoop.getDAG().getTopologicalSort().get(4);
+ final IRVertex vertex6 = groupedDAG.getTopologicalSort().get(11);
+ final IRVertex vertex18 = alsLoop.getDAG().getTopologicalSort().get(4);
- final Set<IREdge> oldDAGIncomingEdges = alsLoop.getDagIncomingEdges().get(vertex14);
- final List<IREdge> newDAGIncomingEdge = groupedDAG.getIncomingEdgesOf(vertex7);
+ final Set<IREdge> oldDAGIncomingEdges = alsLoop.getDagIncomingEdges().get(vertex18);
+ final List<IREdge> newDAGIncomingEdge = groupedDAG.getIncomingEdgesOf(vertex6);
- alsLoop.getDagIncomingEdges().remove(vertex14);
- alsLoop.getDagIncomingEdges().putIfAbsent(vertex7, new HashSet<>());
- newDAGIncomingEdge.forEach(alsLoop.getDagIncomingEdges().get(vertex7)::add);
+ alsLoop.getDagIncomingEdges().remove(vertex18);
+ alsLoop.getDagIncomingEdges().putIfAbsent(vertex6, new HashSet<>());
+ newDAGIncomingEdge.forEach(alsLoop.getDagIncomingEdges().get(vertex6)::add);
- alsLoop.getNonIterativeIncomingEdges().remove(vertex14);
- alsLoop.getNonIterativeIncomingEdges().putIfAbsent(vertex7, new HashSet<>());
- newDAGIncomingEdge.forEach(alsLoop.getNonIterativeIncomingEdges().get(vertex7)::add);
+ alsLoop.getNonIterativeIncomingEdges().remove(vertex18);
+ alsLoop.getNonIterativeIncomingEdges().putIfAbsent(vertex6, new HashSet<>());
+ newDAGIncomingEdge.forEach(alsLoop.getNonIterativeIncomingEdges().get(vertex6)::add);
- alsLoop.getBuilder().addVertex(vertex7);
+ alsLoop.getBuilder().addVertex(vertex6);
oldDAGIncomingEdges.forEach(alsLoop.getBuilder()::connectVertices);
final DAGBuilder<IRVertex, IREdge> builder = new DAGBuilder<>();
groupedDAG.topologicalDo(v -> {
- if (!v.equals(vertex7) && !v.equals(alsLoop)) {
+ if (!v.equals(vertex6) && !v.equals(alsLoop)) {
builder.addVertex(v);
groupedDAG.getIncomingEdgesOf(v).forEach(builder::connectVertices);
} else if (v.equals(alsLoop)) {
builder.addVertex(v);
groupedDAG.getIncomingEdgesOf(v).forEach(e -> {
- if (!e.getSrc().equals(vertex7)) {
+ if (!e.getSrc().equals(vertex6)) {
builder.connectVertices(e);
} else {
final Optional<IREdge> incomingEdge = newDAGIncomingEdge.stream().findFirst();
diff --git a/examples/beam/src/main/java/edu/snu/nemo/examples/beam/WordCount.java b/examples/beam/src/main/java/edu/snu/nemo/examples/beam/WordCount.java
index 88ed4ae..3d3c556 100644
--- a/examples/beam/src/main/java/edu/snu/nemo/examples/beam/WordCount.java
+++ b/examples/beam/src/main/java/edu/snu/nemo/examples/beam/WordCount.java
@@ -56,8 +56,7 @@ public final class WordCount {
return KV.of(documentId, count);
}
}))
- .apply(GroupByKey.<String, Long>create())
- .apply(Combine.<String, Long, Long>groupedValues(Sum.ofLongs()))
+ .apply(Sum.longsPerKey())
.apply(MapElements.<KV<String, Long>, String>via(new SimpleFunction<KV<String, Long>, String>() {
@Override
public String apply(final KV<String, Long> kv) {