You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@nemo.apache.org by jo...@apache.org on 2019/01/22 00:47:05 UTC
[incubator-nemo] branch reshaping updated: handle loops
This is an automated email from the ASF dual-hosted git repository.
johnyangk pushed a commit to branch reshaping
in repository https://gitbox.apache.org/repos/asf/incubator-nemo.git
The following commit(s) were added to refs/heads/reshaping by this push:
new 4319fb1 handle loops
4319fb1 is described below
commit 4319fb1984bef3564e1b77977734223bb2c3ff50
Author: John Yang <jo...@apache.org>
AuthorDate: Tue Jan 22 09:44:30 2019 +0900
handle loops
---
.../compiletime/reshaping/LoopOptimizations.java | 271 +++++++++++----------
.../compiletime/reshaping/LoopUnrollingPass.java | 9 +-
2 files changed, 145 insertions(+), 135 deletions(-)
diff --git a/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/LoopOptimizations.java b/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/LoopOptimizations.java
index 8cae875..27afa4a 100644
--- a/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/LoopOptimizations.java
+++ b/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/LoopOptimizations.java
@@ -18,6 +18,7 @@
*/
package org.apache.nemo.compiler.optimizer.pass.compiletime.reshaping;
+import org.apache.nemo.common.dag.DAG;
import org.apache.nemo.common.dag.DAGBuilder;
import org.apache.nemo.common.ir.IRDAG;
import org.apache.nemo.common.ir.edge.IREdge;
@@ -64,7 +65,7 @@ public final class LoopOptimizations {
* @param outEdges outgoing Edges of LoopVertices.
* @param builder builder to build the rest of the DAG on.
*/
- private static void collectLoopVertices(final IRDAG dag,
+ private static void collectLoopVertices(final DAG<IRVertex, IREdge> dag,
final List<LoopVertex> loopVertices,
final Map<LoopVertex, List<IREdge>> inEdges,
final Map<LoopVertex, List<IREdge>> outEdges,
@@ -112,93 +113,95 @@ public final class LoopOptimizations {
}
@Override
- public void optimize(final IRDAG dag) {
- final List<LoopVertex> loopVertices = new ArrayList<>();
- final Map<LoopVertex, List<IREdge>> inEdges = new HashMap<>();
- final Map<LoopVertex, List<IREdge>> outEdges = new HashMap<>();
- final DAGBuilder<IRVertex, IREdge> builder = new DAGBuilder<>();
+ public void optimize(final IRDAG inputDAG) {
+ inputDAG.unSafeDirectReshaping(dag -> {
+ final List<LoopVertex> loopVertices = new ArrayList<>();
+ final Map<LoopVertex, List<IREdge>> inEdges = new HashMap<>();
+ final Map<LoopVertex, List<IREdge>> outEdges = new HashMap<>();
+ final DAGBuilder<IRVertex, IREdge> builder = new DAGBuilder<>();
- collectLoopVertices(dag, loopVertices, inEdges, outEdges, builder);
+ collectLoopVertices(dag, loopVertices, inEdges, outEdges, builder);
- // Collect and group those with same termination condition.
- final Set<Set<LoopVertex>> setOfLoopsToBeFused = new HashSet<>();
- loopVertices.forEach(loopVertex -> {
- final IntPredicate terminationCondition = loopVertex.getTerminationCondition();
- final Integer numberOfIterations = loopVertex.getMaxNumberOfIterations();
- // We want loopVertices that are not dependent on each other or the list that is potentially going to be merged.
- final List<LoopVertex> independentLoops = loopVertices.stream().filter(loop ->
+ // Collect and group those with same termination condition.
+ final Set<Set<LoopVertex>> setOfLoopsToBeFused = new HashSet<>();
+ loopVertices.forEach(loopVertex -> {
+ final IntPredicate terminationCondition = loopVertex.getTerminationCondition();
+ final Integer numberOfIterations = loopVertex.getMaxNumberOfIterations();
+ // We want loopVertices that are not dependent on each other or the list that is potentially going to be merged.
+ final List<LoopVertex> independentLoops = loopVertices.stream().filter(loop ->
setOfLoopsToBeFused.stream().anyMatch(list -> list.contains(loop))
- ? setOfLoopsToBeFused.stream().filter(list -> list.contains(loop))
- .findFirst()
- .map(list -> list.stream().noneMatch(loopV -> dag.pathExistsBetween(loopV, loopVertex)))
- .orElse(false)
- : !dag.pathExistsBetween(loop, loopVertex)).collect(Collectors.toList());
+ ? setOfLoopsToBeFused.stream().filter(list -> list.contains(loop))
+ .findFirst()
+ .map(list -> list.stream().noneMatch(loopV -> dag.pathExistsBetween(loopV, loopVertex)))
+ .orElse(false)
+ : !dag.pathExistsBetween(loop, loopVertex)).collect(Collectors.toList());
- // Find loops to be fused together.
- final Set<LoopVertex> loopsToBeFused = new HashSet<>();
- loopsToBeFused.add(loopVertex);
- independentLoops.forEach(independentLoop -> {
- // add them to the list if those independent loops have equal termination conditions.
- if (loopVertex.terminationConditionEquals(independentLoop)) {
- loopsToBeFused.add(independentLoop);
- }
- });
+ // Find loops to be fused together.
+ final Set<LoopVertex> loopsToBeFused = new HashSet<>();
+ loopsToBeFused.add(loopVertex);
+ independentLoops.forEach(independentLoop -> {
+ // add them to the list if those independent loops have equal termination conditions.
+ if (loopVertex.terminationConditionEquals(independentLoop)) {
+ loopsToBeFused.add(independentLoop);
+ }
+ });
- // add this information to the setOfLoopsToBeFused set.
- final Optional<Set<LoopVertex>> listToAddVerticesTo = setOfLoopsToBeFused.stream()
+ // add this information to the setOfLoopsToBeFused set.
+ final Optional<Set<LoopVertex>> listToAddVerticesTo = setOfLoopsToBeFused.stream()
.filter(list -> list.stream().anyMatch(loopsToBeFused::contains)).findFirst();
- if (listToAddVerticesTo.isPresent()) {
- listToAddVerticesTo.get().addAll(loopsToBeFused);
- } else {
- setOfLoopsToBeFused.add(loopsToBeFused);
- }
- });
+ if (listToAddVerticesTo.isPresent()) {
+ listToAddVerticesTo.get().addAll(loopsToBeFused);
+ } else {
+ setOfLoopsToBeFused.add(loopsToBeFused);
+ }
+ });
- // merge and add to builder.
- setOfLoopsToBeFused.forEach(loops -> {
- if (loops.size() > 1) {
- final LoopVertex newLoopVertex = mergeLoopVertices(loops);
- builder.addVertex(newLoopVertex, dag);
- loops.forEach(loopVertex -> {
- // inEdges.
- inEdges.getOrDefault(loopVertex, new ArrayList<>()).forEach(irEdge -> {
- if (builder.contains(irEdge.getSrc())) {
- final IREdge newIREdge = new IREdge(irEdge.getPropertyValue(CommunicationPatternProperty.class)
+ // merge and add to builder.
+ setOfLoopsToBeFused.forEach(loops -> {
+ if (loops.size() > 1) {
+ final LoopVertex newLoopVertex = mergeLoopVertices(loops);
+ builder.addVertex(newLoopVertex, dag);
+ loops.forEach(loopVertex -> {
+ // inEdges.
+ inEdges.getOrDefault(loopVertex, new ArrayList<>()).forEach(irEdge -> {
+ if (builder.contains(irEdge.getSrc())) {
+ final IREdge newIREdge = new IREdge(irEdge.getPropertyValue(CommunicationPatternProperty.class)
.get(), irEdge.getSrc(), newLoopVertex);
- irEdge.copyExecutionPropertiesTo(newIREdge);
- builder.connectVertices(newIREdge);
- }
- });
- // outEdges.
- outEdges.getOrDefault(loopVertex, new ArrayList<>()).forEach(irEdge -> {
- if (builder.contains(irEdge.getDst())) {
- final IREdge newIREdge = new IREdge(irEdge.getPropertyValue(CommunicationPatternProperty.class)
+ irEdge.copyExecutionPropertiesTo(newIREdge);
+ builder.connectVertices(newIREdge);
+ }
+ });
+ // outEdges.
+ outEdges.getOrDefault(loopVertex, new ArrayList<>()).forEach(irEdge -> {
+ if (builder.contains(irEdge.getDst())) {
+ final IREdge newIREdge = new IREdge(irEdge.getPropertyValue(CommunicationPatternProperty.class)
.get(), newLoopVertex, irEdge.getDst());
- irEdge.copyExecutionPropertiesTo(newIREdge);
- builder.connectVertices(newIREdge);
- }
- });
- });
- } else {
- loops.forEach(loopVertex -> {
- builder.addVertex(loopVertex);
- // inEdges.
- inEdges.getOrDefault(loopVertex, new ArrayList<>()).forEach(edge -> {
- if (builder.contains(edge.getSrc())) {
- builder.connectVertices(edge);
- }
+ irEdge.copyExecutionPropertiesTo(newIREdge);
+ builder.connectVertices(newIREdge);
+ }
+ });
});
- // outEdges.
- outEdges.getOrDefault(loopVertex, new ArrayList<>()).forEach(edge -> {
- if (builder.contains(edge.getDst())) {
- builder.connectVertices(edge);
- }
+ } else {
+ loops.forEach(loopVertex -> {
+ builder.addVertex(loopVertex);
+ // inEdges.
+ inEdges.getOrDefault(loopVertex, new ArrayList<>()).forEach(edge -> {
+ if (builder.contains(edge.getSrc())) {
+ builder.connectVertices(edge);
+ }
+ });
+ // outEdges.
+ outEdges.getOrDefault(loopVertex, new ArrayList<>()).forEach(edge -> {
+ if (builder.contains(edge.getDst())) {
+ builder.connectVertices(edge);
+ }
+ });
});
- });
- }
- });
+ }
+ });
- return builder.build();
+ return builder.build();
+ });
}
/**
@@ -211,7 +214,7 @@ public final class LoopOptimizations {
String.join("+", loopVertices.stream().map(LoopVertex::getName).collect(Collectors.toList()));
final LoopVertex mergedLoopVertex = new LoopVertex(newName);
loopVertices.forEach(loopVertex -> {
- final IRDAG dagToCopy = loopVertex.getDAG();
+ final DAG<IRVertex, IREdge> dagToCopy = loopVertex.getDAG();
dagToCopy.topologicalDo(v -> {
mergedLoopVertex.getBuilder().addVertex(v);
dagToCopy.getIncomingEdgesOf(v).forEach(mergedLoopVertex.getBuilder()::connectVertices);
@@ -242,67 +245,71 @@ public final class LoopOptimizations {
@Override
public void optimize(final IRDAG inputDAG) {
inputDAG.unSafeDirectReshaping(dag -> {
- final List<LoopVertex> loopVertices = new ArrayList<>();
- final Map<LoopVertex, List<IREdge>> inEdges = new HashMap<>();
- final Map<LoopVertex, List<IREdge>> outEdges = new HashMap<>();
- final DAGBuilder<IRVertex, IREdge> builder = new DAGBuilder<>();
+ return recursivelyOptimize(dag);
+ });
+ }
- collectLoopVertices(dag, loopVertices, inEdges, outEdges, builder);
+ DAG<IRVertex, IREdge> recursivelyOptimize(final DAG<IRVertex, IREdge> dag) {
+ final List<LoopVertex> loopVertices = new ArrayList<>();
+ final Map<LoopVertex, List<IREdge>> inEdges = new HashMap<>();
+ final Map<LoopVertex, List<IREdge>> outEdges = new HashMap<>();
+ final DAGBuilder<IRVertex, IREdge> builder = new DAGBuilder<>();
- // Refactor those with same data scan / operation, without dependencies in the loop.
- loopVertices.forEach(loopVertex -> {
- final List<Map.Entry<IRVertex, Set<IREdge>>> candidates = loopVertex.getNonIterativeIncomingEdges().entrySet()
- .stream().filter(entry ->
- loopVertex.getDAG().getIncomingEdgesOf(entry.getKey()).size() == 0 // no internal inEdges
- // no external inEdges
- && loopVertex.getIterativeIncomingEdges().getOrDefault(entry.getKey(), new HashSet<>()).size() == 0)
- .collect(Collectors.toList());
- candidates.forEach(candidate -> {
- // add refactored vertex to builder.
- builder.addVertex(candidate.getKey());
- // connect incoming edges.
- candidate.getValue().forEach(builder::connectVertices);
- // connect outgoing edges.
- loopVertex.getDAG().getOutgoingEdgesOf(candidate.getKey()).forEach(loopVertex::addDagIncomingEdge);
- loopVertex.getDAG().getOutgoingEdgesOf(candidate.getKey()).forEach(loopVertex::addNonIterativeIncomingEdge);
- // modify incoming edges of loopVertex.
- final List<IREdge> edgesToRemove = new ArrayList<>();
- final List<IREdge> edgesToAdd = new ArrayList<>();
- inEdges.getOrDefault(loopVertex, new ArrayList<>()).stream().filter(e ->
- // filter edges that have their sources as the refactored vertices.
- candidate.getValue().stream().map(IREdge::getSrc).anyMatch(edgeSrc -> edgeSrc.equals(e.getSrc())))
- .forEach(edge -> {
- edgesToRemove.add(edge);
- final IREdge newEdge = new IREdge(edge.getPropertyValue(CommunicationPatternProperty.class).get(),
- candidate.getKey(), edge.getDst());
- newEdge.setProperty(EncoderProperty.of(edge.getPropertyValue(EncoderProperty.class).get()));
- newEdge.setProperty(DecoderProperty.of(edge.getPropertyValue(DecoderProperty.class).get()));
- edgesToAdd.add(newEdge);
- });
- final List<IREdge> listToModify = inEdges.getOrDefault(loopVertex, new ArrayList<>());
- listToModify.removeAll(edgesToRemove);
- listToModify.addAll(edgesToAdd);
- // clear garbage.
- loopVertex.getBuilder().removeVertex(candidate.getKey());
- loopVertex.getDagIncomingEdges().remove(candidate.getKey());
- loopVertex.getNonIterativeIncomingEdges().remove(candidate.getKey());
- });
- });
+ collectLoopVertices(dag, loopVertices, inEdges, outEdges, builder);
- // Add LoopVertices.
- loopVertices.forEach(loopVertex -> {
- builder.addVertex(loopVertex);
- inEdges.getOrDefault(loopVertex, new ArrayList<>()).forEach(builder::connectVertices);
- outEdges.getOrDefault(loopVertex, new ArrayList<>()).forEach(builder::connectVertices);
+ // Refactor those with same data scan / operation, without dependencies in the loop.
+ loopVertices.forEach(loopVertex -> {
+ final List<Map.Entry<IRVertex, Set<IREdge>>> candidates = loopVertex.getNonIterativeIncomingEdges().entrySet()
+ .stream().filter(entry ->
+ loopVertex.getDAG().getIncomingEdgesOf(entry.getKey()).size() == 0 // no internal inEdges
+ // no external inEdges
+ && loopVertex.getIterativeIncomingEdges().getOrDefault(entry.getKey(), new HashSet<>()).size() == 0)
+ .collect(Collectors.toList());
+ candidates.forEach(candidate -> {
+ // add refactored vertex to builder.
+ builder.addVertex(candidate.getKey());
+ // connect incoming edges.
+ candidate.getValue().forEach(builder::connectVertices);
+ // connect outgoing edges.
+ loopVertex.getDAG().getOutgoingEdgesOf(candidate.getKey()).forEach(loopVertex::addDagIncomingEdge);
+ loopVertex.getDAG().getOutgoingEdgesOf(candidate.getKey()).forEach(loopVertex::addNonIterativeIncomingEdge);
+ // modify incoming edges of loopVertex.
+ final List<IREdge> edgesToRemove = new ArrayList<>();
+ final List<IREdge> edgesToAdd = new ArrayList<>();
+ inEdges.getOrDefault(loopVertex, new ArrayList<>()).stream().filter(e ->
+ // filter edges that have their sources as the refactored vertices.
+ candidate.getValue().stream().map(IREdge::getSrc).anyMatch(edgeSrc -> edgeSrc.equals(e.getSrc())))
+ .forEach(edge -> {
+ edgesToRemove.add(edge);
+ final IREdge newEdge = new IREdge(edge.getPropertyValue(CommunicationPatternProperty.class).get(),
+ candidate.getKey(), edge.getDst());
+ newEdge.setProperty(EncoderProperty.of(edge.getPropertyValue(EncoderProperty.class).get()));
+ newEdge.setProperty(DecoderProperty.of(edge.getPropertyValue(DecoderProperty.class).get()));
+ edgesToAdd.add(newEdge);
+ });
+ final List<IREdge> listToModify = inEdges.getOrDefault(loopVertex, new ArrayList<>());
+ listToModify.removeAll(edgesToRemove);
+ listToModify.addAll(edgesToAdd);
+ // clear garbage.
+ loopVertex.getBuilder().removeVertex(candidate.getKey());
+ loopVertex.getDagIncomingEdges().remove(candidate.getKey());
+ loopVertex.getNonIterativeIncomingEdges().remove(candidate.getKey());
});
+ });
- final IRDAG newDag = builder.build();
- if (dag.getVertices().size() == newDag.getVertices().size()) {
- return newDag;
- } else {
- return apply(newDag);
- }
+ // Add LoopVertices.
+ loopVertices.forEach(loopVertex -> {
+ builder.addVertex(loopVertex);
+ inEdges.getOrDefault(loopVertex, new ArrayList<>()).forEach(builder::connectVertices);
+ outEdges.getOrDefault(loopVertex, new ArrayList<>()).forEach(builder::connectVertices);
});
+
+ final DAG<IRVertex, IREdge> newDag = builder.build();
+ if (dag.getVertices().size() == newDag.getVertices().size()) {
+ return newDag;
+ } else {
+ return recursivelyOptimize(newDag);
+ }
}
}
}
diff --git a/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/LoopUnrollingPass.java b/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/LoopUnrollingPass.java
index 1e298c1..7d2219d 100644
--- a/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/LoopUnrollingPass.java
+++ b/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/LoopUnrollingPass.java
@@ -18,6 +18,7 @@
*/
package org.apache.nemo.compiler.optimizer.pass.compiletime.reshaping;
+import org.apache.nemo.common.dag.DAG;
import org.apache.nemo.common.dag.DAGBuilder;
import org.apache.nemo.common.ir.IRDAG;
import org.apache.nemo.common.ir.edge.IREdge;
@@ -38,8 +39,10 @@ public final class LoopUnrollingPass extends ReshapingPass {
}
@Override
- public void optimize(final IRDAG dag) {
- return recursivelyUnroll(dag);
+ public void optimize(final IRDAG inputDAG) {
+ inputDAG.unSafeDirectReshaping(dag -> {
+ return recursivelyUnroll(dag);
+ });
}
/**
@@ -47,7 +50,7 @@ public final class LoopUnrollingPass extends ReshapingPass {
* @param dag DAG to process.
* @return DAG without LoopVertex.
*/
- private IRDAG recursivelyUnroll(final IRDAG dag) {
+ private DAG<IRVertex, IREdge> recursivelyUnroll(final DAG<IRVertex, IREdge> dag) {
final DAGBuilder<IRVertex, IREdge> builder = new DAGBuilder<>();
dag.topologicalDo(irVertex -> {