You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@nemo.apache.org by jo...@apache.org on 2018/06/26 06:46:16 UTC

[incubator-nemo] branch master updated: [NEMO-126] Split ScheduleGroup by pull StageEdges by PhysicalPlanGenerator (#54)

This is an automated email from the ASF dual-hosted git repository.

johnyangk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-nemo.git


The following commit(s) were added to refs/heads/master by this push:
     new 9e58c32  [NEMO-126] Split ScheduleGroup by pull StageEdges by PhysicalPlanGenerator (#54)
9e58c32 is described below

commit 9e58c329ca9c7e5fce578857e91ea7838d5113e8
Author: Jangho Seo <ja...@jangho.io>
AuthorDate: Tue Jun 26 15:46:14 2018 +0900

    [NEMO-126] Split ScheduleGroup by pull StageEdges by PhysicalPlanGenerator (#54)
    
    JIRA: [NEMO-126: PhysicalPlanGenerator: Split ScheduleGroup by pull StageEdges](https://issues.apache.org/jira/projects/NEMO/issues/NEMO-126)
    
    **Major change:**
    - Add a post-StagePartitioning step in PhysicalPlanGenerator, which splits ScheduleGroups by Pull StageEdges.
    
    **Minor changes to note:**
    - Renamed CompilerTestUtil.compileMRDAG to CompilerTestUtil.compileWordCountDAG
    - Increased timeout for FaultToleranceTest
    
    **Test for the changes:**
    - Added PhysicalPlanGenerator to test the major change.
    
    **Other comments:**
    N/A
    
    resolves [NEMO-126](https://issues.apache.org/jira/projects/NEMO/issues/NEMO-126)
---
 .../annotating/DefaultScheduleGroupPass.java       |  3 +
 .../runtime/common/plan/PhysicalPlanGenerator.java | 78 ++++++++++++++++++++
 .../edu/snu/nemo/runtime/common/plan/Stage.java    | 23 +++---
 .../snu/nemo/runtime/common/plan/StageEdge.java    | 35 ++++++++-
 .../master/scheduler/FaultToleranceTest.java       |  8 +--
 .../common/plan/PhysicalPlanGeneratorTest.java     | 83 ++++++++++++++++++++++
 .../snu/nemo/tests/compiler/CompilerTestUtil.java  |  2 +-
 .../annotating/DefaultEdgeCoderPassTest.java       |  2 +-
 .../composite/DataSkewCompositePassTest.java       |  2 +-
 9 files changed, 218 insertions(+), 18 deletions(-)

diff --git a/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/annotating/DefaultScheduleGroupPass.java b/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/annotating/DefaultScheduleGroupPass.java
index 55c16c5..38f964c 100644
--- a/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/annotating/DefaultScheduleGroupPass.java
+++ b/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/annotating/DefaultScheduleGroupPass.java
@@ -113,6 +113,9 @@ public final class DefaultScheduleGroupPass extends AnnotatingPass {
         if (skip) {
           continue;
         }
+        if (irVertexToScheduleGroupMap.containsKey(connectedIRVertex)) {
+          continue;
+        }
         // Now we can assure that all vertices that connectedIRVertex depends on have assigned a ScheduleGroup
 
         // Get ScheduleGroup(s) that push data to the connectedIRVertex
diff --git a/runtime/common/src/main/java/edu/snu/nemo/runtime/common/plan/PhysicalPlanGenerator.java b/runtime/common/src/main/java/edu/snu/nemo/runtime/common/plan/PhysicalPlanGenerator.java
index d679387..107a623 100644
--- a/runtime/common/src/main/java/edu/snu/nemo/runtime/common/plan/PhysicalPlanGenerator.java
+++ b/runtime/common/src/main/java/edu/snu/nemo/runtime/common/plan/PhysicalPlanGenerator.java
@@ -16,6 +16,7 @@
 package edu.snu.nemo.runtime.common.plan;
 
 import edu.snu.nemo.common.ir.Readable;
+import edu.snu.nemo.common.ir.edge.executionproperty.DataFlowModelProperty;
 import edu.snu.nemo.common.ir.edge.executionproperty.DuplicateEdgeGroupProperty;
 import edu.snu.nemo.common.ir.edge.executionproperty.DuplicateEdgeGroupPropertyValue;
 import edu.snu.nemo.common.ir.executionproperty.ExecutionPropertyMap;
@@ -31,6 +32,7 @@ import edu.snu.nemo.common.ir.edge.IREdge;
 import edu.snu.nemo.common.exception.IllegalVertexOperationException;
 import edu.snu.nemo.common.exception.PhysicalPlanGenerationException;
 import edu.snu.nemo.runtime.common.RuntimeIdGenerator;
+import org.apache.commons.lang3.mutable.MutableInt;
 import org.apache.reef.tang.annotations.Parameter;
 
 import javax.inject.Inject;
@@ -75,6 +77,10 @@ public final class PhysicalPlanGenerator implements Function<DAG<IRVertex, IREdg
     // this is needed because of DuplicateEdgeGroupProperty.
     handleDuplicateEdgeGroupProperty(dagOfStages);
 
+
+    // Split StageGroup by Pull StageEdges
+    splitScheduleGroupByPullStageEdges(dagOfStages);
+
     // for debugging purposes.
     dagOfStages.storeJSON(dagDirectory, "plan-logical", "logical execution plan");
 
@@ -232,4 +238,76 @@ public final class PhysicalPlanGenerator implements Function<DAG<IRVertex, IREdg
       }
     });
   }
+
+  /**
+   * Split ScheduleGroups by Pull {@link StageEdge}s, and ensure topological ordering of
+   * {@link ScheduleGroupIndexProperty}.
+   *
+   * @param dag {@link DAG} of {@link Stage}s to manipulate
+   */
+  private void splitScheduleGroupByPullStageEdges(final DAG<Stage, StageEdge> dag) {
+    final MutableInt nextScheduleGroupIndex = new MutableInt(0);
+    final Map<Stage, Integer> stageToScheduleGroupIndexMap = new HashMap<>();
+
+    dag.topologicalDo(currentStage -> {
+      // Base case: assign New ScheduleGroupIndex of the Stage
+      stageToScheduleGroupIndexMap.computeIfAbsent(currentStage, s -> getAndIncrement(nextScheduleGroupIndex));
+
+      for (final StageEdge stageEdgeFromCurrentStage : dag.getOutgoingEdgesOf(currentStage)) {
+        final Stage destination = stageEdgeFromCurrentStage.getDst();
+        // Skip if some Stages that destination depends on do not have assigned new ScheduleGroupIndex
+        boolean skip = false;
+        for (final StageEdge stageEdgeToDestination : dag.getIncomingEdgesOf(destination)) {
+          if (!stageToScheduleGroupIndexMap.containsKey(stageEdgeToDestination.getSrc())) {
+            skip = true;
+            break;
+          }
+        }
+        if (skip) {
+          continue;
+        }
+        if (stageToScheduleGroupIndexMap.containsKey(destination)) {
+          continue;
+        }
+
+        // Find any non-pull inEdge
+        Integer scheduleGroupIndex = null;
+        Integer newScheduleGroupIndex = null;
+        for (final StageEdge stageEdge : dag.getIncomingEdgesOf(destination)) {
+          final Stage source = stageEdge.getSrc();
+          if (stageEdge.getDataFlowModel() != DataFlowModelProperty.Value.Pull) {
+            if (scheduleGroupIndex != null && source.getScheduleGroupIndex() != scheduleGroupIndex) {
+              throw new RuntimeException(String.format("Multiple Push inEdges from different ScheduleGroup: %d, %d",
+                  scheduleGroupIndex, source.getScheduleGroupIndex()));
+            }
+            if (source.getScheduleGroupIndex() != destination.getScheduleGroupIndex()) {
+              throw new RuntimeException(String.format("Split ScheduleGroup by push StageEdge: %d, %d",
+                  source.getScheduleGroupIndex(), destination.getScheduleGroupIndex()));
+            }
+            scheduleGroupIndex = source.getScheduleGroupIndex();
+            newScheduleGroupIndex = stageToScheduleGroupIndexMap.get(source);
+          }
+        }
+
+        if (newScheduleGroupIndex == null) {
+          stageToScheduleGroupIndexMap.put(destination, getAndIncrement(nextScheduleGroupIndex));
+        } else {
+          stageToScheduleGroupIndexMap.put(destination, newScheduleGroupIndex);
+        }
+      }
+    });
+
+    dag.topologicalDo(stage -> {
+      final int scheduleGroupIndex = stageToScheduleGroupIndexMap.get(stage);
+      stage.getExecutionProperties().put(ScheduleGroupIndexProperty.of(scheduleGroupIndex));
+      stage.getIRDAG().topologicalDo(vertex -> vertex.getExecutionProperties()
+          .put(ScheduleGroupIndexProperty.of(scheduleGroupIndex)));
+    });
+  }
+
+  private static int getAndIncrement(final MutableInt mutableInt) {
+    final int toReturn = mutableInt.getValue();
+    mutableInt.increment();
+    return toReturn;
+  }
 }
diff --git a/runtime/common/src/main/java/edu/snu/nemo/runtime/common/plan/Stage.java b/runtime/common/src/main/java/edu/snu/nemo/runtime/common/plan/Stage.java
index ee0a337..9e5da38 100644
--- a/runtime/common/src/main/java/edu/snu/nemo/runtime/common/plan/Stage.java
+++ b/runtime/common/src/main/java/edu/snu/nemo/runtime/common/plan/Stage.java
@@ -40,8 +40,6 @@ public final class Stage extends Vertex {
   private final byte[] serializedIRDag;
   private final ExecutionPropertyMap<VertexExecutionProperty> executionProperties;
   private final List<Map<String, Readable>> vertexIdToReadables;
-  private final int parallelism;
-  private final int scheduleGroupIndex;
 
   /**
    * Constructor.
@@ -60,10 +58,6 @@ public final class Stage extends Vertex {
     this.serializedIRDag = SerializationUtils.serialize(irDag);
     this.executionProperties = executionProperties;
     this.vertexIdToReadables = vertexIdToReadables;
-    this.parallelism = executionProperties.get(ParallelismProperty.class)
-        .orElseThrow(() -> new RuntimeException("Parallelism property must be set for Stage"));
-    this.scheduleGroupIndex = executionProperties.get(ScheduleGroupIndexProperty.class)
-        .orElseThrow(() -> new RuntimeException("ScheduleGroupIndex property must be set for Stage"));
   }
 
   /**
@@ -85,17 +79,26 @@ public final class Stage extends Vertex {
    */
   public List<String> getTaskIds() {
     final List<String> taskIds = new ArrayList<>();
-    for (int taskIdx = 0; taskIdx < parallelism; taskIdx++) {
+    for (int taskIdx = 0; taskIdx < getParallelism(); taskIdx++) {
       taskIds.add(RuntimeIdGenerator.generateTaskId(taskIdx, getId()));
     }
     return taskIds;
   }
 
   /**
+   * @return the parallelism
+   */
+  public int getParallelism() {
+    return executionProperties.get(ParallelismProperty.class)
+        .orElseThrow(() -> new RuntimeException("Parallelism property must be set for Stage"));
+  }
+
+  /**
    * @return the schedule group index.
    */
   public int getScheduleGroupIndex() {
-    return scheduleGroupIndex;
+    return executionProperties.get(ScheduleGroupIndexProperty.class)
+        .orElseThrow(() -> new RuntimeException("ScheduleGroupIndex property must be set for Stage"));
   }
 
   /**
@@ -127,9 +130,9 @@ public final class Stage extends Vertex {
   @Override
   public String propertiesToJSON() {
     final StringBuilder sb = new StringBuilder();
-    sb.append("{\"scheduleGroupIndex\": ").append(scheduleGroupIndex);
+    sb.append("{\"scheduleGroupIndex\": ").append(getScheduleGroupIndex());
     sb.append(", \"irDag\": ").append(irDag);
-    sb.append(", \"parallelism\": ").append(parallelism);
+    sb.append(", \"parallelism\": ").append(getParallelism());
     sb.append(", \"executionProperties\": ").append(executionProperties);
     sb.append('}');
     return sb.toString();
diff --git a/runtime/common/src/main/java/edu/snu/nemo/runtime/common/plan/StageEdge.java b/runtime/common/src/main/java/edu/snu/nemo/runtime/common/plan/StageEdge.java
index d2697f4..a45b9b6 100644
--- a/runtime/common/src/main/java/edu/snu/nemo/runtime/common/plan/StageEdge.java
+++ b/runtime/common/src/main/java/edu/snu/nemo/runtime/common/plan/StageEdge.java
@@ -16,6 +16,9 @@
 package edu.snu.nemo.runtime.common.plan;
 
 import com.google.common.annotations.VisibleForTesting;
+import edu.snu.nemo.common.ir.edge.executionproperty.DataCommunicationPatternProperty;
+import edu.snu.nemo.common.ir.edge.executionproperty.DataFlowModelProperty;
+import edu.snu.nemo.common.ir.executionproperty.EdgeExecutionProperty;
 import edu.snu.nemo.common.ir.vertex.IRVertex;
 import edu.snu.nemo.common.ir.executionproperty.ExecutionPropertyMap;
 import edu.snu.nemo.runtime.common.data.KeyRange;
@@ -47,6 +50,16 @@ public final class StageEdge extends RuntimeEdge<Stage> {
   private List<KeyRange> taskIdxToKeyRange;
 
   /**
+   * Value for {@link DataCommunicationPatternProperty}.
+   */
+  private final DataCommunicationPatternProperty.Value dataCommunicationPatternValue;
+
+  /**
+   * Value for {@link DataFlowModelProperty}.
+   */
+  private final DataFlowModelProperty.Value dataFlowModelValue;
+
+  /**
    * Constructor.
    *
    * @param runtimeEdgeId  id of the runtime edge.
@@ -59,7 +72,7 @@ public final class StageEdge extends RuntimeEdge<Stage> {
    */
   @VisibleForTesting
   public StageEdge(final String runtimeEdgeId,
-            final ExecutionPropertyMap edgeProperties,
+            final ExecutionPropertyMap<EdgeExecutionProperty> edgeProperties,
             final IRVertex srcVertex,
             final IRVertex dstVertex,
             final Stage srcStage,
@@ -73,6 +86,12 @@ public final class StageEdge extends RuntimeEdge<Stage> {
     for (int taskIdx = 0; taskIdx < dstStage.getTaskIds().size(); taskIdx++) {
       taskIdxToKeyRange.add(HashRange.of(taskIdx, taskIdx + 1));
     }
+    this.dataCommunicationPatternValue = edgeProperties.get(DataCommunicationPatternProperty.class)
+        .orElseThrow(() -> new RuntimeException(String.format(
+            "DataCommunicationPatternProperty not set for %s", runtimeEdgeId)));
+    this.dataFlowModelValue = edgeProperties.get(DataFlowModelProperty.class)
+        .orElseThrow(() -> new RuntimeException(String.format(
+            "DataFlowModelProperty not set for %s", runtimeEdgeId)));
   }
 
   /**
@@ -115,4 +134,18 @@ public final class StageEdge extends RuntimeEdge<Stage> {
   public void setTaskIdxToKeyRange(final List<KeyRange> taskIdxToKeyRange) {
     this.taskIdxToKeyRange = taskIdxToKeyRange;
   }
+
+  /**
+   * @return {@link DataCommunicationPatternProperty} value.
+   */
+  public DataCommunicationPatternProperty.Value getDataCommunicationPattern() {
+    return dataCommunicationPatternValue;
+  }
+
+  /**
+   * @return {@link DataFlowModelProperty} value.
+   */
+  public DataFlowModelProperty.Value getDataFlowModel() {
+    return dataFlowModelValue;
+  }
 }
diff --git a/runtime/master/src/test/java/edu/snu/nemo/runtime/master/scheduler/FaultToleranceTest.java b/runtime/master/src/test/java/edu/snu/nemo/runtime/master/scheduler/FaultToleranceTest.java
index 0e57dc5..62b5181 100644
--- a/runtime/master/src/test/java/edu/snu/nemo/runtime/master/scheduler/FaultToleranceTest.java
+++ b/runtime/master/src/test/java/edu/snu/nemo/runtime/master/scheduler/FaultToleranceTest.java
@@ -102,7 +102,7 @@ public final class FaultToleranceTest {
   /**
    * Tests fault tolerance after a container removal.
    */
-  @Test(timeout=5000)
+  @Test(timeout=50000)
   public void testContainerRemoval() throws Exception {
     final ActiveContext activeContext = mock(ActiveContext.class);
     Mockito.doThrow(new RuntimeException()).when(activeContext).close();
@@ -180,7 +180,7 @@ public final class FaultToleranceTest {
   /**
    * Tests fault tolerance after an output write failure.
    */
-  @Test(timeout=5000)
+  @Test(timeout=50000)
   public void testOutputFailure() throws Exception {
     final ActiveContext activeContext = mock(ActiveContext.class);
     Mockito.doThrow(new RuntimeException()).when(activeContext).close();
@@ -245,7 +245,7 @@ public final class FaultToleranceTest {
   /**
    * Tests fault tolerance after an input read failure.
    */
-  @Test(timeout=5000)
+  @Test(timeout=50000)
   public void testInputReadFailure() throws Exception {
     final ActiveContext activeContext = mock(ActiveContext.class);
     Mockito.doThrow(new RuntimeException()).when(activeContext).close();
@@ -311,7 +311,7 @@ public final class FaultToleranceTest {
   /**
    * Tests the rescheduling of Tasks upon a failure.
    */
-  @Test(timeout=20000)
+  @Test(timeout=200000)
   public void testTaskReexecutionForFailure() throws Exception {
     final ActiveContext activeContext = mock(ActiveContext.class);
     Mockito.doThrow(new RuntimeException()).when(activeContext).close();
diff --git a/tests/src/test/java/edu/snu/nemo/runtime/common/plan/PhysicalPlanGeneratorTest.java b/tests/src/test/java/edu/snu/nemo/runtime/common/plan/PhysicalPlanGeneratorTest.java
new file mode 100644
index 0000000..f8c4f90
--- /dev/null
+++ b/tests/src/test/java/edu/snu/nemo/runtime/common/plan/PhysicalPlanGeneratorTest.java
@@ -0,0 +1,83 @@
+/*
+ * Copyright (C) 2018 Seoul National University
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *         http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package edu.snu.nemo.runtime.common.plan;
+
+import edu.snu.nemo.common.dag.DAG;
+import edu.snu.nemo.common.dag.DAGBuilder;
+import edu.snu.nemo.common.ir.edge.IREdge;
+import edu.snu.nemo.common.ir.edge.executionproperty.DataCommunicationPatternProperty;
+import edu.snu.nemo.common.ir.edge.executionproperty.DataFlowModelProperty;
+import edu.snu.nemo.common.ir.vertex.IRVertex;
+import edu.snu.nemo.common.ir.vertex.OperatorVertex;
+import edu.snu.nemo.common.ir.vertex.executionproperty.ParallelismProperty;
+import edu.snu.nemo.common.ir.vertex.executionproperty.ScheduleGroupIndexProperty;
+import edu.snu.nemo.common.ir.vertex.transform.Transform;
+import edu.snu.nemo.common.test.EmptyComponents;
+import org.apache.reef.tang.Injector;
+import org.apache.reef.tang.Tang;
+import org.junit.Test;
+
+import java.util.Iterator;
+
+import static org.junit.Assert.assertNotEquals;
+
+/**
+ * Tests {@link PhysicalPlanGenerator}.
+ */
+public final class PhysicalPlanGeneratorTest {
+  private static final Transform EMPTY_TRANSFORM = new EmptyComponents.EmptyTransform("");
+
+  /**
+   * Test splitting ScheduleGroups by Pull StageEdges.
+   * @throws Exception exceptions on the way
+   */
+  @Test
+  public void testSplitScheduleGroupByPullStageEdges() throws Exception {
+    final Injector injector = Tang.Factory.getTang().newInjector();
+    final PhysicalPlanGenerator physicalPlanGenerator = injector.getInstance(PhysicalPlanGenerator.class);
+
+    final IRVertex v0 = newIRVertex(0, 5);
+    final IRVertex v1 = newIRVertex(0, 3);
+    final DAG<IRVertex, IREdge> irDAG = new DAGBuilder<IRVertex, IREdge>()
+        .addVertex(v0)
+        .addVertex(v1)
+        .connectVertices(newIREdge(v0, v1, DataCommunicationPatternProperty.Value.OneToOne,
+            DataFlowModelProperty.Value.Pull))
+        .buildWithoutSourceSinkCheck();
+
+    final DAG<Stage, StageEdge> stageDAG = physicalPlanGenerator.apply(irDAG);
+    final Iterator<Stage> stages = stageDAG.getVertices().iterator();
+    final Stage s0 = stages.next();
+    final Stage s1 = stages.next();
+
+    assertNotEquals(s0.getScheduleGroupIndex(), s1.getScheduleGroupIndex());
+  }
+
+  private static final IRVertex newIRVertex(final int scheduleGroupIndex, final int parallelism) {
+    final IRVertex irVertex = new OperatorVertex(EMPTY_TRANSFORM);
+    irVertex.getExecutionProperties().put(ScheduleGroupIndexProperty.of(scheduleGroupIndex));
+    irVertex.getExecutionProperties().put(ParallelismProperty.of(parallelism));
+    return irVertex;
+  }
+
+  private static final IREdge newIREdge(final IRVertex src, final IRVertex dst,
+                                        final DataCommunicationPatternProperty.Value communicationPattern,
+                                        final DataFlowModelProperty.Value dataFlowModel) {
+    final IREdge irEdge = new IREdge(communicationPattern, src, dst);
+    irEdge.getExecutionProperties().put(DataFlowModelProperty.of(dataFlowModel));
+    return irEdge;
+  }
+}
diff --git a/tests/src/test/java/edu/snu/nemo/tests/compiler/CompilerTestUtil.java b/tests/src/test/java/edu/snu/nemo/tests/compiler/CompilerTestUtil.java
index 63c02d9..d5f17f4 100644
--- a/tests/src/test/java/edu/snu/nemo/tests/compiler/CompilerTestUtil.java
+++ b/tests/src/test/java/edu/snu/nemo/tests/compiler/CompilerTestUtil.java
@@ -59,7 +59,7 @@ public final class CompilerTestUtil {
     return captor.getValue();
   }
 
-  public static DAG<IRVertex, IREdge> compileMRDAG() throws Exception {
+  public static DAG<IRVertex, IREdge> compileWordCountDAG() throws Exception {
     final String input = rootDir + "/../examples/resources/sample_input_wordcount";
     final String output = rootDir + "/../examples/resources/sample_output";
     final String main = "edu.snu.nemo.examples.beam.WordCount";
diff --git a/tests/src/test/java/edu/snu/nemo/tests/compiler/optimizer/pass/compiletime/annotating/DefaultEdgeCoderPassTest.java b/tests/src/test/java/edu/snu/nemo/tests/compiler/optimizer/pass/compiletime/annotating/DefaultEdgeCoderPassTest.java
index 002873c..7f9dc3a 100644
--- a/tests/src/test/java/edu/snu/nemo/tests/compiler/optimizer/pass/compiletime/annotating/DefaultEdgeCoderPassTest.java
+++ b/tests/src/test/java/edu/snu/nemo/tests/compiler/optimizer/pass/compiletime/annotating/DefaultEdgeCoderPassTest.java
@@ -45,7 +45,7 @@ public class DefaultEdgeCoderPassTest {
 
   @Before
   public void setUp() throws Exception {
-    compiledDAG = CompilerTestUtil.compileMRDAG();
+    compiledDAG = CompilerTestUtil.compileWordCountDAG();
   }
 
   @Test
diff --git a/tests/src/test/java/edu/snu/nemo/tests/compiler/optimizer/pass/compiletime/composite/DataSkewCompositePassTest.java b/tests/src/test/java/edu/snu/nemo/tests/compiler/optimizer/pass/compiletime/composite/DataSkewCompositePassTest.java
index 50b3cdb..7f81154 100644
--- a/tests/src/test/java/edu/snu/nemo/tests/compiler/optimizer/pass/compiletime/composite/DataSkewCompositePassTest.java
+++ b/tests/src/test/java/edu/snu/nemo/tests/compiler/optimizer/pass/compiletime/composite/DataSkewCompositePassTest.java
@@ -81,7 +81,7 @@ public class DataSkewCompositePassTest {
    */
   @Test
   public void testDataSkewPass() throws Exception {
-    mrDAG = CompilerTestUtil.compileMRDAG();
+    mrDAG = CompilerTestUtil.compileWordCountDAG();
     final Integer originalVerticesNum = mrDAG.getVertices().size();
     final Long numOfShuffleGatherEdges = mrDAG.getVertices().stream().filter(irVertex ->
         mrDAG.getIncomingEdgesOf(irVertex).stream().anyMatch(irEdge ->