You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@nemo.apache.org by jo...@apache.org on 2019/01/22 00:47:05 UTC

[incubator-nemo] branch reshaping updated: handle loops

This is an automated email from the ASF dual-hosted git repository.

johnyangk pushed a commit to branch reshaping
in repository https://gitbox.apache.org/repos/asf/incubator-nemo.git


The following commit(s) were added to refs/heads/reshaping by this push:
     new 4319fb1  handle loops
4319fb1 is described below

commit 4319fb1984bef3564e1b77977734223bb2c3ff50
Author: John Yang <jo...@apache.org>
AuthorDate: Tue Jan 22 09:44:30 2019 +0900

    handle loops
---
 .../compiletime/reshaping/LoopOptimizations.java   | 271 +++++++++++----------
 .../compiletime/reshaping/LoopUnrollingPass.java   |   9 +-
 2 files changed, 145 insertions(+), 135 deletions(-)

diff --git a/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/LoopOptimizations.java b/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/LoopOptimizations.java
index 8cae875..27afa4a 100644
--- a/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/LoopOptimizations.java
+++ b/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/LoopOptimizations.java
@@ -18,6 +18,7 @@
  */
 package org.apache.nemo.compiler.optimizer.pass.compiletime.reshaping;
 
+import org.apache.nemo.common.dag.DAG;
 import org.apache.nemo.common.dag.DAGBuilder;
 import org.apache.nemo.common.ir.IRDAG;
 import org.apache.nemo.common.ir.edge.IREdge;
@@ -64,7 +65,7 @@ public final class LoopOptimizations {
    * @param outEdges outgoing Edges of LoopVertices.
    * @param builder builder to build the rest of the DAG on.
    */
-  private static void collectLoopVertices(final IRDAG dag,
+  private static void collectLoopVertices(final DAG<IRVertex, IREdge> dag,
                                           final List<LoopVertex> loopVertices,
                                           final Map<LoopVertex, List<IREdge>> inEdges,
                                           final Map<LoopVertex, List<IREdge>> outEdges,
@@ -112,93 +113,95 @@ public final class LoopOptimizations {
     }
 
     @Override
-    public void optimize(final IRDAG dag) {
-      final List<LoopVertex> loopVertices = new ArrayList<>();
-      final Map<LoopVertex, List<IREdge>> inEdges = new HashMap<>();
-      final Map<LoopVertex, List<IREdge>> outEdges = new HashMap<>();
-      final DAGBuilder<IRVertex, IREdge> builder = new DAGBuilder<>();
+    public void optimize(final IRDAG inputDAG) {
+      inputDAG.unSafeDirectReshaping(dag -> {
+        final List<LoopVertex> loopVertices = new ArrayList<>();
+        final Map<LoopVertex, List<IREdge>> inEdges = new HashMap<>();
+        final Map<LoopVertex, List<IREdge>> outEdges = new HashMap<>();
+        final DAGBuilder<IRVertex, IREdge> builder = new DAGBuilder<>();
 
-      collectLoopVertices(dag, loopVertices, inEdges, outEdges, builder);
+        collectLoopVertices(dag, loopVertices, inEdges, outEdges, builder);
 
-      // Collect and group those with same termination condition.
-      final Set<Set<LoopVertex>> setOfLoopsToBeFused = new HashSet<>();
-      loopVertices.forEach(loopVertex -> {
-        final IntPredicate terminationCondition = loopVertex.getTerminationCondition();
-        final Integer numberOfIterations = loopVertex.getMaxNumberOfIterations();
-        // We want loopVertices that are not dependent on each other or the list that is potentially going to be merged.
-        final List<LoopVertex> independentLoops = loopVertices.stream().filter(loop ->
+        // Collect and group those with same termination condition.
+        final Set<Set<LoopVertex>> setOfLoopsToBeFused = new HashSet<>();
+        loopVertices.forEach(loopVertex -> {
+          final IntPredicate terminationCondition = loopVertex.getTerminationCondition();
+          final Integer numberOfIterations = loopVertex.getMaxNumberOfIterations();
+          // We want loopVertices that are not dependent on each other or the list that is potentially going to be merged.
+          final List<LoopVertex> independentLoops = loopVertices.stream().filter(loop ->
             setOfLoopsToBeFused.stream().anyMatch(list -> list.contains(loop))
-                ? setOfLoopsToBeFused.stream().filter(list -> list.contains(loop))
-                .findFirst()
-                .map(list -> list.stream().noneMatch(loopV -> dag.pathExistsBetween(loopV, loopVertex)))
-                .orElse(false)
-                : !dag.pathExistsBetween(loop, loopVertex)).collect(Collectors.toList());
+              ? setOfLoopsToBeFused.stream().filter(list -> list.contains(loop))
+              .findFirst()
+              .map(list -> list.stream().noneMatch(loopV -> dag.pathExistsBetween(loopV, loopVertex)))
+              .orElse(false)
+              : !dag.pathExistsBetween(loop, loopVertex)).collect(Collectors.toList());
 
-        // Find loops to be fused together.
-        final Set<LoopVertex> loopsToBeFused = new HashSet<>();
-        loopsToBeFused.add(loopVertex);
-        independentLoops.forEach(independentLoop -> {
-          // add them to the list if those independent loops have equal termination conditions.
-          if (loopVertex.terminationConditionEquals(independentLoop)) {
-            loopsToBeFused.add(independentLoop);
-          }
-        });
+          // Find loops to be fused together.
+          final Set<LoopVertex> loopsToBeFused = new HashSet<>();
+          loopsToBeFused.add(loopVertex);
+          independentLoops.forEach(independentLoop -> {
+            // add them to the list if those independent loops have equal termination conditions.
+            if (loopVertex.terminationConditionEquals(independentLoop)) {
+              loopsToBeFused.add(independentLoop);
+            }
+          });
 
-        // add this information to the setOfLoopsToBeFused set.
-        final Optional<Set<LoopVertex>> listToAddVerticesTo = setOfLoopsToBeFused.stream()
+          // add this information to the setOfLoopsToBeFused set.
+          final Optional<Set<LoopVertex>> listToAddVerticesTo = setOfLoopsToBeFused.stream()
             .filter(list -> list.stream().anyMatch(loopsToBeFused::contains)).findFirst();
-        if (listToAddVerticesTo.isPresent()) {
-          listToAddVerticesTo.get().addAll(loopsToBeFused);
-        } else {
-          setOfLoopsToBeFused.add(loopsToBeFused);
-        }
-      });
+          if (listToAddVerticesTo.isPresent()) {
+            listToAddVerticesTo.get().addAll(loopsToBeFused);
+          } else {
+            setOfLoopsToBeFused.add(loopsToBeFused);
+          }
+        });
 
-      // merge and add to builder.
-      setOfLoopsToBeFused.forEach(loops -> {
-        if (loops.size() > 1) {
-          final LoopVertex newLoopVertex = mergeLoopVertices(loops);
-          builder.addVertex(newLoopVertex, dag);
-          loops.forEach(loopVertex -> {
-            // inEdges.
-            inEdges.getOrDefault(loopVertex, new ArrayList<>()).forEach(irEdge -> {
-              if (builder.contains(irEdge.getSrc())) {
-                final IREdge newIREdge = new IREdge(irEdge.getPropertyValue(CommunicationPatternProperty.class)
+        // merge and add to builder.
+        setOfLoopsToBeFused.forEach(loops -> {
+          if (loops.size() > 1) {
+            final LoopVertex newLoopVertex = mergeLoopVertices(loops);
+            builder.addVertex(newLoopVertex, dag);
+            loops.forEach(loopVertex -> {
+              // inEdges.
+              inEdges.getOrDefault(loopVertex, new ArrayList<>()).forEach(irEdge -> {
+                if (builder.contains(irEdge.getSrc())) {
+                  final IREdge newIREdge = new IREdge(irEdge.getPropertyValue(CommunicationPatternProperty.class)
                     .get(), irEdge.getSrc(), newLoopVertex);
-                irEdge.copyExecutionPropertiesTo(newIREdge);
-                builder.connectVertices(newIREdge);
-              }
-            });
-            // outEdges.
-            outEdges.getOrDefault(loopVertex, new ArrayList<>()).forEach(irEdge -> {
-              if (builder.contains(irEdge.getDst())) {
-                final IREdge newIREdge = new IREdge(irEdge.getPropertyValue(CommunicationPatternProperty.class)
+                  irEdge.copyExecutionPropertiesTo(newIREdge);
+                  builder.connectVertices(newIREdge);
+                }
+              });
+              // outEdges.
+              outEdges.getOrDefault(loopVertex, new ArrayList<>()).forEach(irEdge -> {
+                if (builder.contains(irEdge.getDst())) {
+                  final IREdge newIREdge = new IREdge(irEdge.getPropertyValue(CommunicationPatternProperty.class)
                     .get(), newLoopVertex, irEdge.getDst());
-                irEdge.copyExecutionPropertiesTo(newIREdge);
-                builder.connectVertices(newIREdge);
-              }
-            });
-          });
-        } else {
-          loops.forEach(loopVertex -> {
-            builder.addVertex(loopVertex);
-            // inEdges.
-            inEdges.getOrDefault(loopVertex, new ArrayList<>()).forEach(edge -> {
-              if (builder.contains(edge.getSrc())) {
-                builder.connectVertices(edge);
-              }
+                  irEdge.copyExecutionPropertiesTo(newIREdge);
+                  builder.connectVertices(newIREdge);
+                }
+              });
             });
-            // outEdges.
-            outEdges.getOrDefault(loopVertex, new ArrayList<>()).forEach(edge -> {
-              if (builder.contains(edge.getDst())) {
-                builder.connectVertices(edge);
-              }
+          } else {
+            loops.forEach(loopVertex -> {
+              builder.addVertex(loopVertex);
+              // inEdges.
+              inEdges.getOrDefault(loopVertex, new ArrayList<>()).forEach(edge -> {
+                if (builder.contains(edge.getSrc())) {
+                  builder.connectVertices(edge);
+                }
+              });
+              // outEdges.
+              outEdges.getOrDefault(loopVertex, new ArrayList<>()).forEach(edge -> {
+                if (builder.contains(edge.getDst())) {
+                  builder.connectVertices(edge);
+                }
+              });
             });
-          });
-        }
-      });
+          }
+        });
 
-      return builder.build();
+        return builder.build();
+      });
     }
 
     /**
@@ -211,7 +214,7 @@ public final class LoopOptimizations {
           String.join("+", loopVertices.stream().map(LoopVertex::getName).collect(Collectors.toList()));
       final LoopVertex mergedLoopVertex = new LoopVertex(newName);
       loopVertices.forEach(loopVertex -> {
-        final IRDAG dagToCopy = loopVertex.getDAG();
+        final DAG<IRVertex, IREdge> dagToCopy = loopVertex.getDAG();
         dagToCopy.topologicalDo(v -> {
           mergedLoopVertex.getBuilder().addVertex(v);
           dagToCopy.getIncomingEdgesOf(v).forEach(mergedLoopVertex.getBuilder()::connectVertices);
@@ -242,67 +245,71 @@ public final class LoopOptimizations {
     @Override
     public void optimize(final IRDAG inputDAG) {
       inputDAG.unSafeDirectReshaping(dag -> {
-        final List<LoopVertex> loopVertices = new ArrayList<>();
-        final Map<LoopVertex, List<IREdge>> inEdges = new HashMap<>();
-        final Map<LoopVertex, List<IREdge>> outEdges = new HashMap<>();
-        final DAGBuilder<IRVertex, IREdge> builder = new DAGBuilder<>();
+        return recursivelyOptimize(dag);
+      });
+    }
 
-        collectLoopVertices(dag, loopVertices, inEdges, outEdges, builder);
+    DAG<IRVertex, IREdge> recursivelyOptimize(final DAG<IRVertex, IREdge> dag) {
+      final List<LoopVertex> loopVertices = new ArrayList<>();
+      final Map<LoopVertex, List<IREdge>> inEdges = new HashMap<>();
+      final Map<LoopVertex, List<IREdge>> outEdges = new HashMap<>();
+      final DAGBuilder<IRVertex, IREdge> builder = new DAGBuilder<>();
 
-        // Refactor those with same data scan / operation, without dependencies in the loop.
-        loopVertices.forEach(loopVertex -> {
-          final List<Map.Entry<IRVertex, Set<IREdge>>> candidates = loopVertex.getNonIterativeIncomingEdges().entrySet()
-            .stream().filter(entry ->
-              loopVertex.getDAG().getIncomingEdgesOf(entry.getKey()).size() == 0 // no internal inEdges
-                // no external inEdges
-                && loopVertex.getIterativeIncomingEdges().getOrDefault(entry.getKey(), new HashSet<>()).size() == 0)
-            .collect(Collectors.toList());
-          candidates.forEach(candidate -> {
-            // add refactored vertex to builder.
-            builder.addVertex(candidate.getKey());
-            // connect incoming edges.
-            candidate.getValue().forEach(builder::connectVertices);
-            // connect outgoing edges.
-            loopVertex.getDAG().getOutgoingEdgesOf(candidate.getKey()).forEach(loopVertex::addDagIncomingEdge);
-            loopVertex.getDAG().getOutgoingEdgesOf(candidate.getKey()).forEach(loopVertex::addNonIterativeIncomingEdge);
-            // modify incoming edges of loopVertex.
-            final List<IREdge> edgesToRemove = new ArrayList<>();
-            final List<IREdge> edgesToAdd = new ArrayList<>();
-            inEdges.getOrDefault(loopVertex, new ArrayList<>()).stream().filter(e ->
-              // filter edges that have their sources as the refactored vertices.
-              candidate.getValue().stream().map(IREdge::getSrc).anyMatch(edgeSrc -> edgeSrc.equals(e.getSrc())))
-              .forEach(edge -> {
-                edgesToRemove.add(edge);
-                final IREdge newEdge = new IREdge(edge.getPropertyValue(CommunicationPatternProperty.class).get(),
-                  candidate.getKey(), edge.getDst());
-                newEdge.setProperty(EncoderProperty.of(edge.getPropertyValue(EncoderProperty.class).get()));
-                newEdge.setProperty(DecoderProperty.of(edge.getPropertyValue(DecoderProperty.class).get()));
-                edgesToAdd.add(newEdge);
-              });
-            final List<IREdge> listToModify = inEdges.getOrDefault(loopVertex, new ArrayList<>());
-            listToModify.removeAll(edgesToRemove);
-            listToModify.addAll(edgesToAdd);
-            // clear garbage.
-            loopVertex.getBuilder().removeVertex(candidate.getKey());
-            loopVertex.getDagIncomingEdges().remove(candidate.getKey());
-            loopVertex.getNonIterativeIncomingEdges().remove(candidate.getKey());
-          });
-        });
+      collectLoopVertices(dag, loopVertices, inEdges, outEdges, builder);
 
-        // Add LoopVertices.
-        loopVertices.forEach(loopVertex -> {
-          builder.addVertex(loopVertex);
-          inEdges.getOrDefault(loopVertex, new ArrayList<>()).forEach(builder::connectVertices);
-          outEdges.getOrDefault(loopVertex, new ArrayList<>()).forEach(builder::connectVertices);
+      // Refactor those with same data scan / operation, without dependencies in the loop.
+      loopVertices.forEach(loopVertex -> {
+        final List<Map.Entry<IRVertex, Set<IREdge>>> candidates = loopVertex.getNonIterativeIncomingEdges().entrySet()
+          .stream().filter(entry ->
+            loopVertex.getDAG().getIncomingEdgesOf(entry.getKey()).size() == 0 // no internal inEdges
+              // no external inEdges
+              && loopVertex.getIterativeIncomingEdges().getOrDefault(entry.getKey(), new HashSet<>()).size() == 0)
+          .collect(Collectors.toList());
+        candidates.forEach(candidate -> {
+          // add refactored vertex to builder.
+          builder.addVertex(candidate.getKey());
+          // connect incoming edges.
+          candidate.getValue().forEach(builder::connectVertices);
+          // connect outgoing edges.
+          loopVertex.getDAG().getOutgoingEdgesOf(candidate.getKey()).forEach(loopVertex::addDagIncomingEdge);
+          loopVertex.getDAG().getOutgoingEdgesOf(candidate.getKey()).forEach(loopVertex::addNonIterativeIncomingEdge);
+          // modify incoming edges of loopVertex.
+          final List<IREdge> edgesToRemove = new ArrayList<>();
+          final List<IREdge> edgesToAdd = new ArrayList<>();
+          inEdges.getOrDefault(loopVertex, new ArrayList<>()).stream().filter(e ->
+            // filter edges that have their sources as the refactored vertices.
+            candidate.getValue().stream().map(IREdge::getSrc).anyMatch(edgeSrc -> edgeSrc.equals(e.getSrc())))
+            .forEach(edge -> {
+              edgesToRemove.add(edge);
+              final IREdge newEdge = new IREdge(edge.getPropertyValue(CommunicationPatternProperty.class).get(),
+                candidate.getKey(), edge.getDst());
+              newEdge.setProperty(EncoderProperty.of(edge.getPropertyValue(EncoderProperty.class).get()));
+              newEdge.setProperty(DecoderProperty.of(edge.getPropertyValue(DecoderProperty.class).get()));
+              edgesToAdd.add(newEdge);
+            });
+          final List<IREdge> listToModify = inEdges.getOrDefault(loopVertex, new ArrayList<>());
+          listToModify.removeAll(edgesToRemove);
+          listToModify.addAll(edgesToAdd);
+          // clear garbage.
+          loopVertex.getBuilder().removeVertex(candidate.getKey());
+          loopVertex.getDagIncomingEdges().remove(candidate.getKey());
+          loopVertex.getNonIterativeIncomingEdges().remove(candidate.getKey());
         });
+      });
 
-        final IRDAG newDag = builder.build();
-        if (dag.getVertices().size() == newDag.getVertices().size()) {
-          return newDag;
-        } else {
-          return apply(newDag);
-        }
+      // Add LoopVertices.
+      loopVertices.forEach(loopVertex -> {
+        builder.addVertex(loopVertex);
+        inEdges.getOrDefault(loopVertex, new ArrayList<>()).forEach(builder::connectVertices);
+        outEdges.getOrDefault(loopVertex, new ArrayList<>()).forEach(builder::connectVertices);
       });
+
+      final DAG<IRVertex, IREdge> newDag = builder.build();
+      if (dag.getVertices().size() == newDag.getVertices().size()) {
+        return newDag;
+      } else {
+        return recursivelyOptimize(newDag);
+      }
     }
   }
 }
diff --git a/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/LoopUnrollingPass.java b/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/LoopUnrollingPass.java
index 1e298c1..7d2219d 100644
--- a/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/LoopUnrollingPass.java
+++ b/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/LoopUnrollingPass.java
@@ -18,6 +18,7 @@
  */
 package org.apache.nemo.compiler.optimizer.pass.compiletime.reshaping;
 
+import org.apache.nemo.common.dag.DAG;
 import org.apache.nemo.common.dag.DAGBuilder;
 import org.apache.nemo.common.ir.IRDAG;
 import org.apache.nemo.common.ir.edge.IREdge;
@@ -38,8 +39,10 @@ public final class LoopUnrollingPass extends ReshapingPass {
   }
 
   @Override
-  public void optimize(final IRDAG dag) {
-    return recursivelyUnroll(dag);
+  public void optimize(final IRDAG inputDAG) {
+    inputDAG.unSafeDirectReshaping(dag -> {
+      return recursivelyUnroll(dag);
+    });
   }
 
   /**
@@ -47,7 +50,7 @@ public final class LoopUnrollingPass extends ReshapingPass {
    * @param dag DAG to process.
    * @return DAG without LoopVertex.
    */
-  private IRDAG recursivelyUnroll(final IRDAG dag) {
+  private DAG<IRVertex, IREdge> recursivelyUnroll(final DAG<IRVertex, IREdge> dag) {
     final DAGBuilder<IRVertex, IREdge> builder = new DAGBuilder<>();
 
     dag.topologicalDo(irVertex -> {