You are viewing a plain text version of this content. The canonical link for it is here.
Posted to dev@nemo.apache.org by GitBox <gi...@apache.org> on 2018/06/26 06:46:16 UTC

[GitHub] johnyangk closed pull request #54: [NEMO-126] Split ScheduleGroup by pull StageEdges by PhysicalPlanGenerator

johnyangk closed pull request #54: [NEMO-126] Split ScheduleGroup by pull StageEdges by PhysicalPlanGenerator
URL: https://github.com/apache/incubator-nemo/pull/54
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/annotating/DefaultScheduleGroupPass.java b/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/annotating/DefaultScheduleGroupPass.java
index 55c16c51..38f964cb 100644
--- a/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/annotating/DefaultScheduleGroupPass.java
+++ b/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/annotating/DefaultScheduleGroupPass.java
@@ -113,6 +113,9 @@ public DefaultScheduleGroupPass(final boolean allowBroadcastWithinScheduleGroup,
         if (skip) {
           continue;
         }
+        if (irVertexToScheduleGroupMap.containsKey(connectedIRVertex)) {
+          continue;
+        }
         // Now we can assure that all vertices that connectedIRVertex depends on have assigned a ScheduleGroup
 
         // Get ScheduleGroup(s) that push data to the connectedIRVertex
diff --git a/runtime/common/src/main/java/edu/snu/nemo/runtime/common/plan/PhysicalPlanGenerator.java b/runtime/common/src/main/java/edu/snu/nemo/runtime/common/plan/PhysicalPlanGenerator.java
index d6793876..107a6231 100644
--- a/runtime/common/src/main/java/edu/snu/nemo/runtime/common/plan/PhysicalPlanGenerator.java
+++ b/runtime/common/src/main/java/edu/snu/nemo/runtime/common/plan/PhysicalPlanGenerator.java
@@ -16,6 +16,7 @@
 package edu.snu.nemo.runtime.common.plan;
 
 import edu.snu.nemo.common.ir.Readable;
+import edu.snu.nemo.common.ir.edge.executionproperty.DataFlowModelProperty;
 import edu.snu.nemo.common.ir.edge.executionproperty.DuplicateEdgeGroupProperty;
 import edu.snu.nemo.common.ir.edge.executionproperty.DuplicateEdgeGroupPropertyValue;
 import edu.snu.nemo.common.ir.executionproperty.ExecutionPropertyMap;
@@ -31,6 +32,7 @@
 import edu.snu.nemo.common.exception.IllegalVertexOperationException;
 import edu.snu.nemo.common.exception.PhysicalPlanGenerationException;
 import edu.snu.nemo.runtime.common.RuntimeIdGenerator;
+import org.apache.commons.lang3.mutable.MutableInt;
 import org.apache.reef.tang.annotations.Parameter;
 
 import javax.inject.Inject;
@@ -75,6 +77,10 @@ private PhysicalPlanGenerator(final StagePartitioner stagePartitioner,
     // this is needed because of DuplicateEdgeGroupProperty.
     handleDuplicateEdgeGroupProperty(dagOfStages);
 
+
+    // Split StageGroup by Pull StageEdges
+    splitScheduleGroupByPullStageEdges(dagOfStages);
+
     // for debugging purposes.
     dagOfStages.storeJSON(dagDirectory, "plan-logical", "logical execution plan");
 
@@ -232,4 +238,76 @@ private void integrityCheck(final Stage stage) {
       }
     });
   }
+
+  /**
+   * Split ScheduleGroups by Pull {@link StageEdge}s, and ensure topological ordering of
+   * {@link ScheduleGroupIndexProperty}.
+   *
+   * @param dag {@link DAG} of {@link Stage}s to manipulate
+   */
+  private void splitScheduleGroupByPullStageEdges(final DAG<Stage, StageEdge> dag) {
+    final MutableInt nextScheduleGroupIndex = new MutableInt(0);
+    final Map<Stage, Integer> stageToScheduleGroupIndexMap = new HashMap<>();
+
+    dag.topologicalDo(currentStage -> {
+      // Base case: assign New ScheduleGroupIndex of the Stage
+      stageToScheduleGroupIndexMap.computeIfAbsent(currentStage, s -> getAndIncrement(nextScheduleGroupIndex));
+
+      for (final StageEdge stageEdgeFromCurrentStage : dag.getOutgoingEdgesOf(currentStage)) {
+        final Stage destination = stageEdgeFromCurrentStage.getDst();
+        // Skip if some Stages that destination depends on do not have assigned new ScheduleGroupIndex
+        boolean skip = false;
+        for (final StageEdge stageEdgeToDestination : dag.getIncomingEdgesOf(destination)) {
+          if (!stageToScheduleGroupIndexMap.containsKey(stageEdgeToDestination.getSrc())) {
+            skip = true;
+            break;
+          }
+        }
+        if (skip) {
+          continue;
+        }
+        if (stageToScheduleGroupIndexMap.containsKey(destination)) {
+          continue;
+        }
+
+        // Find any non-pull inEdge
+        Integer scheduleGroupIndex = null;
+        Integer newScheduleGroupIndex = null;
+        for (final StageEdge stageEdge : dag.getIncomingEdgesOf(destination)) {
+          final Stage source = stageEdge.getSrc();
+          if (stageEdge.getDataFlowModel() != DataFlowModelProperty.Value.Pull) {
+            if (scheduleGroupIndex != null && source.getScheduleGroupIndex() != scheduleGroupIndex) {
+              throw new RuntimeException(String.format("Multiple Push inEdges from different ScheduleGroup: %d, %d",
+                  scheduleGroupIndex, source.getScheduleGroupIndex()));
+            }
+            if (source.getScheduleGroupIndex() != destination.getScheduleGroupIndex()) {
+              throw new RuntimeException(String.format("Split ScheduleGroup by push StageEdge: %d, %d",
+                  source.getScheduleGroupIndex(), destination.getScheduleGroupIndex()));
+            }
+            scheduleGroupIndex = source.getScheduleGroupIndex();
+            newScheduleGroupIndex = stageToScheduleGroupIndexMap.get(source);
+          }
+        }
+
+        if (newScheduleGroupIndex == null) {
+          stageToScheduleGroupIndexMap.put(destination, getAndIncrement(nextScheduleGroupIndex));
+        } else {
+          stageToScheduleGroupIndexMap.put(destination, newScheduleGroupIndex);
+        }
+      }
+    });
+
+    dag.topologicalDo(stage -> {
+      final int scheduleGroupIndex = stageToScheduleGroupIndexMap.get(stage);
+      stage.getExecutionProperties().put(ScheduleGroupIndexProperty.of(scheduleGroupIndex));
+      stage.getIRDAG().topologicalDo(vertex -> vertex.getExecutionProperties()
+          .put(ScheduleGroupIndexProperty.of(scheduleGroupIndex)));
+    });
+  }
+
+  private static int getAndIncrement(final MutableInt mutableInt) {
+    final int toReturn = mutableInt.getValue();
+    mutableInt.increment();
+    return toReturn;
+  }
 }
diff --git a/runtime/common/src/main/java/edu/snu/nemo/runtime/common/plan/Stage.java b/runtime/common/src/main/java/edu/snu/nemo/runtime/common/plan/Stage.java
index ee0a337e..9e5da38b 100644
--- a/runtime/common/src/main/java/edu/snu/nemo/runtime/common/plan/Stage.java
+++ b/runtime/common/src/main/java/edu/snu/nemo/runtime/common/plan/Stage.java
@@ -40,8 +40,6 @@
   private final byte[] serializedIRDag;
   private final ExecutionPropertyMap<VertexExecutionProperty> executionProperties;
   private final List<Map<String, Readable>> vertexIdToReadables;
-  private final int parallelism;
-  private final int scheduleGroupIndex;
 
   /**
    * Constructor.
@@ -60,10 +58,6 @@ public Stage(final String stageId,
     this.serializedIRDag = SerializationUtils.serialize(irDag);
     this.executionProperties = executionProperties;
     this.vertexIdToReadables = vertexIdToReadables;
-    this.parallelism = executionProperties.get(ParallelismProperty.class)
-        .orElseThrow(() -> new RuntimeException("Parallelism property must be set for Stage"));
-    this.scheduleGroupIndex = executionProperties.get(ScheduleGroupIndexProperty.class)
-        .orElseThrow(() -> new RuntimeException("ScheduleGroupIndex property must be set for Stage"));
   }
 
   /**
@@ -85,17 +79,26 @@ public Stage(final String stageId,
    */
   public List<String> getTaskIds() {
     final List<String> taskIds = new ArrayList<>();
-    for (int taskIdx = 0; taskIdx < parallelism; taskIdx++) {
+    for (int taskIdx = 0; taskIdx < getParallelism(); taskIdx++) {
       taskIds.add(RuntimeIdGenerator.generateTaskId(taskIdx, getId()));
     }
     return taskIds;
   }
 
+  /**
+   * @return the parallelism
+   */
+  public int getParallelism() {
+    return executionProperties.get(ParallelismProperty.class)
+        .orElseThrow(() -> new RuntimeException("Parallelism property must be set for Stage"));
+  }
+
   /**
    * @return the schedule group index.
    */
   public int getScheduleGroupIndex() {
-    return scheduleGroupIndex;
+    return executionProperties.get(ScheduleGroupIndexProperty.class)
+        .orElseThrow(() -> new RuntimeException("ScheduleGroupIndex property must be set for Stage"));
   }
 
   /**
@@ -127,9 +130,9 @@ public int getScheduleGroupIndex() {
   @Override
   public String propertiesToJSON() {
     final StringBuilder sb = new StringBuilder();
-    sb.append("{\"scheduleGroupIndex\": ").append(scheduleGroupIndex);
+    sb.append("{\"scheduleGroupIndex\": ").append(getScheduleGroupIndex());
     sb.append(", \"irDag\": ").append(irDag);
-    sb.append(", \"parallelism\": ").append(parallelism);
+    sb.append(", \"parallelism\": ").append(getParallelism());
     sb.append(", \"executionProperties\": ").append(executionProperties);
     sb.append('}');
     return sb.toString();
diff --git a/runtime/common/src/main/java/edu/snu/nemo/runtime/common/plan/StageEdge.java b/runtime/common/src/main/java/edu/snu/nemo/runtime/common/plan/StageEdge.java
index d2697f4a..a45b9b60 100644
--- a/runtime/common/src/main/java/edu/snu/nemo/runtime/common/plan/StageEdge.java
+++ b/runtime/common/src/main/java/edu/snu/nemo/runtime/common/plan/StageEdge.java
@@ -16,6 +16,9 @@
 package edu.snu.nemo.runtime.common.plan;
 
 import com.google.common.annotations.VisibleForTesting;
+import edu.snu.nemo.common.ir.edge.executionproperty.DataCommunicationPatternProperty;
+import edu.snu.nemo.common.ir.edge.executionproperty.DataFlowModelProperty;
+import edu.snu.nemo.common.ir.executionproperty.EdgeExecutionProperty;
 import edu.snu.nemo.common.ir.vertex.IRVertex;
 import edu.snu.nemo.common.ir.executionproperty.ExecutionPropertyMap;
 import edu.snu.nemo.runtime.common.data.KeyRange;
@@ -46,6 +49,16 @@
    */
   private List<KeyRange> taskIdxToKeyRange;
 
+  /**
+   * Value for {@link DataCommunicationPatternProperty}.
+   */
+  private final DataCommunicationPatternProperty.Value dataCommunicationPatternValue;
+
+  /**
+   * Value for {@link DataFlowModelProperty}.
+   */
+  private final DataFlowModelProperty.Value dataFlowModelValue;
+
   /**
    * Constructor.
    *
@@ -59,7 +72,7 @@
    */
   @VisibleForTesting
   public StageEdge(final String runtimeEdgeId,
-            final ExecutionPropertyMap edgeProperties,
+            final ExecutionPropertyMap<EdgeExecutionProperty> edgeProperties,
             final IRVertex srcVertex,
             final IRVertex dstVertex,
             final Stage srcStage,
@@ -73,6 +86,12 @@ public StageEdge(final String runtimeEdgeId,
     for (int taskIdx = 0; taskIdx < dstStage.getTaskIds().size(); taskIdx++) {
       taskIdxToKeyRange.add(HashRange.of(taskIdx, taskIdx + 1));
     }
+    this.dataCommunicationPatternValue = edgeProperties.get(DataCommunicationPatternProperty.class)
+        .orElseThrow(() -> new RuntimeException(String.format(
+            "DataCommunicationPatternProperty not set for %s", runtimeEdgeId)));
+    this.dataFlowModelValue = edgeProperties.get(DataFlowModelProperty.class)
+        .orElseThrow(() -> new RuntimeException(String.format(
+            "DataFlowModelProperty not set for %s", runtimeEdgeId)));
   }
 
   /**
@@ -115,4 +134,18 @@ public String propertiesToJSON() {
   public void setTaskIdxToKeyRange(final List<KeyRange> taskIdxToKeyRange) {
     this.taskIdxToKeyRange = taskIdxToKeyRange;
   }
+
+  /**
+   * @return {@link DataCommunicationPatternProperty} value.
+   */
+  public DataCommunicationPatternProperty.Value getDataCommunicationPattern() {
+    return dataCommunicationPatternValue;
+  }
+
+  /**
+   * @return {@link DataFlowModelProperty} value.
+   */
+  public DataFlowModelProperty.Value getDataFlowModel() {
+    return dataFlowModelValue;
+  }
 }
diff --git a/runtime/master/src/test/java/edu/snu/nemo/runtime/master/scheduler/FaultToleranceTest.java b/runtime/master/src/test/java/edu/snu/nemo/runtime/master/scheduler/FaultToleranceTest.java
index 0e57dc5b..62b51812 100644
--- a/runtime/master/src/test/java/edu/snu/nemo/runtime/master/scheduler/FaultToleranceTest.java
+++ b/runtime/master/src/test/java/edu/snu/nemo/runtime/master/scheduler/FaultToleranceTest.java
@@ -102,7 +102,7 @@ private Scheduler setUpScheduler(final boolean useMockSchedulerRunner) throws In
   /**
    * Tests fault tolerance after a container removal.
    */
-  @Test(timeout=5000)
+  @Test(timeout=50000)
   public void testContainerRemoval() throws Exception {
     final ActiveContext activeContext = mock(ActiveContext.class);
     Mockito.doThrow(new RuntimeException()).when(activeContext).close();
@@ -180,7 +180,7 @@ public void testContainerRemoval() throws Exception {
   /**
    * Tests fault tolerance after an output write failure.
    */
-  @Test(timeout=5000)
+  @Test(timeout=50000)
   public void testOutputFailure() throws Exception {
     final ActiveContext activeContext = mock(ActiveContext.class);
     Mockito.doThrow(new RuntimeException()).when(activeContext).close();
@@ -245,7 +245,7 @@ public void testOutputFailure() throws Exception {
   /**
    * Tests fault tolerance after an input read failure.
    */
-  @Test(timeout=5000)
+  @Test(timeout=50000)
   public void testInputReadFailure() throws Exception {
     final ActiveContext activeContext = mock(ActiveContext.class);
     Mockito.doThrow(new RuntimeException()).when(activeContext).close();
@@ -311,7 +311,7 @@ public void testInputReadFailure() throws Exception {
   /**
    * Tests the rescheduling of Tasks upon a failure.
    */
-  @Test(timeout=20000)
+  @Test(timeout=200000)
   public void testTaskReexecutionForFailure() throws Exception {
     final ActiveContext activeContext = mock(ActiveContext.class);
     Mockito.doThrow(new RuntimeException()).when(activeContext).close();
diff --git a/tests/src/test/java/edu/snu/nemo/runtime/common/plan/PhysicalPlanGeneratorTest.java b/tests/src/test/java/edu/snu/nemo/runtime/common/plan/PhysicalPlanGeneratorTest.java
new file mode 100644
index 00000000..f8c4f902
--- /dev/null
+++ b/tests/src/test/java/edu/snu/nemo/runtime/common/plan/PhysicalPlanGeneratorTest.java
@@ -0,0 +1,83 @@
+/*
+ * Copyright (C) 2018 Seoul National University
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *         http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package edu.snu.nemo.runtime.common.plan;
+
+import edu.snu.nemo.common.dag.DAG;
+import edu.snu.nemo.common.dag.DAGBuilder;
+import edu.snu.nemo.common.ir.edge.IREdge;
+import edu.snu.nemo.common.ir.edge.executionproperty.DataCommunicationPatternProperty;
+import edu.snu.nemo.common.ir.edge.executionproperty.DataFlowModelProperty;
+import edu.snu.nemo.common.ir.vertex.IRVertex;
+import edu.snu.nemo.common.ir.vertex.OperatorVertex;
+import edu.snu.nemo.common.ir.vertex.executionproperty.ParallelismProperty;
+import edu.snu.nemo.common.ir.vertex.executionproperty.ScheduleGroupIndexProperty;
+import edu.snu.nemo.common.ir.vertex.transform.Transform;
+import edu.snu.nemo.common.test.EmptyComponents;
+import org.apache.reef.tang.Injector;
+import org.apache.reef.tang.Tang;
+import org.junit.Test;
+
+import java.util.Iterator;
+
+import static org.junit.Assert.assertNotEquals;
+
+/**
+ * Tests {@link PhysicalPlanGenerator}.
+ */
+public final class PhysicalPlanGeneratorTest {
+  private static final Transform EMPTY_TRANSFORM = new EmptyComponents.EmptyTransform("");
+
+  /**
+   * Test splitting ScheduleGroups by Pull StageEdges.
+   * @throws Exception exceptions on the way
+   */
+  @Test
+  public void testSplitScheduleGroupByPullStageEdges() throws Exception {
+    final Injector injector = Tang.Factory.getTang().newInjector();
+    final PhysicalPlanGenerator physicalPlanGenerator = injector.getInstance(PhysicalPlanGenerator.class);
+
+    final IRVertex v0 = newIRVertex(0, 5);
+    final IRVertex v1 = newIRVertex(0, 3);
+    final DAG<IRVertex, IREdge> irDAG = new DAGBuilder<IRVertex, IREdge>()
+        .addVertex(v0)
+        .addVertex(v1)
+        .connectVertices(newIREdge(v0, v1, DataCommunicationPatternProperty.Value.OneToOne,
+            DataFlowModelProperty.Value.Pull))
+        .buildWithoutSourceSinkCheck();
+
+    final DAG<Stage, StageEdge> stageDAG = physicalPlanGenerator.apply(irDAG);
+    final Iterator<Stage> stages = stageDAG.getVertices().iterator();
+    final Stage s0 = stages.next();
+    final Stage s1 = stages.next();
+
+    assertNotEquals(s0.getScheduleGroupIndex(), s1.getScheduleGroupIndex());
+  }
+
+  private static final IRVertex newIRVertex(final int scheduleGroupIndex, final int parallelism) {
+    final IRVertex irVertex = new OperatorVertex(EMPTY_TRANSFORM);
+    irVertex.getExecutionProperties().put(ScheduleGroupIndexProperty.of(scheduleGroupIndex));
+    irVertex.getExecutionProperties().put(ParallelismProperty.of(parallelism));
+    return irVertex;
+  }
+
+  private static final IREdge newIREdge(final IRVertex src, final IRVertex dst,
+                                        final DataCommunicationPatternProperty.Value communicationPattern,
+                                        final DataFlowModelProperty.Value dataFlowModel) {
+    final IREdge irEdge = new IREdge(communicationPattern, src, dst);
+    irEdge.getExecutionProperties().put(DataFlowModelProperty.of(dataFlowModel));
+    return irEdge;
+  }
+}
diff --git a/tests/src/test/java/edu/snu/nemo/tests/compiler/CompilerTestUtil.java b/tests/src/test/java/edu/snu/nemo/tests/compiler/CompilerTestUtil.java
index 63c02d98..d5f17f4a 100644
--- a/tests/src/test/java/edu/snu/nemo/tests/compiler/CompilerTestUtil.java
+++ b/tests/src/test/java/edu/snu/nemo/tests/compiler/CompilerTestUtil.java
@@ -59,7 +59,7 @@
     return captor.getValue();
   }
 
-  public static DAG<IRVertex, IREdge> compileMRDAG() throws Exception {
+  public static DAG<IRVertex, IREdge> compileWordCountDAG() throws Exception {
     final String input = rootDir + "/../examples/resources/sample_input_wordcount";
     final String output = rootDir + "/../examples/resources/sample_output";
     final String main = "edu.snu.nemo.examples.beam.WordCount";
diff --git a/tests/src/test/java/edu/snu/nemo/tests/compiler/optimizer/pass/compiletime/annotating/DefaultEdgeCoderPassTest.java b/tests/src/test/java/edu/snu/nemo/tests/compiler/optimizer/pass/compiletime/annotating/DefaultEdgeCoderPassTest.java
index 002873c7..7f9dc3a8 100644
--- a/tests/src/test/java/edu/snu/nemo/tests/compiler/optimizer/pass/compiletime/annotating/DefaultEdgeCoderPassTest.java
+++ b/tests/src/test/java/edu/snu/nemo/tests/compiler/optimizer/pass/compiletime/annotating/DefaultEdgeCoderPassTest.java
@@ -45,7 +45,7 @@
 
   @Before
   public void setUp() throws Exception {
-    compiledDAG = CompilerTestUtil.compileMRDAG();
+    compiledDAG = CompilerTestUtil.compileWordCountDAG();
   }
 
   @Test
diff --git a/tests/src/test/java/edu/snu/nemo/tests/compiler/optimizer/pass/compiletime/composite/DataSkewCompositePassTest.java b/tests/src/test/java/edu/snu/nemo/tests/compiler/optimizer/pass/compiletime/composite/DataSkewCompositePassTest.java
index 50b3cdba..7f81154c 100644
--- a/tests/src/test/java/edu/snu/nemo/tests/compiler/optimizer/pass/compiletime/composite/DataSkewCompositePassTest.java
+++ b/tests/src/test/java/edu/snu/nemo/tests/compiler/optimizer/pass/compiletime/composite/DataSkewCompositePassTest.java
@@ -81,7 +81,7 @@ public void testCompositePass() {
    */
   @Test
   public void testDataSkewPass() throws Exception {
-    mrDAG = CompilerTestUtil.compileMRDAG();
+    mrDAG = CompilerTestUtil.compileWordCountDAG();
     final Integer originalVerticesNum = mrDAG.getVertices().size();
     final Long numOfShuffleGatherEdges = mrDAG.getVertices().stream().filter(irVertex ->
         mrDAG.getIncomingEdgesOf(irVertex).stream().anyMatch(irEdge ->


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services