You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@nemo.apache.org by sa...@apache.org on 2019/02/21 07:38:58 UTC

[incubator-nemo] branch master updated: [NEMO-338] SkewSamplingPass (#193)

This is an automated email from the ASF dual-hosted git repository.

sanha pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-nemo.git


The following commit(s) were added to refs/heads/master by this push:
     new a2b02dc  [NEMO-338] SkewSamplingPass (#193)
a2b02dc is described below

commit a2b02dc85b433c68b8feafe63248e05c0724b4f8
Author: John Yang <jo...@gmail.com>
AuthorDate: Thu Feb 21 16:38:54 2019 +0900

    [NEMO-338] SkewSamplingPass (#193)
    
    JIRA: [NEMO-338: SkewSamplingPass](https://issues.apache.org/jira/projects/NEMO/issues/NEMO-338)
    
    **Major changes:**
    - SamplingSkewReshapingPass: Inserts SkewSampling and MessageBarrier vertices
    - SamplingVertex: Instantiated with (originalVertex, desiredSampleRate)
    - IRDAG: Automatically inserts IREdges from/to SamplingVertex objects, similar to other insert() methods
    - PhysicalPlanGenerator: Handles SamplingVertex objects appropriately
    - Stage: Uses getTaskIndices(), returns a subset of tasks if consists of SamplingVertex objects, to determine the tasks to execute
    
    **Minor changes to note:**
    - Refactors other insert() methods to share code as much as possible
    
    **Tests for the changes:**
    - PerKeyMedianITCase#testLargeShuffleSamplingSkew (combines large shuffle + skew handling optimizations)
    
    **Other comments:**
    - Sanha(@sanha) wrote the original code. I refactored the code and added comments to create this PR.
    
    Closes #193
---
 .../java/org/apache/nemo/client/JobLauncher.java   |   2 +-
 .../src/main/java/org/apache/nemo/common/Util.java | 104 +++++++-
 .../main/java/org/apache/nemo/common/dag/DAG.java  |   5 +
 .../org/apache/nemo/common/dag/DAGBuilder.java     |  27 +-
 .../org/apache/nemo/common/dag/DAGInterface.java   |  18 +-
 .../main/java/org/apache/nemo/common/ir/IRDAG.java | 295 ++++++++++++++++-----
 .../org/apache/nemo/common/ir/edge/IREdge.java     |  15 +-
 ...eIdProperty.java => MessageIdEdgeProperty.java} |  10 +-
 .../apache/nemo/common/ir/vertex/LoopVertex.java   |   2 +-
 .../apache/nemo/common/ir/vertex/SourceVertex.java |   2 +-
 .../MessageIdVertexProperty.java}                  |  14 +-
 .../ir/vertex/utility/MessageAggregatorVertex.java |  11 +-
 .../common/ir/vertex/utility/SamplingVertex.java   | 119 +++++++++
 .../java/org/apache/nemo/common/util/UtilTest.java |   1 +
 .../compiler/backend/nemo/NemoPlanRewriter.java    |   6 +-
 .../annotating/LargeShuffleAnnotatingPass.java     |  18 +-
 .../reshaping/SamplingSkewReshapingPass.java       | 137 ++++++++++
 .../compiletime/reshaping/SkewHandlingUtil.java    |  77 ++++++
 .../compiletime/reshaping/SkewReshapingPass.java   |  49 +---
 .../policy/SamplingLargeShuffleSkewPolicy.java     |  58 ++++
 .../compiler/backend/nemo/DAGConverterTest.java    |   4 +-
 .../nemo/examples/beam/PerKeyMedianITCase.java     |  13 +
 .../beam_test_one_executor_resources.json          |   2 +-
 .../runtime/common/plan/PhysicalPlanGenerator.java |  80 +++++-
 .../org/apache/nemo/runtime/common/plan/Stage.java |  18 +-
 .../apache/nemo/runtime/common/plan/StageEdge.java |   8 +-
 .../common/plan/PhysicalPlanGeneratorTest.java     |   5 +-
 .../org/apache/nemo/runtime/executor/TestUtil.java |   4 +-
 .../executor/datatransfer/DataTransferTest.java    |   7 +-
 .../nemo/runtime/master/PlanStateManager.java      |  46 ++--
 .../runtime/master/scheduler/BatchScheduler.java   |  32 +--
 31 files changed, 950 insertions(+), 239 deletions(-)

diff --git a/client/src/main/java/org/apache/nemo/client/JobLauncher.java b/client/src/main/java/org/apache/nemo/client/JobLauncher.java
index 2a89a5e..ba6194b 100644
--- a/client/src/main/java/org/apache/nemo/client/JobLauncher.java
+++ b/client/src/main/java/org/apache/nemo/client/JobLauncher.java
@@ -181,7 +181,7 @@ public final class JobLauncher {
           LOG.info("Wait for the driver to finish");
           driverLauncher.wait();
         } catch (final InterruptedException e) {
-          LOG.warn("Interrupted: " + e);
+          LOG.warn("Interrupted: ", e);
           // clean up state...
           Thread.currentThread().interrupt();
         }
diff --git a/common/src/main/java/org/apache/nemo/common/Util.java b/common/src/main/java/org/apache/nemo/common/Util.java
index 288c5c2..f888ceb 100644
--- a/common/src/main/java/org/apache/nemo/common/Util.java
+++ b/common/src/main/java/org/apache/nemo/common/Util.java
@@ -16,14 +16,22 @@
  * specific language governing permissions and limitations
  * under the License.
  */
-package org.apache.nemo.common.util;
+package org.apache.nemo.common;
 
+import org.apache.nemo.common.ir.edge.IREdge;
+import org.apache.nemo.common.ir.edge.executionproperty.*;
+import org.apache.nemo.common.ir.vertex.IRVertex;
+
+import java.util.Collection;
 import java.util.function.IntPredicate;
+import java.util.stream.Collectors;
 
 /**
  * Class to hold the utility methods.
  */
 public final class Util {
+  // Assume that this tag is never used in user application
+  public static final String CONTROL_EDGE_TAG = "CONTROL_EDGE";
 
   /**
    * Private constructor for utility class.
@@ -42,7 +50,7 @@ public final class Util {
    * @return whether or not we can say that they are equal.
    */
   public static boolean checkEqualityOfIntPredicates(final IntPredicate firstPredicate,
-      final IntPredicate secondPredicate, final int noOfTimes) {
+                                                     final IntPredicate secondPredicate, final int noOfTimes) {
     for (int value = 0; value <= noOfTimes; value++) {
       if (firstPredicate.test(value) != secondPredicate.test(value)) {
         return false;
@@ -51,4 +59,96 @@ public final class Util {
     return true;
   }
 
+  /**
+   * @param edgeToClone to copy execution properties from.
+   * @param newSrc of the new edge.
+   * @param newDst of the new edge.
+   * @return the new edge.
+   */
+  public static IREdge cloneEdge(final IREdge edgeToClone,
+                                 final IRVertex newSrc,
+                                 final IRVertex newDst) {
+    return cloneEdge(
+      edgeToClone.getPropertyValue(CommunicationPatternProperty.class).get(), edgeToClone, newSrc, newDst);
+  }
+
+  /**
+   * Creates a new edge with several execution properties same as the given edge.
+   * The copied execution properties include those minimally required for execution, such as encoder/decoders.
+   *
+   * @param commPattern to use.
+   * @param edgeToClone to copy execution properties from.
+   * @param newSrc of the new edge.
+   * @param newDst of the new edge.
+   * @return the new edge.
+   */
+  public static IREdge cloneEdge(final CommunicationPatternProperty.Value commPattern,
+                                 final IREdge edgeToClone,
+                                 final IRVertex newSrc,
+                                 final IRVertex newDst) {
+    final IREdge clone = new IREdge(commPattern, newSrc, newDst);
+
+    if (edgeToClone.getPropertySnapshot().containsKey(EncoderProperty.class)) {
+      clone.setProperty(edgeToClone.getPropertySnapshot().get(EncoderProperty.class));
+    } else {
+      clone.setProperty(EncoderProperty.of(edgeToClone.getPropertyValue(EncoderProperty.class)
+        .orElseThrow(IllegalStateException::new)));
+    }
+
+    if (edgeToClone.getPropertySnapshot().containsKey(DecoderProperty.class)) {
+      clone.setProperty(edgeToClone.getPropertySnapshot().get(DecoderProperty.class));
+    } else {
+      clone.setProperty(DecoderProperty.of(edgeToClone.getPropertyValue(DecoderProperty.class)
+        .orElseThrow(IllegalStateException::new)));
+    }
+
+    edgeToClone.getPropertyValue(AdditionalOutputTagProperty.class).ifPresent(tag -> {
+      clone.setProperty(AdditionalOutputTagProperty.of(tag));
+    });
+
+    edgeToClone.getPropertyValue(PartitionerProperty.class).ifPresent(p -> {
+      if (p.right() == PartitionerProperty.NUM_EQUAL_TO_DST_PARALLELISM) {
+        clone.setProperty(PartitionerProperty.of(p.left()));
+      } else {
+        clone.setProperty(PartitionerProperty.of(p.left(), p.right()));
+      }
+    });
+
+    edgeToClone.getPropertyValue(KeyExtractorProperty.class).ifPresent(ke -> {
+      clone.setProperty(KeyExtractorProperty.of(ke));
+    });
+
+    return clone;
+  }
+
+  /**
+   * A control edge enforces an execution ordering between the source vertex and the destination vertex.
+   * The additional output tag property of control edges is set such that no actual data element is transferred
+   * via the edges. This minimizes the run-time overhead of executing control edges.
+   *
+   * @param src vertex.
+   * @param dst vertex.
+   * @return the control edge.
+   */
+  public static IREdge createControlEdge(final IRVertex src, final IRVertex dst) {
+    final IREdge controlEdge = new IREdge(CommunicationPatternProperty.Value.BroadCast, src, dst);
+    controlEdge.setPropertyPermanently(AdditionalOutputTagProperty.of(CONTROL_EDGE_TAG));
+    return controlEdge;
+  }
+
+  /**
+   * @param vertices to stringify ids of.
+   * @return the string of ids.
+   */
+  public static String stringifyIRVertexIds(final Collection<IRVertex> vertices) {
+    return vertices.stream().map(IRVertex::getId).collect(Collectors.toSet()).toString();
+  }
+
+  /**
+   * @param edges to stringify ids of.
+   * @return the string of ids.
+   */
+  public static String stringifyIREdgeIds(final Collection<IREdge> edges) {
+    return edges.stream().map(IREdge::getId).collect(Collectors.toSet()).toString();
+  }
 }
diff --git a/common/src/main/java/org/apache/nemo/common/dag/DAG.java b/common/src/main/java/org/apache/nemo/common/dag/DAG.java
index e6a55b7..de11d98 100644
--- a/common/src/main/java/org/apache/nemo/common/dag/DAG.java
+++ b/common/src/main/java/org/apache/nemo/common/dag/DAG.java
@@ -103,6 +103,11 @@ public final class DAG<V extends Vertex, E extends Edge<V>> implements DAGInterf
   }
 
   @Override
+  public List<E> getEdges() {
+    return incomingEdges.values().stream().flatMap(List::stream).collect(Collectors.toList());
+  }
+
+  @Override
   public List<V> getRootVertices() {
     return rootVertices;
   }
diff --git a/common/src/main/java/org/apache/nemo/common/dag/DAGBuilder.java b/common/src/main/java/org/apache/nemo/common/dag/DAGBuilder.java
index 1fe2b13..bd5cc95 100644
--- a/common/src/main/java/org/apache/nemo/common/dag/DAGBuilder.java
+++ b/common/src/main/java/org/apache/nemo/common/dag/DAGBuilder.java
@@ -19,12 +19,11 @@
 package org.apache.nemo.common.dag;
 
 import org.apache.nemo.common.exception.CompileTimeOptimizationException;
-import org.apache.nemo.common.ir.edge.IREdge;
-import org.apache.nemo.common.ir.edge.executionproperty.DataFlowProperty;
-import org.apache.nemo.common.ir.edge.executionproperty.MessageIdProperty;
 import org.apache.nemo.common.ir.vertex.*;
 import org.apache.nemo.common.exception.IllegalVertexOperationException;
+import org.apache.nemo.common.ir.vertex.executionproperty.MessageIdVertexProperty;
 import org.apache.nemo.common.ir.vertex.utility.MessageAggregatorVertex;
+import org.apache.nemo.common.ir.vertex.utility.SamplingVertex;
 
 import java.io.Serializable;
 import java.util.*;
@@ -227,7 +226,8 @@ public final class DAGBuilder<V extends Vertex, E extends Edge<V>> implements Se
     final Supplier<Stream<V>> verticesToObserve = () -> vertices.stream().filter(v -> incomingEdges.get(v).isEmpty())
         .filter(v -> v instanceof IRVertex);
     // They should all match SourceVertex
-    if (verticesToObserve.get().anyMatch(v -> !(v instanceof SourceVertex))) {
+    if (!(verticesToObserve.get().allMatch(v -> (v instanceof SourceVertex)
+      || (v instanceof SamplingVertex && ((SamplingVertex) v).getCloneOfOriginalVertex() instanceof SourceVertex)))) {
       final String problematicVertices = verticesToObserve.get()
           .filter(v -> !(v instanceof SourceVertex))
           .map(V::getId)
@@ -258,16 +258,15 @@ public final class DAGBuilder<V extends Vertex, E extends Edge<V>> implements Se
    * Helper method to check that all execution properties are correct and makes sense.
    */
   private void executionPropertyCheck() {
-    // DataSizeMetricCollection is not compatible with Push (All data have to be stored before the data collection)
-    vertices.forEach(v -> incomingEdges.get(v).stream().filter(e -> e instanceof IREdge).map(e -> (IREdge) e)
-        .filter(e -> e.getPropertyValue(MessageIdProperty.class).isPresent())
-        .filter(e -> !(e.getDst() instanceof OperatorVertex
-          && e.getDst() instanceof MessageAggregatorVertex))
-        .filter(e -> DataFlowProperty.Value.Push.equals(e.getPropertyValue(DataFlowProperty.class).get()))
-        .forEach(e -> {
-          throw new CompileTimeOptimizationException("DAG execution property check: "
-              + "DataSizeMetricCollection edge is not compatible with push" + e.getId());
-        }));
+    final long numOfMAV = vertices.stream().filter(v -> v instanceof MessageAggregatorVertex).count();
+    final long numOfDistinctMessageIds = vertices.stream()
+      .filter(v -> v instanceof MessageAggregatorVertex)
+      .map(v -> ((MessageAggregatorVertex) v).getPropertyValue(MessageIdVertexProperty.class).get())
+      .distinct()
+      .count();
+    if (numOfMAV != numOfDistinctMessageIds) {
+      throw getException("A unique message id must exist for each MessageAggregator", "");
+    }
   }
 
   /**
diff --git a/common/src/main/java/org/apache/nemo/common/dag/DAGInterface.java b/common/src/main/java/org/apache/nemo/common/dag/DAGInterface.java
index 40cb72c..eba2e4a 100644
--- a/common/src/main/java/org/apache/nemo/common/dag/DAGInterface.java
+++ b/common/src/main/java/org/apache/nemo/common/dag/DAGInterface.java
@@ -44,21 +44,27 @@ public interface DAGInterface<V extends Vertex, E extends Edge<V>> extends Seria
 
   /**
    * Retrieves the vertices of this DAG.
-   * @return the set of vertices.
+   * @return the list of vertices.
    * Note that the result is never null, ensured by {@link DAGBuilder}.
    */
   List<V> getVertices();
 
   /**
+   * Retrieves the edges of this DAG.
+   * @return the list of edges.
+   */
+  List<E> getEdges();
+
+  /**
    * Retrieves the root vertices of this DAG.
-   * @return the set of root vertices.
+   * @return the list of root vertices.
    */
   List<V> getRootVertices();
 
   /**
    * Retrieves the incoming edges of the given vertex.
    * @param v the subject vertex.
-   * @return the set of incoming edges to the vertex.
+   * @return the list of incoming edges to the vertex.
    * Note that the result is never null, ensured by {@link DAGBuilder}.
    */
   List<E> getIncomingEdgesOf(final V v);
@@ -66,7 +72,7 @@ public interface DAGInterface<V extends Vertex, E extends Edge<V>> extends Seria
   /**
    * Retrieves the incoming edges of the given vertex.
    * @param vertexId the ID of the subject vertex.
-   * @return the set of incoming edges to the vertex.
+   * @return the list of incoming edges to the vertex.
    * Note that the result is never null, ensured by {@link DAGBuilder}.
    */
   List<E> getIncomingEdgesOf(final String vertexId);
@@ -74,7 +80,7 @@ public interface DAGInterface<V extends Vertex, E extends Edge<V>> extends Seria
   /**
    * Retrieves the outgoing edges of the given vertex.
    * @param v the subject vertex.
-   * @return the set of outgoing edges to the vertex.
+   * @return the list of outgoing edges to the vertex.
    * Note that the result is never null, ensured by {@link DAGBuilder}.
    */
   List<E> getOutgoingEdgesOf(final V v);
@@ -82,7 +88,7 @@ public interface DAGInterface<V extends Vertex, E extends Edge<V>> extends Seria
   /**
    * Retrieves the outgoing edges of the given vertex.
    * @param vertexId the ID of the subject vertex.
-   * @return the set of outgoing edges to the vertex.
+   * @return the list of outgoing edges to the vertex.
    * Note that the result is never null, ensured by {@link DAGBuilder}.
    */
   List<E> getOutgoingEdgesOf(final String vertexId);
diff --git a/common/src/main/java/org/apache/nemo/common/ir/IRDAG.java b/common/src/main/java/org/apache/nemo/common/ir/IRDAG.java
index a6cd1c7..af79e72 100644
--- a/common/src/main/java/org/apache/nemo/common/ir/IRDAG.java
+++ b/common/src/main/java/org/apache/nemo/common/ir/IRDAG.java
@@ -19,25 +19,31 @@
 package org.apache.nemo.common.ir;
 
 import com.fasterxml.jackson.databind.node.ObjectNode;
-import org.apache.nemo.common.PairKeyExtractor;
+import com.google.common.collect.Sets;
+import org.apache.nemo.common.KeyExtractor;
+import org.apache.nemo.common.Pair;
+import org.apache.nemo.common.Util;
+import org.apache.nemo.common.coder.BytesDecoderFactory;
+import org.apache.nemo.common.coder.BytesEncoderFactory;
 import org.apache.nemo.common.dag.DAG;
 import org.apache.nemo.common.dag.DAGBuilder;
 import org.apache.nemo.common.dag.DAGInterface;
+import org.apache.nemo.common.exception.CompileTimeOptimizationException;
 import org.apache.nemo.common.exception.IllegalEdgeOperationException;
 import org.apache.nemo.common.ir.edge.IREdge;
 import org.apache.nemo.common.ir.edge.executionproperty.*;
 import org.apache.nemo.common.ir.vertex.IRVertex;
 import org.apache.nemo.common.ir.vertex.LoopVertex;
+import org.apache.nemo.common.ir.vertex.executionproperty.MessageIdVertexProperty;
 import org.apache.nemo.common.ir.vertex.utility.MessageAggregatorVertex;
 import org.apache.nemo.common.ir.vertex.utility.MessageBarrierVertex;
+import org.apache.nemo.common.ir.vertex.utility.SamplingVertex;
 import org.apache.nemo.common.ir.vertex.utility.StreamVertex;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import javax.annotation.concurrent.NotThreadSafe;
-import java.util.List;
-import java.util.Set;
-import java.util.concurrent.atomic.AtomicInteger;
+import java.util.*;
 import java.util.function.BiFunction;
 import java.util.function.Consumer;
 import java.util.function.Function;
@@ -53,13 +59,13 @@ import java.util.stream.Collectors;
  * All of these methods preserve application semantics.
  * - Annotation: setProperty(), getPropertyValue() on each IRVertex/IREdge
  * - Reshaping: insert(), delete() on the IRDAG
+ *
+ * TODO #341: Rethink IRDAG insert() signatures
  */
 @NotThreadSafe
 public final class IRDAG implements DAGInterface<IRVertex, IREdge> {
   private static final Logger LOG = LoggerFactory.getLogger(IRDAG.class.getName());
 
-  private final AtomicInteger metricCollectionId;
-
   private DAG<IRVertex, IREdge> dagSnapshot; // the DAG that was saved most recently.
   private DAG<IRVertex, IREdge> modifiedDAG; // the DAG that is being updated.
 
@@ -69,7 +75,6 @@ public final class IRDAG implements DAGInterface<IRVertex, IREdge> {
   public IRDAG(final DAG<IRVertex, IREdge> originalUserApplicationDAG) {
     this.modifiedDAG = originalUserApplicationDAG;
     this.dagSnapshot = originalUserApplicationDAG;
-    this.metricCollectionId = new AtomicInteger(0);
   }
 
   //////////////////////////////////////////////////
@@ -103,15 +108,25 @@ public final class IRDAG implements DAGInterface<IRVertex, IREdge> {
    * After: src - edgeToStreamizeWithNewDestination - streamVertex - oneToOneEdge - dst
    * (replaces the "Before" relationships)
    *
+   * This preserves semantics as the streamVertex simply forwards data elements from the input edge to the output edge.
+   *
    * @param streamVertex to insert.
    * @param edgeToStreamize to modify.
    */
   public void insert(final StreamVertex streamVertex, final IREdge edgeToStreamize) {
+    assertNonExistence(streamVertex);
+
     // Create a completely new DAG with the vertex inserted.
-    final DAGBuilder builder = new DAGBuilder();
+    final DAGBuilder<IRVertex, IREdge> builder = new DAGBuilder<>();
+
+    // Integrity check
+    if (edgeToStreamize.getPropertyValue(MessageIdEdgeProperty.class).isPresent()) {
+      throw new CompileTimeOptimizationException(edgeToStreamize.getId() + " has a MessageId, and cannot be removed");
+    }
 
     // Insert the vertex.
-    builder.addVertex(streamVertex);
+    final IRVertex vertexToInsert = wrapSamplingVertexIfNeeded(streamVertex, edgeToStreamize.getSrc());
+    builder.addVertex(vertexToInsert);
 
     // Build the new DAG to reflect the new topology.
     modifiedDAG.topologicalDo(v -> {
@@ -122,20 +137,35 @@ public final class IRDAG implements DAGInterface<IRVertex, IREdge> {
           // MATCH!
 
           // Edge to the streamVertex
-          final IREdge edgeToStreamizeWithNewDestination = new IREdge(
+          final IREdge toSV = new IREdge(
             edgeToStreamize.getPropertyValue(CommunicationPatternProperty.class).get(),
             edgeToStreamize.getSrc(),
-            streamVertex);
-          edgeToStreamize.copyExecutionPropertiesTo(edgeToStreamizeWithNewDestination);
+            vertexToInsert);
+          edgeToStreamize.copyExecutionPropertiesTo(toSV);
 
           // Edge from the streamVertex.
-          final IREdge oneToOneEdge = new IREdge(CommunicationPatternProperty.Value.OneToOne, streamVertex, v);
-          oneToOneEdge.setProperty(EncoderProperty.of(edgeToStreamize.getPropertyValue(EncoderProperty.class).get()));
-          oneToOneEdge.setProperty(DecoderProperty.of(edgeToStreamize.getPropertyValue(DecoderProperty.class).get()));
+          final IREdge fromSV = new IREdge(CommunicationPatternProperty.Value.OneToOne, vertexToInsert, v);
+          fromSV.setProperty(EncoderProperty.of(edgeToStreamize.getPropertyValue(EncoderProperty.class).get()));
+          fromSV.setProperty(DecoderProperty.of(edgeToStreamize.getPropertyValue(DecoderProperty.class).get()));
+
+          // Future optimizations may want to use the original encoders/compressions.
+          toSV.setPropertySnapshot();
+          fromSV.setPropertySnapshot();
+
+          // Annotations for efficient data transfers - toSV
+          toSV.setPropertyPermanently(DecoderProperty.of(BytesDecoderFactory.of()));
+          toSV.setPropertyPermanently(CompressionProperty.of(CompressionProperty.Value.LZ4));
+          toSV.setPropertyPermanently(DecompressionProperty.of(CompressionProperty.Value.None));
+
+          // Annotations for efficient data transfers - fromSV
+          fromSV.setPropertyPermanently(EncoderProperty.of(BytesEncoderFactory.of()));
+          fromSV.setPropertyPermanently(CompressionProperty.of(CompressionProperty.Value.None));
+          fromSV.setPropertyPermanently(DecompressionProperty.of(CompressionProperty.Value.LZ4));
+          fromSV.setPropertyPermanently(PartitionerProperty.of(PartitionerProperty.Type.DedicatedKeyPerElement));
 
           // Track the new edges.
-          builder.connectVertices(edgeToStreamizeWithNewDestination);
-          builder.connectVertices(oneToOneEdge);
+          builder.connectVertices(toSV);
+          builder.connectVertices(fromSV);
         } else {
           // NO MATCH, so simply connect vertices as before.
           builder.connectVertices(edge);
@@ -156,77 +186,174 @@ public final class IRDAG implements DAGInterface<IRVertex, IREdge> {
    *        shuffleEdge - messageAggregatorVertex - broadcastEdge - dst
    * (the "Before" relationships are unmodified)
    *
+   * This preserves semantics as the results of the inserted message vertices are never consumed by the original IRDAG.
+   *
    * @param messageBarrierVertex to insert.
    * @param messageAggregatorVertex to insert.
    * @param mbvOutputEncoder to use.
    * @param mbvOutputDecoder to use.
    * @param edgesToGetStatisticsOf to examine.
+   * @param edgesToOptimize to optimize.
    */
   public void insert(final MessageBarrierVertex messageBarrierVertex,
                      final MessageAggregatorVertex messageAggregatorVertex,
                      final EncoderProperty mbvOutputEncoder,
                      final DecoderProperty mbvOutputDecoder,
-                     final Set<IREdge> edgesToGetStatisticsOf) {
+                     final Set<IREdge> edgesToGetStatisticsOf,
+                     final Set<IREdge> edgesToOptimize) {
+    assertNonExistence(messageBarrierVertex);
+    assertNonExistence(messageAggregatorVertex);
+
     if (edgesToGetStatisticsOf.stream().map(edge -> edge.getDst().getId()).collect(Collectors.toSet()).size() != 1) {
-      throw new IllegalArgumentException("Not destined to the same vertex: " + edgesToGetStatisticsOf.toString());
+      throw new IllegalArgumentException("Not destined to the same vertex: " + edgesToOptimize.toString());
+    }
+    if (edgesToOptimize.stream().map(edge -> edge.getDst().getId()).collect(Collectors.toSet()).size() != 1) {
+      throw new IllegalArgumentException("Not destined to the same vertex: " + edgesToOptimize.toString());
     }
-    final IRVertex dst = edgesToGetStatisticsOf.iterator().next().getDst();
 
     // Create a completely new DAG with the vertex inserted.
-    final DAGBuilder builder = new DAGBuilder();
+    final DAGBuilder<IRVertex, IREdge> builder = new DAGBuilder<>();
 
-    // Current metric collection id.
-    final int currentMetricCollectionId = metricCollectionId.incrementAndGet();
+    // All of the existing vertices and edges remain intact
+    modifiedDAG.topologicalDo(v -> {
+      builder.addVertex(v);
+      modifiedDAG.getIncomingEdgesOf(v).forEach(builder::connectVertices);
+    });
 
-    // First, add all the vertices.
-    modifiedDAG.topologicalDo(v -> builder.addVertex(v));
+    ////////////////////////////////// STEP 1: Insert new vertices and edges (src - mbv - mav - dst)
 
-    // Add a control dependency (no output) from the messageAggregatorVertex to the destination.
+    // From src to mbv
+    final List<IRVertex> mbvList = new ArrayList<>();
+    for (final IREdge edge : edgesToGetStatisticsOf) {
+      final IRVertex mbvToAdd = wrapSamplingVertexIfNeeded(
+        new MessageBarrierVertex<>(messageBarrierVertex.getMessageFunction()), edge.getSrc());
+      builder.addVertex(mbvToAdd);
+      mbvList.add(mbvToAdd);
+
+      final IREdge clone = Util.cloneEdge(CommunicationPatternProperty.Value.OneToOne, edge, edge.getSrc(), mbvToAdd);
+      builder.connectVertices(clone);
+    }
+
+    // Add mav (no need to wrap with a sampling vertex)
     builder.addVertex(messageAggregatorVertex);
-    final IREdge noDataEdge = new IREdge(CommunicationPatternProperty.Value.BroadCast, messageAggregatorVertex, dst);
-    builder.connectVertices(noDataEdge);
 
-    // Add the edges and the messageBarrierVertex.
+    // From mbv to mav
+    for (final IRVertex mbv : mbvList) {
+      final IREdge edgeToMav = edgeBetweenMessageVertices(
+        mbv, messageAggregatorVertex, mbvOutputEncoder, mbvOutputDecoder);
+      builder.connectVertices(edgeToMav);
+    }
+
+    // From mav to dst
+    // Add a control dependency (no output) from the messageAggregatorVertex to the destination.
+    builder.connectVertices(
+      Util.createControlEdge(messageAggregatorVertex, edgesToGetStatisticsOf.iterator().next().getDst()));
+
+    ////////////////////////////////// STEP 2: Annotate the MessageId on optimization target edges
+
     modifiedDAG.topologicalDo(v -> {
-      for (final IREdge edge : modifiedDAG.getIncomingEdgesOf(v)) {
-        if (edgesToGetStatisticsOf.contains(edge)) {
-          // MATCH!
-          final MessageBarrierVertex mbv = new MessageBarrierVertex<>(messageBarrierVertex.getMessageFunction());
-          builder.addVertex(mbv);
-
-          // Clone the edgeToGetStatisticsOf
-          final IREdge clone = new IREdge(CommunicationPatternProperty.Value.OneToOne, edge.getSrc(), mbv);
-          clone.setProperty(EncoderProperty.of(edge.getPropertyValue(EncoderProperty.class).get()));
-          clone.setProperty(DecoderProperty.of(edge.getPropertyValue(DecoderProperty.class).get()));
-          edge.getPropertyValue(AdditionalOutputTagProperty.class).ifPresent(tag -> {
-            clone.setProperty(AdditionalOutputTagProperty.of(tag));
-          });
-          builder.connectVertices(clone);
-
-          // messageBarrierVertex to the messageAggregatorVertex
-          final IREdge edgeToABV = edgeBetweenMessageVertices(mbv,
-            messageAggregatorVertex, mbvOutputEncoder, mbvOutputDecoder, currentMetricCollectionId);
-          builder.connectVertices(edgeToABV);
-
-          // The original edge
-          // We then insert the vertex with MessageBarrierTransform and vertex with MessageAggregatorTransform
-          // between the vertex and incoming vertices.
-          final IREdge edgeToOriginalDst =
-            new IREdge(edge.getPropertyValue(CommunicationPatternProperty.class).get(), edge.getSrc(), v);
-          edge.copyExecutionPropertiesTo(edgeToOriginalDst);
-          edgeToOriginalDst.setPropertyPermanently(MessageIdProperty.of(currentMetricCollectionId));
-          builder.connectVertices(edgeToOriginalDst);
-        } else {
-          // NO MATCH, so simply connect vertices as before.
-          builder.connectVertices(edge);
+      modifiedDAG.getIncomingEdgesOf(v).forEach(inEdge -> {
+        if (edgesToOptimize.contains(inEdge)) {
+          inEdge.setPropertyPermanently(MessageIdEdgeProperty.of(
+            messageAggregatorVertex.getPropertyValue(MessageIdVertexProperty.class).get()));
         }
-      }
+      });
     });
 
     modifiedDAG = builder.build(); // update the DAG.
   }
 
   /**
+   * Inserts a set of samplingVertices that process sampled data.
+   *
+   * This method automatically inserts the following three types of edges.
+   * (1) Edges between samplingVertices to reflect the original relationship
+   * (2) Edges from the original IRDAG to samplingVertices that clone the inEdges of the original vertices
+   * (3) Edges from the samplingVertices to the original IRDAG to respect executeAfterSamplingVertices
+   *
+   * Suppose the caller supplies the following arguments to perform a "sampled run" of vertices {V1, V2},
+   * prior to executing them.
+   * - samplingVertices: {V1', V2'}
+   * - childrenOfSamplingVertices: {V1}
+   *
+   * Before: V1 - oneToOneEdge - V2 - shuffleEdge - V3
+   * After: V1' - oneToOneEdge - V2' - controlEdge - V1 - oneToOneEdge - V2 - shuffleEdge - V3
+   *
+   * This preserves semantics as the original IRDAG remains unchanged and unaffected.
+   *
+   * (Future calls to insert() can add new vertices that connect to sampling vertices. Such new vertices will also be
+   * wrapped with sampling vertices, as new vertices that consume outputs from sampling vertices will process
+   * a subset of data anyways, and no such new vertex will reach the original DAG except via control edges)
+   *
+   * TODO #343: Extend SamplingVertex control edges
+   *
+   * @param samplingVertices to insert.
+   * @param executeAfterSamplingVertices that must be executed after samplingVertices.
+   */
+  public void insert(final Set<SamplingVertex> samplingVertices,
+                     final Set<IRVertex> executeAfterSamplingVertices) {
+    samplingVertices.forEach(this::assertNonExistence);
+
+    // Create a completely new DAG with the vertex inserted.
+    final DAGBuilder<IRVertex, IREdge> builder = new DAGBuilder<>();
+
+    // All of the existing vertices and edges remain intact
+    modifiedDAG.topologicalDo(v -> {
+      builder.addVertex(v);
+      modifiedDAG.getIncomingEdgesOf(v).forEach(builder::connectVertices);
+    });
+
+    // Add the sampling vertices
+    samplingVertices.forEach(builder::addVertex);
+
+    // Get the original vertices
+    final Map<IRVertex, IRVertex> originalToSampling = samplingVertices.stream()
+      .collect(Collectors.toMap(sv -> modifiedDAG.getVertexById(sv.getOriginalVertexId()), Function.identity()));
+    final Set<IREdge> inEdgesOfOriginals = originalToSampling.keySet()
+      .stream()
+      .flatMap(ov -> modifiedDAG.getIncomingEdgesOf(ov).stream())
+      .collect(Collectors.toSet());
+
+    // [EDGE TYPE 1] Between sampling vertices
+    final Set<IREdge> betweenOriginals = inEdgesOfOriginals
+      .stream()
+      .filter(ovInEdge -> originalToSampling.containsKey(ovInEdge.getSrc()))
+      .collect(Collectors.toSet());
+    betweenOriginals.stream().map(boEdge -> Util.cloneEdge(
+      boEdge,
+      originalToSampling.get(boEdge.getSrc()),
+      originalToSampling.get(boEdge.getDst()))).forEach(builder::connectVertices);
+
+    // [EDGE TYPE 2] From original IRDAG to sampling vertices
+    final Set<IREdge> notBetweenOriginals = inEdgesOfOriginals
+      .stream()
+      .filter(ovInEdge -> !originalToSampling.containsKey(ovInEdge.getSrc()))
+      .collect(Collectors.toSet());
+    notBetweenOriginals.stream().map(nboEdge -> {
+      final IREdge cloneEdge = Util.cloneEdge(
+        nboEdge,
+        nboEdge.getSrc(), // sampling vertices consume a subset of original data partitions here
+        originalToSampling.get(nboEdge.getDst()));
+      nboEdge.copyExecutionPropertiesTo(cloneEdge); // exec properties must be exactly the same
+      return cloneEdge;
+    }).forEach(builder::connectVertices);
+
+    // [EDGE TYPE 3] From sampling vertices to vertices that should be executed after
+    final Set<IRVertex> sinks = getSinksWithinVertexSet(modifiedDAG, originalToSampling.keySet())
+      .stream()
+      .map(originalToSampling::get)
+      .collect(Collectors.toSet());
+    for (final IRVertex executeAfter : executeAfterSamplingVertices) {
+      for (final IRVertex sink : sinks) {
+        // Control edge that enforces execution ordering
+        builder.connectVertices(Util.createControlEdge(sink, executeAfter));
+      }
+    }
+
+    modifiedDAG = builder.build(); // update the DAG.
+  }
+
+  /**
    * Reshape unsafely, without guarantees on preserving application semantics.
    * TODO #330: Refactor Unsafe Reshaping Passes
    * @param unsafeReshapingFunction takes as input the underlying DAG, and outputs a reshaped DAG.
@@ -237,25 +364,52 @@ public final class IRDAG implements DAGInterface<IRVertex, IREdge> {
 
   ////////////////////////////////////////////////// Private helper methods.
 
+  private Set<IRVertex> getSinksWithinVertexSet(final DAG<IRVertex, IREdge> dag,
+                                                final Set<IRVertex> vertexSet) {
+    final Set<IRVertex> parentsOfAnotherVertex = vertexSet.stream()
+      .flatMap(v -> dag.getOutgoingEdgesOf(v).stream())
+      .filter(e -> vertexSet.contains(e.getDst()))
+      .map(IREdge::getSrc) // makes the result a subset of the input vertexSet
+      .collect(Collectors.toSet());
+    return Sets.difference(vertexSet, parentsOfAnotherVertex);
+  }
+
+  private IRVertex wrapSamplingVertexIfNeeded(final IRVertex newVertex, final IRVertex existingVertexToConnectWith) {
+    // If the connecting vertex is a sampling vertex, the new vertex must be wrapped inside a sampling vertex too.
+    return existingVertexToConnectWith instanceof SamplingVertex
+      ? new SamplingVertex(newVertex, ((SamplingVertex) existingVertexToConnectWith).getDesiredSampleRate())
+      : newVertex;
+  }
+
+  private void assertNonExistence(final IRVertex v) {
+    if (getVertices().contains(v)) {
+      throw new IllegalArgumentException(v.getId());
+    }
+  }
+
   /**
    * @param mbv src.
    * @param mav dst.
    * @param encoder src-dst encoder.
    * @param decoder src-dst decoder.
-   * @param currentMetricCollectionId of the edge.
    * @return the edge.
    */
-  private IREdge edgeBetweenMessageVertices(final MessageBarrierVertex mbv,
-                                            final MessageAggregatorVertex mav,
+  private IREdge edgeBetweenMessageVertices(final IRVertex mbv,
+                                            final IRVertex mav,
                                             final EncoderProperty encoder,
-                                            final DecoderProperty decoder,
-                                            final int currentMetricCollectionId) {
+                                            final DecoderProperty decoder) {
     final IREdge newEdge = new IREdge(CommunicationPatternProperty.Value.Shuffle, mbv, mav);
     newEdge.setProperty(DataStoreProperty.of(DataStoreProperty.Value.LocalFileStore));
     newEdge.setProperty(DataPersistenceProperty.of(DataPersistenceProperty.Value.Keep));
     newEdge.setProperty(DataFlowProperty.of(DataFlowProperty.Value.Push));
-    newEdge.setPropertyPermanently(MessageIdProperty.of(currentMetricCollectionId));
-    newEdge.setProperty(KeyExtractorProperty.of(new PairKeyExtractor()));
+    final KeyExtractor pairKeyExtractor = (element) -> {
+      if (element instanceof Pair) {
+        return ((Pair) element).left();
+      } else {
+        throw new IllegalStateException(element.toString());
+      }
+    };
+    newEdge.setProperty(KeyExtractorProperty.of(pairKeyExtractor));
     newEdge.setPropertyPermanently(encoder);
     newEdge.setPropertyPermanently(decoder);
     return newEdge;
@@ -322,6 +476,11 @@ public final class IRDAG implements DAGInterface<IRVertex, IREdge> {
   }
 
   @Override
+  public List<IREdge> getEdges() {
+    return modifiedDAG.getEdges();
+  }
+
+  @Override
   public List<IRVertex> getRootVertices() {
     return modifiedDAG.getRootVertices();
   }
diff --git a/common/src/main/java/org/apache/nemo/common/ir/edge/IREdge.java b/common/src/main/java/org/apache/nemo/common/ir/edge/IREdge.java
index 1aa117b..68e5381 100644
--- a/common/src/main/java/org/apache/nemo/common/ir/edge/IREdge.java
+++ b/common/src/main/java/org/apache/nemo/common/ir/edge/IREdge.java
@@ -30,7 +30,7 @@ import org.apache.nemo.common.ir.vertex.IRVertex;
 import org.apache.commons.lang3.builder.HashCodeBuilder;
 
 import java.io.Serializable;
-import java.util.Optional;
+import java.util.*;
 
 /**
  * Physical execution plan of intermediate data movement.
@@ -139,4 +139,17 @@ public final class IREdge extends Edge<IRVertex> {
     node.set("executionProperties", executionProperties.asJsonNode());
     return node;
   }
+
+  /////////// For saving original EPs (e.g., save original encoders/decoders of StreamVertex edges)
+
+  private final Map<Class, EdgeExecutionProperty> snapshot = new HashMap<>();
+
+  public void setPropertySnapshot() {
+    snapshot.clear();
+    executionProperties.forEachProperties(p -> snapshot.put(p.getClass(), p));
+  }
+
+  public Map<Class, EdgeExecutionProperty> getPropertySnapshot() {
+    return snapshot;
+  }
 }
diff --git a/common/src/main/java/org/apache/nemo/common/ir/edge/executionproperty/MessageIdProperty.java b/common/src/main/java/org/apache/nemo/common/ir/edge/executionproperty/MessageIdEdgeProperty.java
similarity index 78%
copy from common/src/main/java/org/apache/nemo/common/ir/edge/executionproperty/MessageIdProperty.java
copy to common/src/main/java/org/apache/nemo/common/ir/edge/executionproperty/MessageIdEdgeProperty.java
index 3bd74f8..b45c85e 100644
--- a/common/src/main/java/org/apache/nemo/common/ir/edge/executionproperty/MessageIdProperty.java
+++ b/common/src/main/java/org/apache/nemo/common/ir/edge/executionproperty/MessageIdEdgeProperty.java
@@ -21,14 +21,14 @@ package org.apache.nemo.common.ir.edge.executionproperty;
 import org.apache.nemo.common.ir.executionproperty.EdgeExecutionProperty;
 
 /**
- * Edges with the same MessageId are subject to the same run-time optimization.
+ * Vertices and edges with the same MessageId are subject to the same run-time optimization.
  */
-public final class MessageIdProperty extends EdgeExecutionProperty<Integer> {
+public final class MessageIdEdgeProperty extends EdgeExecutionProperty<Integer> {
   /**
    * Constructor.
    * @param value value of the execution property.
    */
-  private MessageIdProperty(final Integer value) {
+  private MessageIdEdgeProperty(final Integer value) {
     super(value);
   }
 
@@ -37,7 +37,7 @@ public final class MessageIdProperty extends EdgeExecutionProperty<Integer> {
    * @param value value of the new execution property.
    * @return the newly created execution property.
    */
-  public static MessageIdProperty of(final Integer value) {
-    return new MessageIdProperty(value);
+  public static MessageIdEdgeProperty of(final Integer value) {
+    return new MessageIdEdgeProperty(value);
   }
 }
diff --git a/common/src/main/java/org/apache/nemo/common/ir/vertex/LoopVertex.java b/common/src/main/java/org/apache/nemo/common/ir/vertex/LoopVertex.java
index f279783..b1ce3ca 100644
--- a/common/src/main/java/org/apache/nemo/common/ir/vertex/LoopVertex.java
+++ b/common/src/main/java/org/apache/nemo/common/ir/vertex/LoopVertex.java
@@ -21,13 +21,13 @@ package org.apache.nemo.common.ir.vertex;
 import com.fasterxml.jackson.databind.node.ArrayNode;
 import com.fasterxml.jackson.databind.node.JsonNodeFactory;
 import com.fasterxml.jackson.databind.node.ObjectNode;
+import org.apache.nemo.common.Util;
 import org.apache.nemo.common.dag.DAG;
 import org.apache.nemo.common.dag.DAGBuilder;
 import org.apache.nemo.common.ir.edge.IREdge;
 import org.apache.nemo.common.ir.edge.executionproperty.CommunicationPatternProperty;
 import org.apache.nemo.common.ir.edge.executionproperty.DuplicateEdgeGroupProperty;
 import org.apache.nemo.common.ir.edge.executionproperty.DuplicateEdgeGroupPropertyValue;
-import org.apache.nemo.common.util.Util;
 
 import java.io.Serializable;
 import java.util.*;
diff --git a/common/src/main/java/org/apache/nemo/common/ir/vertex/SourceVertex.java b/common/src/main/java/org/apache/nemo/common/ir/vertex/SourceVertex.java
index d8f6eee..7ad6a82 100644
--- a/common/src/main/java/org/apache/nemo/common/ir/vertex/SourceVertex.java
+++ b/common/src/main/java/org/apache/nemo/common/ir/vertex/SourceVertex.java
@@ -28,7 +28,6 @@ import java.util.List;
  * @param <O> output type.
  */
 public abstract class SourceVertex<O> extends IRVertex {
-
   /**
    * Constructor for SourceVertex.
    */
@@ -49,6 +48,7 @@ public abstract class SourceVertex<O> extends IRVertex {
   public SourceVertex(final SourceVertex that) {
     super(that);
   }
+
   /**
    * Gets parallel readables.
    *
diff --git a/common/src/main/java/org/apache/nemo/common/ir/edge/executionproperty/MessageIdProperty.java b/common/src/main/java/org/apache/nemo/common/ir/vertex/executionproperty/MessageIdVertexProperty.java
similarity index 69%
rename from common/src/main/java/org/apache/nemo/common/ir/edge/executionproperty/MessageIdProperty.java
rename to common/src/main/java/org/apache/nemo/common/ir/vertex/executionproperty/MessageIdVertexProperty.java
index 3bd74f8..9ba842f 100644
--- a/common/src/main/java/org/apache/nemo/common/ir/edge/executionproperty/MessageIdProperty.java
+++ b/common/src/main/java/org/apache/nemo/common/ir/vertex/executionproperty/MessageIdVertexProperty.java
@@ -16,19 +16,19 @@
  * specific language governing permissions and limitations
  * under the License.
  */
-package org.apache.nemo.common.ir.edge.executionproperty;
+package org.apache.nemo.common.ir.vertex.executionproperty;
 
-import org.apache.nemo.common.ir.executionproperty.EdgeExecutionProperty;
+import org.apache.nemo.common.ir.executionproperty.VertexExecutionProperty;
 
 /**
- * Edges with the same MessageId are subject to the same run-time optimization.
+ * Vertices and edges with the same MessageId belong to the same run-time optimization.
  */
-public final class MessageIdProperty extends EdgeExecutionProperty<Integer> {
+public final class MessageIdVertexProperty extends VertexExecutionProperty<Integer> {
   /**
    * Constructor.
    * @param value value of the execution property.
    */
-  private MessageIdProperty(final Integer value) {
+  private MessageIdVertexProperty(final Integer value) {
     super(value);
   }
 
@@ -37,7 +37,7 @@ public final class MessageIdProperty extends EdgeExecutionProperty<Integer> {
    * @param value value of the new execution property.
    * @return the newly created execution property.
    */
-  public static MessageIdProperty of(final Integer value) {
-    return new MessageIdProperty(value);
+  public static MessageIdVertexProperty of(final Integer value) {
+    return new MessageIdVertexProperty(value);
   }
 }
diff --git a/common/src/main/java/org/apache/nemo/common/ir/vertex/utility/MessageAggregatorVertex.java b/common/src/main/java/org/apache/nemo/common/ir/vertex/utility/MessageAggregatorVertex.java
index a5a9280..4db54d7 100644
--- a/common/src/main/java/org/apache/nemo/common/ir/vertex/utility/MessageAggregatorVertex.java
+++ b/common/src/main/java/org/apache/nemo/common/ir/vertex/utility/MessageAggregatorVertex.java
@@ -20,8 +20,12 @@ package org.apache.nemo.common.ir.vertex.utility;
 
 import org.apache.nemo.common.Pair;
 import org.apache.nemo.common.ir.vertex.OperatorVertex;
+import org.apache.nemo.common.ir.vertex.executionproperty.MessageIdVertexProperty;
 import org.apache.nemo.common.ir.vertex.transform.MessageAggregatorTransform;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 
+import java.util.concurrent.atomic.AtomicInteger;
 import java.util.function.BiFunction;
 
 /**
@@ -30,12 +34,17 @@ import java.util.function.BiFunction;
  * @param <V> of the input pair.
  * @param <O> of the output aggregated message.
  */
-public class MessageAggregatorVertex<K, V, O> extends OperatorVertex {
+public final class MessageAggregatorVertex<K, V, O> extends OperatorVertex {
+  private static final Logger LOG = LoggerFactory.getLogger(MessageAggregatorVertex.class.getName());
+
+  private static final AtomicInteger MESSAGE_ID_GENERATOR = new AtomicInteger(0);
+
   /**
    * @param initialState to use.
    * @param userFunction for aggregating the messages.
    */
   public MessageAggregatorVertex(final O initialState, final BiFunction<Pair<K, V>, O, O> userFunction) {
     super(new MessageAggregatorTransform<>(initialState, userFunction));
+    this.setPropertyPermanently(MessageIdVertexProperty.of(MESSAGE_ID_GENERATOR.incrementAndGet()));
   }
 }
diff --git a/common/src/main/java/org/apache/nemo/common/ir/vertex/utility/SamplingVertex.java b/common/src/main/java/org/apache/nemo/common/ir/vertex/utility/SamplingVertex.java
new file mode 100644
index 0000000..fe54952
--- /dev/null
+++ b/common/src/main/java/org/apache/nemo/common/ir/vertex/utility/SamplingVertex.java
@@ -0,0 +1,119 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.nemo.common.ir.vertex.utility;
+
+import com.fasterxml.jackson.databind.JsonNode;
+import org.apache.nemo.common.Util;
+import org.apache.nemo.common.ir.edge.IREdge;
+import org.apache.nemo.common.ir.vertex.IRVertex;
+
+/**
+ * Executes the original IRVertex using a subset of input data partitions.
+ */
+public final class SamplingVertex extends IRVertex {
+  private final IRVertex originalVertex;
+  private final IRVertex cloneOfOriginalVertex;
+  private final float desiredSampleRate;
+
+  /**
+   * @param originalVertex to clone.
+   * @param desiredSampleRate percentage of tasks to execute.
+   *                          The actual sample rate may vary depending on neighboring sampling vertices.
+   */
+  public SamplingVertex(final IRVertex originalVertex, final float desiredSampleRate) {
+    super();
+    if (originalVertex instanceof SamplingVertex) {
+      throw new IllegalArgumentException("Cannot sample again: " + originalVertex.toString());
+    }
+    if (desiredSampleRate > 1 || desiredSampleRate <= 0) {
+      throw new IllegalArgumentException(String.valueOf(desiredSampleRate));
+    }
+    this.originalVertex = originalVertex;
+    this.cloneOfOriginalVertex = originalVertex.getClone();
+    this.desiredSampleRate = desiredSampleRate;
+
+    // Copy execution properties.
+    originalVertex.copyExecutionPropertiesTo(cloneOfOriginalVertex);
+    originalVertex.copyExecutionPropertiesTo(this);
+  }
+
+  /**
+   * @return the id of the original vertex for reference.
+   */
+  public String getOriginalVertexId() {
+    return originalVertex.getId();
+  }
+
+  /**
+   * @return the clone of the original vertex.
+   * This clone is intended to be used during the actual execution, as the sampling vertex itself is not executable
+   * and the original vertex should not be executed again.
+   */
+  public IRVertex getCloneOfOriginalVertex() {
+    return cloneOfOriginalVertex;
+  }
+
+  /**
+   * @return the desired sample rate.
+   */
+  public float getDesiredSampleRate() {
+    return desiredSampleRate;
+  }
+
+  /**
+   * Obtains a clone of an original edge that is attached to this sampling vertex.
+   *
+   * Original edge: src - to - dst
+   * When src == originalVertex, return thisSamplingVertex - to - dst
+   * When dst == originalVertex, return src - to - thisSamplingVertex
+   *
+   * @param originalEdge to clone.
+   * @return a clone of the edge.
+   */
+  public IREdge getCloneOfOriginalEdge(final IREdge originalEdge) {
+    if (originalEdge.getSrc().equals(originalVertex)) {
+      return Util.cloneEdge(originalEdge, this, originalEdge.getDst());
+    } else if (originalEdge.getDst().equals(originalVertex)) {
+      return Util.cloneEdge(originalEdge, originalEdge.getSrc(), this);
+    } else {
+      throw new IllegalArgumentException(originalEdge.getId());
+    }
+  }
+
+  @Override
+  public String toString() {
+    final StringBuilder sb = new StringBuilder();
+    sb.append("SamplingVertex(desiredSampleRate:");
+    sb.append(String.valueOf(desiredSampleRate));
+    sb.append(")[");
+    sb.append(originalVertex);
+    sb.append("]");
+    return sb.toString();
+  }
+
+  @Override
+  public IRVertex getClone() {
+    return new SamplingVertex(originalVertex, desiredSampleRate);
+  }
+
+  @Override
+  public JsonNode getPropertiesAsJsonNode() {
+    return getCloneOfOriginalVertex().getPropertiesAsJsonNode();
+  }
+}
diff --git a/common/src/test/java/org/apache/nemo/common/util/UtilTest.java b/common/src/test/java/org/apache/nemo/common/util/UtilTest.java
index 4e6869f..e46db33 100644
--- a/common/src/test/java/org/apache/nemo/common/util/UtilTest.java
+++ b/common/src/test/java/org/apache/nemo/common/util/UtilTest.java
@@ -22,6 +22,7 @@ import static org.junit.Assert.assertEquals;
 
 import java.util.function.IntPredicate;
 
+import org.apache.nemo.common.Util;
 import org.junit.Test;
 
 public class UtilTest {
diff --git a/compiler/backend/src/main/java/org/apache/nemo/compiler/backend/nemo/NemoPlanRewriter.java b/compiler/backend/src/main/java/org/apache/nemo/compiler/backend/nemo/NemoPlanRewriter.java
index a457de4..7131be0 100644
--- a/compiler/backend/src/main/java/org/apache/nemo/compiler/backend/nemo/NemoPlanRewriter.java
+++ b/compiler/backend/src/main/java/org/apache/nemo/compiler/backend/nemo/NemoPlanRewriter.java
@@ -20,7 +20,7 @@ package org.apache.nemo.compiler.backend.nemo;
 
 import org.apache.nemo.common.ir.IRDAG;
 import org.apache.nemo.common.ir.edge.IREdge;
-import org.apache.nemo.common.ir.edge.executionproperty.MessageIdProperty;
+import org.apache.nemo.common.ir.edge.executionproperty.MessageIdEdgeProperty;
 import org.apache.nemo.common.ir.executionproperty.ExecutionPropertyMap;
 import org.apache.nemo.common.ir.executionproperty.VertexExecutionProperty;
 import org.apache.nemo.common.ir.vertex.utility.MessageAggregatorVertex;
@@ -87,8 +87,8 @@ public final class NemoPlanRewriter implements PlanRewriter {
       .getVertices()
       .stream()
       .flatMap(v -> currentIRDAG.getIncomingEdgesOf(v).stream())
-      .filter(e -> e.getPropertyValue(MessageIdProperty.class).isPresent()
-        && e.getPropertyValue(MessageIdProperty.class).get() == messageId
+      .filter(e -> e.getPropertyValue(MessageIdEdgeProperty.class).isPresent()
+        && e.getPropertyValue(MessageIdEdgeProperty.class).get() == messageId
         && !(e.getDst() instanceof MessageAggregatorVertex))
       .collect(Collectors.toSet());
     if (examiningEdges.isEmpty()) {
diff --git a/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/annotating/LargeShuffleAnnotatingPass.java b/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/annotating/LargeShuffleAnnotatingPass.java
index 0958037..b6e4194 100644
--- a/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/annotating/LargeShuffleAnnotatingPass.java
+++ b/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/annotating/LargeShuffleAnnotatingPass.java
@@ -18,8 +18,6 @@
  */
 package org.apache.nemo.compiler.optimizer.pass.compiletime.annotating;
 
-import org.apache.nemo.common.coder.BytesDecoderFactory;
-import org.apache.nemo.common.coder.BytesEncoderFactory;
 import org.apache.nemo.common.ir.IRDAG;
 import org.apache.nemo.common.ir.edge.executionproperty.*;
 import org.apache.nemo.common.ir.vertex.executionproperty.ResourceSlotProperty;
@@ -42,9 +40,7 @@ import org.apache.nemo.compiler.optimizer.pass.compiletime.Requires;
  * Do not encode/compress the byte[]
  * Perform a pull-based and on-disk data transfer with the DedicatedKeyPerElement.
  */
-@Annotates({CompressionProperty.class, DataFlowProperty.class, CompressionProperty.class,
-  DataPersistenceProperty.class, DataStoreProperty.class, DecoderProperty.class, DecompressionProperty.class,
-  EncoderProperty.class, PartitionerProperty.class, ResourceSlotProperty.class})
+@Annotates({DataFlowProperty.class, DataPersistenceProperty.class, DataStoreProperty.class, ResourceSlotProperty.class})
 @Requires(CommunicationPatternProperty.class)
 public final class LargeShuffleAnnotatingPass extends AnnotatingPass {
   /**
@@ -61,11 +57,6 @@ public final class LargeShuffleAnnotatingPass extends AnnotatingPass {
         if (edge.getDst().getClass().equals(StreamVertex.class)) {
           // CASE #1: To a stream vertex
 
-          // Coder and Compression
-          edge.setPropertyPermanently(DecoderProperty.of(BytesDecoderFactory.of()));
-          edge.setPropertyPermanently(CompressionProperty.of(CompressionProperty.Value.LZ4));
-          edge.setPropertyPermanently(DecompressionProperty.of(CompressionProperty.Value.None));
-
           // Data transfers
           edge.setPropertyPermanently(DataFlowProperty.of(DataFlowProperty.Value.Push));
           edge.setPropertyPermanently(DataPersistenceProperty.of(DataPersistenceProperty.Value.Discard));
@@ -76,16 +67,9 @@ public final class LargeShuffleAnnotatingPass extends AnnotatingPass {
         } else if (edge.getSrc().getClass().equals(StreamVertex.class)) {
           // CASE #2: From a stream vertex
 
-          // Coder and Compression
-          edge.setPropertyPermanently(EncoderProperty.of(BytesEncoderFactory.of()));
-          edge.setPropertyPermanently(CompressionProperty.of(CompressionProperty.Value.None));
-          edge.setPropertyPermanently(DecompressionProperty.of(CompressionProperty.Value.LZ4));
-
           // Data transfers
           edge.setPropertyPermanently(DataFlowProperty.of(DataFlowProperty.Value.Pull));
           edge.setPropertyPermanently(DataStoreProperty.of(DataStoreProperty.Value.LocalFileStore));
-          edge.setPropertyPermanently(
-            PartitionerProperty.of(PartitionerProperty.Type.DedicatedKeyPerElement));
         } else {
           // CASE #3: Unrelated to any stream vertices
           edge.setPropertyPermanently(DataFlowProperty.of(DataFlowProperty.Value.Pull));
diff --git a/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/SamplingSkewReshapingPass.java b/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/SamplingSkewReshapingPass.java
new file mode 100644
index 0000000..8db8085
--- /dev/null
+++ b/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/SamplingSkewReshapingPass.java
@@ -0,0 +1,137 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.nemo.compiler.optimizer.pass.compiletime.reshaping;
+
+import org.apache.nemo.common.KeyExtractor;
+import org.apache.nemo.common.dag.Edge;
+import org.apache.nemo.common.ir.IRDAG;
+import org.apache.nemo.common.ir.edge.IREdge;
+import org.apache.nemo.common.ir.edge.executionproperty.*;
+import org.apache.nemo.common.ir.vertex.IRVertex;
+import org.apache.nemo.common.ir.vertex.utility.MessageAggregatorVertex;
+import org.apache.nemo.common.ir.vertex.utility.MessageBarrierVertex;
+import org.apache.nemo.common.ir.vertex.utility.SamplingVertex;
+import org.apache.nemo.compiler.optimizer.pass.compiletime.Requires;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.*;
+import java.util.stream.Collectors;
+
+/**
+ * Optimizes the PartitionSet property of shuffle edges to handle data skews using the SamplingVertex.
+ *
+ * This pass effectively partitions the IRDAG by non-oneToOne edges, clones each subDAG partition using SamplingVertex
+ * to process sampled data, and executes each cloned partition prior to executing the corresponding original partition.
+ *
+ * Suppose the IRDAG is partitioned into three sub-DAG partitions with shuffle dependencies as follows:
+ * P1 - P2 - P3
+ *
+ * Then, this pass will produce something like:
+ * P1' - P1
+ *     - P2' - P2 - P3
+ * where Px' consists of SamplingVertex objects that clone the execution of Px.
+ * (P3 is not cloned here because it is a sink partition, and none of the outgoing edges of its vertices needs to be
+ * optimized)
+ *
+ * For each Px' this pass also inserts a MessageBarrierVertex, to use its data statistics for dynamically optimizing
+ * the execution behaviors of Px.
+ */
+@Requires(CommunicationPatternProperty.class)
+public final class SamplingSkewReshapingPass extends ReshapingPass {
+  private static final Logger LOG = LoggerFactory.getLogger(SamplingSkewReshapingPass.class.getName());
+  private static final float SAMPLE_RATE = 0.1f;
+
+  /**
+   * Default constructor.
+   */
+  public SamplingSkewReshapingPass() {
+    super(SamplingSkewReshapingPass.class);
+  }
+
+  @Override
+  public IRDAG apply(final IRDAG dag) {
+    dag.topologicalDo(v -> {
+      for (final IREdge e : dag.getIncomingEdgesOf(v)) {
+        if (CommunicationPatternProperty.Value.Shuffle.equals(
+          e.getPropertyValue(CommunicationPatternProperty.class).get())) {
+          // Compute the partition and its source vertices
+          final IRVertex shuffleWriter = e.getSrc();
+          final Set<IRVertex> partitionAll = recursivelyBuildPartition(shuffleWriter, dag);
+          final Set<IRVertex> partitionSources = partitionAll.stream().filter(vertexInPartition ->
+            !dag.getIncomingEdgesOf(vertexInPartition).stream()
+              .map(Edge::getSrc)
+              .anyMatch(partitionAll::contains)
+          ).collect(Collectors.toSet());
+
+          // Check if the partition is a sink, in which case we do not create sampling vertices
+          final boolean isSinkPartition = partitionAll.stream()
+            .flatMap(vertexInPartition -> dag.getOutgoingEdgesOf(vertexInPartition).stream())
+            .map(Edge::getDst)
+            .allMatch(partitionAll::contains);
+          if (isSinkPartition) {
+            break;
+          }
+
+          // Insert sampling vertices.
+          final Set<SamplingVertex> samplingVertices = partitionAll
+            .stream()
+            .map(vertexInPartition -> new SamplingVertex(vertexInPartition, SAMPLE_RATE))
+            .collect(Collectors.toSet());
+          dag.insert(samplingVertices, partitionSources);
+
+          // Insert the message vertex.
+          // We first obtain a clonedShuffleEdge to analyze the data statistics of the shuffle outputs of
+          // the sampling vertex right before shuffle.
+          final SamplingVertex rightBeforeShuffle = samplingVertices.stream()
+            .filter(sv -> sv.getOriginalVertexId().equals(e.getSrc().getId()))
+            .findFirst()
+            .orElseThrow(() -> new IllegalStateException());
+          final IREdge clonedShuffleEdge = rightBeforeShuffle.getCloneOfOriginalEdge(e);
+
+          final KeyExtractor keyExtractor = e.getPropertyValue(KeyExtractorProperty.class).get();
+          dag.insert(
+            new MessageBarrierVertex<>(SkewHandlingUtil.getDynOptCollector(keyExtractor)),
+            new MessageAggregatorVertex(new HashMap(), SkewHandlingUtil.getDynOptAggregator()),
+            SkewHandlingUtil.getEncoder(e),
+            SkewHandlingUtil.getDecoder(e),
+            new HashSet<>(Arrays.asList(clonedShuffleEdge)), // this works although the clone is not in the dag
+            new HashSet<>(Arrays.asList(e))); // we want to optimize the original edge, not the clone
+        }
+      }
+    });
+
+    return dag;
+  }
+
+  private Set<IRVertex> recursivelyBuildPartition(final IRVertex curVertex, final IRDAG dag) {
+    final Set<IRVertex> unionSet = new HashSet<>();
+    unionSet.add(curVertex);
+    for (final IREdge inEdge : dag.getIncomingEdgesOf(curVertex)) {
+      if (CommunicationPatternProperty.Value.OneToOne
+        .equals(inEdge.getPropertyValue(CommunicationPatternProperty.class).get())
+        && DataStoreProperty.Value.MemoryStore
+        .equals(inEdge.getPropertyValue(DataStoreProperty.class).get())
+        && dag.getIncomingEdgesOf(curVertex).size() == 1) {
+        unionSet.addAll(recursivelyBuildPartition(inEdge.getSrc(), dag));
+      }
+    }
+    return unionSet;
+  }
+}
diff --git a/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/SkewHandlingUtil.java b/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/SkewHandlingUtil.java
new file mode 100644
index 0000000..6231ba4
--- /dev/null
+++ b/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/SkewHandlingUtil.java
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.nemo.compiler.optimizer.pass.compiletime.reshaping;
+
+import org.apache.nemo.common.KeyExtractor;
+import org.apache.nemo.common.Pair;
+import org.apache.nemo.common.coder.*;
+import org.apache.nemo.common.ir.edge.IREdge;
+import org.apache.nemo.common.ir.edge.executionproperty.DecoderProperty;
+import org.apache.nemo.common.ir.edge.executionproperty.EncoderProperty;
+import org.apache.nemo.common.ir.edge.executionproperty.KeyDecoderProperty;
+import org.apache.nemo.common.ir.edge.executionproperty.KeyEncoderProperty;
+
+import java.io.Serializable;
+import java.util.Map;
+import java.util.function.BiFunction;
+
+/**
+ * A utility class for skew handling passes.
+ */
+final class SkewHandlingUtil {
+  private SkewHandlingUtil() {
+  }
+
+  static BiFunction<Object, Map<Object, Long>, Map<Object, Long>> getDynOptCollector(final KeyExtractor keyExtractor) {
+    return (BiFunction<Object, Map<Object, Long>, Map<Object, Long>> & Serializable)
+      (element, dynOptData) -> {
+        Object key = keyExtractor.extractKey(element);
+        if (dynOptData.containsKey(key)) {
+          dynOptData.compute(key, (existingKey, existingCount) -> (long) existingCount + 1L);
+        } else {
+          dynOptData.put(key, 1L);
+        }
+        return dynOptData;
+      };
+  }
+
+  static BiFunction<Pair<Object, Long>, Map<Object, Long>, Map<Object, Long>> getDynOptAggregator() {
+    return (BiFunction<Pair<Object, Long>, Map<Object, Long>, Map<Object, Long>> & Serializable)
+      (element, aggregatedDynOptData) -> {
+        final Object key = element.left();
+        final Long count = element.right();
+        if (aggregatedDynOptData.containsKey(key)) {
+          aggregatedDynOptData.compute(key, (existingKey, accumulatedCount) -> accumulatedCount + count);
+        } else {
+          aggregatedDynOptData.put(key, count);
+        }
+        return aggregatedDynOptData;
+      };
+  }
+
+  static EncoderProperty getEncoder(final IREdge irEdge) {
+    return EncoderProperty.of(PairEncoderFactory
+      .of(irEdge.getPropertyValue(KeyEncoderProperty.class).get(), LongEncoderFactory.of()));
+  }
+
+  static DecoderProperty getDecoder(final IREdge irEdge) {
+    return DecoderProperty.of(PairDecoderFactory
+      .of(irEdge.getPropertyValue(KeyDecoderProperty.class).get(), LongDecoderFactory.of()));
+  }
+}
diff --git a/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/SkewReshapingPass.java b/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/SkewReshapingPass.java
index 62c3c9a..046dd75 100644
--- a/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/SkewReshapingPass.java
+++ b/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/SkewReshapingPass.java
@@ -19,24 +19,17 @@
 package org.apache.nemo.compiler.optimizer.pass.compiletime.reshaping;
 
 import org.apache.nemo.common.KeyExtractor;
-import org.apache.nemo.common.Pair;
-import org.apache.nemo.common.coder.*;
 import org.apache.nemo.common.ir.IRDAG;
 import org.apache.nemo.common.ir.edge.IREdge;
 import org.apache.nemo.common.ir.edge.executionproperty.*;
 import org.apache.nemo.common.ir.vertex.utility.MessageAggregatorVertex;
 import org.apache.nemo.common.ir.vertex.utility.MessageBarrierVertex;
 import org.apache.nemo.common.ir.edge.executionproperty.CommunicationPatternProperty;
-import org.apache.nemo.common.ir.edge.executionproperty.DecoderProperty;
-import org.apache.nemo.common.ir.edge.executionproperty.EncoderProperty;
 import org.apache.nemo.compiler.optimizer.pass.compiletime.Requires;
-import org.apache.nemo.compiler.optimizer.pass.compiletime.annotating.Annotates;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-import java.io.Serializable;
 import java.util.*;
-import java.util.function.BiFunction;
 import java.util.function.Function;
 import java.util.stream.Collectors;
 
@@ -45,7 +38,6 @@ import java.util.stream.Collectors;
  * We insert a {@link MessageBarrierVertex} for each shuffle edge,
  * and aggregate messages for multiple same-destination shuffle edges.
  * */
-@Annotates(PartitionerProperty.class)
 @Requires(CommunicationPatternProperty.class)
 public final class SkewReshapingPass extends ReshapingPass {
   private static final Logger LOG = LoggerFactory.getLogger(SkewReshapingPass.class.getName());
@@ -78,43 +70,12 @@ public final class SkewReshapingPass extends ReshapingPass {
         // Get the key extractor
         final KeyExtractor keyExtractor = representativeEdge.getPropertyValue(KeyExtractorProperty.class).get();
 
-        // For collecting the data
-        final BiFunction<Object, Map<Object, Long>, Map<Object, Long>> dynOptDataCollector =
-          (BiFunction<Object, Map<Object, Long>, Map<Object, Long>> & Serializable)
-            (element, dynOptData) -> {
-              Object key = keyExtractor.extractKey(element);
-              if (dynOptData.containsKey(key)) {
-                dynOptData.compute(key, (existingKey, existingCount) -> (long) existingCount + 1L);
-              } else {
-                dynOptData.put(key, 1L);
-              }
-              return dynOptData;
-            };
-
-        // For aggregating the collected data
-        final BiFunction<Pair<Object, Long>, Map<Object, Long>, Map<Object, Long>> dynOptDataAggregator =
-          (BiFunction<Pair<Object, Long>, Map<Object, Long>, Map<Object, Long>> & Serializable)
-            (element, aggregatedDynOptData) -> {
-              final Object key = element.left();
-              final Long count = element.right();
-              if (aggregatedDynOptData.containsKey(key)) {
-                aggregatedDynOptData.compute(key, (existingKey, accumulatedCount) -> accumulatedCount + count);
-              } else {
-                aggregatedDynOptData.put(key, count);
-              }
-              return aggregatedDynOptData;
-            };
-
-        // Coders to use
-        final EncoderProperty encoderProperty = EncoderProperty.of(PairEncoderFactory.
-          of(representativeEdge.getPropertyValue(KeyEncoderProperty.class).get(), LongEncoderFactory.of()));
-        final DecoderProperty decoderProperty = DecoderProperty.of(PairDecoderFactory
-          .of(representativeEdge.getPropertyValue(KeyDecoderProperty.class).get(), LongDecoderFactory.of()));
-
         // Insert the vertices
-        final MessageBarrierVertex mbv = new MessageBarrierVertex<>(dynOptDataCollector);
-        final MessageAggregatorVertex mav = new MessageAggregatorVertex(new HashMap(), dynOptDataAggregator);
-        dag.insert(mbv, mav, encoderProperty, decoderProperty, shuffleEdgeGroup);
+        final MessageBarrierVertex mbv = new MessageBarrierVertex<>(SkewHandlingUtil.getDynOptCollector(keyExtractor));
+        final MessageAggregatorVertex mav =
+          new MessageAggregatorVertex(new HashMap(), SkewHandlingUtil.getDynOptAggregator());
+        dag.insert(mbv, mav, SkewHandlingUtil.getEncoder(representativeEdge),
+          SkewHandlingUtil.getDecoder(representativeEdge), shuffleEdgeGroup, shuffleEdgeGroup);
       }
     });
     return dag;
diff --git a/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/policy/SamplingLargeShuffleSkewPolicy.java b/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/policy/SamplingLargeShuffleSkewPolicy.java
new file mode 100644
index 0000000..84d9470
--- /dev/null
+++ b/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/policy/SamplingLargeShuffleSkewPolicy.java
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.nemo.compiler.optimizer.policy;
+
+import org.apache.nemo.common.ir.IRDAG;
+import org.apache.nemo.compiler.optimizer.pass.compiletime.annotating.DefaultParallelismPass;
+import org.apache.nemo.compiler.optimizer.pass.compiletime.composite.*;
+import org.apache.nemo.compiler.optimizer.pass.compiletime.reshaping.SamplingSkewReshapingPass;
+import org.apache.nemo.compiler.optimizer.pass.runtime.Message;
+import org.apache.nemo.compiler.optimizer.pass.runtime.SkewRunTimePass;
+
+/**
+ * A policy to demonstrate the large shuffle optimization, witch batches disk seek during data shuffle.
+ */
+public final class SamplingLargeShuffleSkewPolicy implements Policy {
+  public static final PolicyBuilder BUILDER =
+    new PolicyBuilder()
+      .registerCompileTimePass(new DefaultParallelismPass())
+      .registerCompileTimePass(new LargeShuffleCompositePass())
+      .registerRunTimePass(new SkewRunTimePass(), new SamplingSkewReshapingPass())
+      .registerCompileTimePass(new LoopOptimizationCompositePass())
+      .registerCompileTimePass(new DefaultCompositePass());
+
+  private final Policy policy;
+
+  /**
+   * Default constructor.
+   */
+  public SamplingLargeShuffleSkewPolicy() {
+    this.policy = BUILDER.build();
+  }
+
+  @Override
+  public IRDAG runCompileTimeOptimization(final IRDAG dag, final String dagDirectory) {
+    return this.policy.runCompileTimeOptimization(dag, dagDirectory);
+  }
+
+  @Override
+  public IRDAG runRunTimeOptimizations(final IRDAG dag, final Message<?> message) {
+    return this.policy.runRunTimeOptimizations(dag, message);
+  }
+}
diff --git a/compiler/test/src/test/java/org/apache/nemo/compiler/backend/nemo/DAGConverterTest.java b/compiler/test/src/test/java/org/apache/nemo/compiler/backend/nemo/DAGConverterTest.java
index 4e0b0bd..95193e1 100644
--- a/compiler/test/src/test/java/org/apache/nemo/compiler/backend/nemo/DAGConverterTest.java
+++ b/compiler/test/src/test/java/org/apache/nemo/compiler/backend/nemo/DAGConverterTest.java
@@ -102,8 +102,8 @@ public final class DAGConverterTest {
     assertEquals(physicalDAG.getOutgoingEdgesOf(physicalStage1).size(), 1);
     assertEquals(physicalDAG.getOutgoingEdgesOf(physicalStage2).size(), 0);
 
-    assertEquals(3, physicalStage1.getParallelism());
-    assertEquals(2, physicalStage2.getParallelism());
+    assertEquals(3, physicalStage1.getTaskIndices().size());
+    assertEquals(2, physicalStage2.getTaskIndices().size());
   }
 
   @Test
diff --git a/examples/beam/src/test/java/org/apache/nemo/examples/beam/PerKeyMedianITCase.java b/examples/beam/src/test/java/org/apache/nemo/examples/beam/PerKeyMedianITCase.java
index 39d8d28..7ee9dd6 100644
--- a/examples/beam/src/test/java/org/apache/nemo/examples/beam/PerKeyMedianITCase.java
+++ b/examples/beam/src/test/java/org/apache/nemo/examples/beam/PerKeyMedianITCase.java
@@ -22,6 +22,7 @@ import org.apache.nemo.client.JobLauncher;
 import org.apache.nemo.common.test.ArgBuilder;
 import org.apache.nemo.common.test.ExampleTestArgs;
 import org.apache.nemo.common.test.ExampleTestUtil;
+import org.apache.nemo.compiler.optimizer.policy.SamplingLargeShuffleSkewPolicy;
 import org.apache.nemo.examples.beam.policy.DataSkewPolicyParallelismFive;
 import org.junit.After;
 import org.junit.Before;
@@ -73,4 +74,16 @@ public final class PerKeyMedianITCase {
         .addOptimizationPolicy(DataSkewPolicyParallelismFive.class.getCanonicalName())
         .build());
   }
+
+  /**
+   * Testing large shuffle and data skew dynamic optimization.
+   * @throws Exception exception on the way.
+   */
+  @Test (timeout = ExampleTestArgs.TIMEOUT)
+  public void testLargeShuffleSamplingSkew() throws Exception {
+    JobLauncher.main(builder
+      .addJobId(PerKeyMedianITCase.class.getSimpleName() + "_LargeShuffleSamplingSkew")
+      .addOptimizationPolicy(SamplingLargeShuffleSkewPolicy.class.getCanonicalName())
+      .build());
+  }
 }
diff --git a/examples/resources/executors/beam_test_one_executor_resources.json b/examples/resources/executors/beam_test_one_executor_resources.json
index 4d6aff4..69f4399 100644
--- a/examples/resources/executors/beam_test_one_executor_resources.json
+++ b/examples/resources/executors/beam_test_one_executor_resources.json
@@ -2,6 +2,6 @@
   {
     "type": "Transient",
     "memory_mb": 512,
-    "capacity": 2
+    "capacity": 15
   }
 ]
diff --git a/runtime/common/src/main/java/org/apache/nemo/runtime/common/plan/PhysicalPlanGenerator.java b/runtime/common/src/main/java/org/apache/nemo/runtime/common/plan/PhysicalPlanGenerator.java
index 242e747..798edeb 100644
--- a/runtime/common/src/main/java/org/apache/nemo/runtime/common/plan/PhysicalPlanGenerator.java
+++ b/runtime/common/src/main/java/org/apache/nemo/runtime/common/plan/PhysicalPlanGenerator.java
@@ -28,6 +28,7 @@ import org.apache.nemo.common.ir.executionproperty.VertexExecutionProperty;
 import org.apache.nemo.common.ir.vertex.*;
 import org.apache.nemo.common.ir.vertex.executionproperty.ParallelismProperty;
 import org.apache.nemo.common.ir.vertex.executionproperty.ScheduleGroupProperty;
+import org.apache.nemo.common.ir.vertex.utility.SamplingVertex;
 import org.apache.nemo.conf.JobConf;
 import org.apache.nemo.common.dag.DAG;
 import org.apache.nemo.common.dag.DAGBuilder;
@@ -42,7 +43,10 @@ import org.slf4j.LoggerFactory;
 
 import javax.inject.Inject;
 import java.util.*;
+import java.util.function.BinaryOperator;
 import java.util.function.Function;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
 
 /**
  * A function that converts an IR DAG to physical DAG.
@@ -80,7 +84,11 @@ public final class PhysicalPlanGenerator implements Function<IRDAG, DAG<Stage, S
     handleDuplicateEdgeGroupProperty(dagOfStages);
 
     // Split StageGroup by Pull StageEdges
-    splitScheduleGroupByPullStageEdges(dagOfStages);
+    //
+    // TODO #337: IRDAG Unit Tests
+    // Move this test to IRDAG unit tests.
+    //
+    // splitScheduleGroupByPullStageEdges(dagOfStages);
 
     // for debugging purposes.
     dagOfStages.storeJSON(dagDirectory, "plan-logical", "logical execution plan");
@@ -136,6 +144,7 @@ public final class PhysicalPlanGenerator implements Function<IRDAG, DAG<Stage, S
     final Map<Integer, Stage> stageIdToStageMap = new HashMap<>();
     final Map<IRVertex, Integer> vertexToStageIdMap = stagePartitioner.apply(irDAG);
     final HashSet<IRVertex> isStagePartitioned = new HashSet<>();
+    final Random random = new Random(hashCode()); // to produce same results for same input IRDAGs
 
     final Map<Integer, Set<IRVertex>> vertexSetForEachStage = new LinkedHashMap<>();
     irDAG.topologicalDo(irVertex -> {
@@ -151,10 +160,9 @@ public final class PhysicalPlanGenerator implements Function<IRDAG, DAG<Stage, S
       final String stageIdentifier = RuntimeIdManager.generateStageId(stageId);
       final ExecutionPropertyMap<VertexExecutionProperty> stageProperties = new ExecutionPropertyMap<>(stageIdentifier);
       stagePartitioner.getStageProperties(stageVertices.iterator().next()).forEach(stageProperties::put);
-
       final int stageParallelism = stageProperties.get(ParallelismProperty.class)
         .orElseThrow(() -> new RuntimeException("Parallelism property must be set for Stage"));
-
+      final List<Integer> taskIndices = getTaskIndicesToExecute(stageVertices, stageParallelism, random);
       final DAGBuilder<IRVertex, RuntimeEdge<IRVertex>> stageInternalDAGBuilder = new DAGBuilder<>();
 
       // Prepare vertexIdToReadables
@@ -164,14 +172,16 @@ public final class PhysicalPlanGenerator implements Function<IRDAG, DAG<Stage, S
       }
 
       // For each IRVertex,
-      for (final IRVertex irVertex : stageVertices) {
+      for (final IRVertex v : stageVertices) {
+        final IRVertex vertexToPutIntoStage = getActualVertexToPutIntoStage(v);
+
         // Take care of the readables of a source vertex.
-        if (irVertex instanceof SourceVertex && !isStagePartitioned.contains(irVertex)) {
-          final SourceVertex sourceVertex = (SourceVertex) irVertex;
+        if (vertexToPutIntoStage instanceof SourceVertex && !isStagePartitioned.contains(vertexToPutIntoStage)) {
+          final SourceVertex sourceVertex = (SourceVertex) vertexToPutIntoStage;
           try {
             final List<Readable> readables = sourceVertex.getReadables(stageParallelism);
             for (int i = 0; i < stageParallelism; i++) {
-              vertexIdToReadables.get(i).put(irVertex.getId(), readables.get(i));
+              vertexIdToReadables.get(i).put(vertexToPutIntoStage.getId(), readables.get(i));
             }
           } catch (final Exception e) {
             throw new PhysicalPlanGenerationException(e);
@@ -181,7 +191,7 @@ public final class PhysicalPlanGenerator implements Function<IRDAG, DAG<Stage, S
         }
 
         // Add vertex to the stage.
-        stageInternalDAGBuilder.addVertex(irVertex);
+        stageInternalDAGBuilder.addVertex(vertexToPutIntoStage);
       }
 
       for (final IRVertex dstVertex : stageVertices) {
@@ -194,8 +204,8 @@ public final class PhysicalPlanGenerator implements Function<IRDAG, DAG<Stage, S
             stageInternalDAGBuilder.connectVertices(new RuntimeEdge<>(
               irEdge.getId(),
               irEdge.getExecutionProperties(),
-              irEdge.getSrc(),
-              irEdge.getDst()));
+              getActualVertexToPutIntoStage(irEdge.getSrc()),
+              getActualVertexToPutIntoStage(irEdge.getDst())));
           } else { // edge comes from another stage
             interStageEdges.add(irEdge);
           }
@@ -205,7 +215,12 @@ public final class PhysicalPlanGenerator implements Function<IRDAG, DAG<Stage, S
       if (!stageInternalDAGBuilder.isEmpty()) {
         final DAG<IRVertex, RuntimeEdge<IRVertex>> stageInternalDAG
           = stageInternalDAGBuilder.buildWithoutSourceSinkCheck();
-        final Stage stage = new Stage(stageIdentifier, stageInternalDAG, stageProperties, vertexIdToReadables);
+        final Stage stage = new Stage(
+          stageIdentifier,
+          taskIndices,
+          stageInternalDAG,
+          stageProperties,
+          vertexIdToReadables);
         dagOfStagesBuilder.addVertex(stage);
         stageIdToStageMap.put(stageId, stage);
       }
@@ -225,13 +240,52 @@ public final class PhysicalPlanGenerator implements Function<IRDAG, DAG<Stage, S
           dstStage == null ? String.format(" destination stage for %s", interStageEdge.getDst()) : ""));
       }
       dagOfStagesBuilder.connectVertices(new StageEdge(interStageEdge.getId(), interStageEdge.getExecutionProperties(),
-        interStageEdge.getSrc(), interStageEdge.getDst(), srcStage, dstStage));
+        getActualVertexToPutIntoStage(interStageEdge.getSrc()), getActualVertexToPutIntoStage(interStageEdge.getDst()),
+        srcStage, dstStage));
     }
 
     return dagOfStagesBuilder.build();
   }
 
   /**
+   * This method is needed, because we do not want to put Sampling vertices into a stage.
+   * The underlying runtime only understands Source and Operator vertices.
+   */
+  private IRVertex getActualVertexToPutIntoStage(final IRVertex irVertex) {
+    return irVertex instanceof SamplingVertex
+      ? ((SamplingVertex) irVertex).getCloneOfOriginalVertex()
+      : irVertex;
+  }
+
+  /**
+   * Randomly select task indices for Sampling vertices.
+   * Select all task indices for non-Sampling vertices.
+   */
+  private List<Integer> getTaskIndicesToExecute(final Set<IRVertex> vertices,
+                                                final int stageParallelism,
+                                                final Random random) {
+    if (vertices.stream().map(v -> v instanceof SamplingVertex).collect(Collectors.toSet()).size() != 1) {
+      throw new IllegalArgumentException("Must be either all sampling vertices, or none: " + vertices.toString());
+    }
+
+    if (vertices.iterator().next() instanceof SamplingVertex) {
+      // Use min of the desired sample rates
+      final float minSampleRate = vertices.stream()
+        .map(v -> ((SamplingVertex) v).getDesiredSampleRate())
+        .reduce(BinaryOperator.minBy(Float::compareTo))
+        .orElseThrow(() -> new IllegalArgumentException(vertices.toString()));
+
+      // Compute and return indices
+      final int numOfTaskIndices = (int) Math.ceil(stageParallelism * minSampleRate);
+      final List<Integer> randomIndices = IntStream.range(0, stageParallelism).boxed().collect(Collectors.toList());
+      Collections.shuffle(randomIndices, random);
+      return new ArrayList<>(randomIndices.subList(0, numOfTaskIndices)); // subList is not serializable.
+    } else {
+      return IntStream.range(0, stageParallelism).boxed().collect(Collectors.toList());
+    }
+  }
+
+  /**
    * Integrity check for Stage.
    * @param stage to check for
    */
@@ -243,7 +297,7 @@ public final class PhysicalPlanGenerator implements Function<IRDAG, DAG<Stage, S
 
     stage.getIRDAG().getVertices().forEach(irVertex -> {
       // Check vertex type.
-      if (!(irVertex instanceof  SourceVertex
+      if (!(irVertex instanceof SourceVertex
         || irVertex instanceof OperatorVertex)) {
         throw new UnsupportedOperationException(irVertex.toString());
       }
diff --git a/runtime/common/src/main/java/org/apache/nemo/runtime/common/plan/Stage.java b/runtime/common/src/main/java/org/apache/nemo/runtime/common/plan/Stage.java
index 600b4ad..070ca62 100644
--- a/runtime/common/src/main/java/org/apache/nemo/runtime/common/plan/Stage.java
+++ b/runtime/common/src/main/java/org/apache/nemo/runtime/common/plan/Stage.java
@@ -39,6 +39,7 @@ import java.util.Optional;
  * Stage.
  */
 public final class Stage extends Vertex {
+  private final List<Integer> taskIndices;
   private final DAG<IRVertex, RuntimeEdge<IRVertex>> irDag;
   private final byte[] serializedIRDag;
   private final List<Map<String, Readable>> vertexIdToReadables;
@@ -49,15 +50,18 @@ public final class Stage extends Vertex {
    * Constructor.
    *
    * @param stageId             ID of the stage.
+   * @param taskIndices         indices of the tasks to execute.
    * @param irDag               the DAG of the task in this stage.
    * @param executionProperties set of {@link VertexExecutionProperty} for this stage
    * @param vertexIdToReadables the list of maps between vertex ID and {@link Readable}.
    */
   public Stage(final String stageId,
+               final List<Integer> taskIndices,
                final DAG<IRVertex, RuntimeEdge<IRVertex>> irDag,
                final ExecutionPropertyMap<VertexExecutionProperty> executionProperties,
                final List<Map<String, Readable>> vertexIdToReadables) {
     super(stageId);
+    this.taskIndices = taskIndices;
     this.irDag = irDag;
     this.serializedIRDag = SerializationUtils.serialize(irDag);
     this.executionProperties = executionProperties;
@@ -79,11 +83,20 @@ public final class Stage extends Vertex {
   }
 
   /**
-   * @return the parallelism
+   * @return task indices of this stage to execute.
+   * For non-sampling vertices, returns [0, 1, 2, ..., parallelism-1].
+   * For sampling vertices, returns a list of size (parallelism * samplingRate).
+   */
+  public List<Integer> getTaskIndices() {
+    return taskIndices;
+  }
+
+  /**
+   * @return the parallelism.
    */
   public int getParallelism() {
     return executionProperties.get(ParallelismProperty.class)
-        .orElseThrow(() -> new RuntimeException("Parallelism property must be set for Stage"));
+      .orElseThrow(() -> new RuntimeException("Parallelism property must be set for Stage"));
   }
 
   /**
@@ -133,6 +146,7 @@ public final class Stage extends Vertex {
     node.put("scheduleGroup", getScheduleGroup());
     node.set("irDag", irDag.asJsonNode());
     node.put("parallelism", getParallelism());
+    node.put("num of task indices", getTaskIndices().size());
     node.set("executionProperties", executionProperties.asJsonNode());
     return node;
   }
diff --git a/runtime/common/src/main/java/org/apache/nemo/runtime/common/plan/StageEdge.java b/runtime/common/src/main/java/org/apache/nemo/runtime/common/plan/StageEdge.java
index f73cc75..04c0f74 100644
--- a/runtime/common/src/main/java/org/apache/nemo/runtime/common/plan/StageEdge.java
+++ b/runtime/common/src/main/java/org/apache/nemo/runtime/common/plan/StageEdge.java
@@ -158,14 +158,12 @@ public final class StageEdge extends RuntimeEdge<Stage> {
    * @return {@link org.apache.nemo.common.ir.edge.executionproperty.PartitionSetProperty} value.
    */
   public List<KeyRange> getKeyRanges() {
-    final ArrayList<KeyRange> defaultPartitionSet = new ArrayList<>(getDst().getParallelism());
-    for (int taskIdx = 0; taskIdx < getDst().getParallelism(); taskIdx++) {
-      defaultPartitionSet.add(taskIdx, HashRange.of(taskIdx, taskIdx + 1));
+    final ArrayList<KeyRange> defaultPartitionSet = new ArrayList<>();
+    for (int taskIndex = 0; taskIndex <  getDst().getParallelism(); taskIndex++) {
+      defaultPartitionSet.add(taskIndex, HashRange.of(taskIndex, taskIndex + 1));
     }
     final List<KeyRange> keyRanges = getExecutionProperties()
       .get(PartitionSetProperty.class).orElse(defaultPartitionSet);
-    LOG.info("{} -> {} getKeyRanges {}", srcVertex.getId(), dstVertex.getId(), keyRanges);
-
     return keyRanges;
   }
 }
diff --git a/runtime/common/src/test/java/org/apache/nemo/runtime/common/plan/PhysicalPlanGeneratorTest.java b/runtime/common/src/test/java/org/apache/nemo/runtime/common/plan/PhysicalPlanGeneratorTest.java
index 7e6dcbe..40f2c2b 100644
--- a/runtime/common/src/test/java/org/apache/nemo/runtime/common/plan/PhysicalPlanGeneratorTest.java
+++ b/runtime/common/src/test/java/org/apache/nemo/runtime/common/plan/PhysicalPlanGeneratorTest.java
@@ -30,7 +30,6 @@ import org.apache.nemo.common.ir.vertex.executionproperty.ParallelismProperty;
 import org.apache.nemo.common.ir.vertex.executionproperty.ScheduleGroupProperty;
 import org.apache.reef.tang.Injector;
 import org.apache.reef.tang.Tang;
-import org.junit.Test;
 
 import java.util.Iterator;
 
@@ -45,8 +44,10 @@ public final class PhysicalPlanGeneratorTest {
   /**
    * Test splitting ScheduleGroups by Pull StageEdges.
    * @throws Exception exceptions on the way
+   *
+   * TODO #337: IRDAG Unit Tests
+   * Move this test to IRDAG unit tests.
    */
-  @Test
   public void testSplitScheduleGroupByPullStageEdges() throws Exception {
     final Injector injector = Tang.Factory.getTang().newInjector();
     final PhysicalPlanGenerator physicalPlanGenerator = injector.getInstance(PhysicalPlanGenerator.class);
diff --git a/runtime/executor/src/test/java/org/apache/nemo/runtime/executor/TestUtil.java b/runtime/executor/src/test/java/org/apache/nemo/runtime/executor/TestUtil.java
index 9b752c8..b1832ec 100644
--- a/runtime/executor/src/test/java/org/apache/nemo/runtime/executor/TestUtil.java
+++ b/runtime/executor/src/test/java/org/apache/nemo/runtime/executor/TestUtil.java
@@ -26,9 +26,9 @@ import java.util.List;
 
 public final class TestUtil {
   public static List<String> generateTaskIds(final Stage stage) {
-    final List<String> result = new ArrayList<>(stage.getParallelism());
+    final List<String> result = new ArrayList<>();
     final int first_attempt = 0;
-    for (int taskIndex = 0; taskIndex < stage.getParallelism(); taskIndex++) {
+    for (final int taskIndex : stage.getTaskIndices()) {
       result.add(RuntimeIdManager.generateTaskId(stage.getId(), taskIndex, first_attempt));
     }
     return result;
diff --git a/runtime/executor/src/test/java/org/apache/nemo/runtime/executor/datatransfer/DataTransferTest.java b/runtime/executor/src/test/java/org/apache/nemo/runtime/executor/datatransfer/DataTransferTest.java
index 2a3d390..88884eb 100644
--- a/runtime/executor/src/test/java/org/apache/nemo/runtime/executor/datatransfer/DataTransferTest.java
+++ b/runtime/executor/src/test/java/org/apache/nemo/runtime/executor/datatransfer/DataTransferTest.java
@@ -520,7 +520,12 @@ public final class DataTransferTest {
     final ExecutionPropertyMap<VertexExecutionProperty> stageExecutionProperty = new ExecutionPropertyMap<>(stageId);
     stageExecutionProperty.put(ParallelismProperty.of(PARALLELISM_TEN));
     stageExecutionProperty.put(ScheduleGroupProperty.of(0));
-    return new Stage(stageId, emptyDag, stageExecutionProperty, Collections.emptyList());
+    return new Stage(
+      stageId,
+      IntStream.range(0, PARALLELISM_TEN).boxed().collect(Collectors.toList()),
+      emptyDag,
+      stageExecutionProperty,
+      Collections.emptyList());
   }
 
   /**
diff --git a/runtime/master/src/main/java/org/apache/nemo/runtime/master/PlanStateManager.java b/runtime/master/src/main/java/org/apache/nemo/runtime/master/PlanStateManager.java
index 88f3686..702e9ea 100644
--- a/runtime/master/src/main/java/org/apache/nemo/runtime/master/PlanStateManager.java
+++ b/runtime/master/src/main/java/org/apache/nemo/runtime/master/PlanStateManager.java
@@ -44,7 +44,6 @@ import org.apache.nemo.runtime.common.state.TaskState;
 import org.apache.nemo.runtime.common.metric.JobMetric;
 import org.apache.nemo.runtime.common.metric.StageMetric;
 import org.apache.nemo.runtime.common.metric.TaskMetric;
-import org.apache.nemo.runtime.master.metric.MetricMessageHandler;
 import org.apache.nemo.runtime.master.metric.MetricStore;
 import org.apache.reef.annotations.audience.DriverSide;
 import org.apache.reef.tang.annotations.Parameter;
@@ -77,7 +76,9 @@ public final class PlanStateManager {
    */
   private PlanState planState;
   private final Map<String, StageState> stageIdToState;
-  private final Map<String, List<List<TaskState>>> stageIdToTaskAttemptStates; // sorted by task idx, and then attempt
+
+  // list of attempt states sorted by attempt idx
+  private final Map<String, Map<Integer, List<TaskState>>> stageIdToTaskIdxToAttemptStates;
 
   /**
    * Used for speculative cloning. (in the unit of milliseconds - ms)
@@ -101,21 +102,16 @@ public final class PlanStateManager {
    * For metrics.
    */
   private final String dagDirectory;
-  private final MetricMessageHandler metricMessageHandler;
   private MetricStore metricStore;
 
   /**
    * Constructor.
-   *
-   * @param metricMessageHandler the metric handler for the plan.
    */
   @Inject
-  private PlanStateManager(@Parameter(JobConf.DAGDirectory.class) final String dagDirectory,
-                           final MetricMessageHandler metricMessageHandler) {
-    this.metricMessageHandler = metricMessageHandler;
+  private PlanStateManager(@Parameter(JobConf.DAGDirectory.class) final String dagDirectory) {
     this.planState = new PlanState();
     this.stageIdToState = new HashMap<>();
-    this.stageIdToTaskAttemptStates = new HashMap<>();
+    this.stageIdToTaskIdxToAttemptStates = new HashMap<>();
     this.finishLock = new ReentrantLock();
     this.planFinishedCondition = finishLock.newCondition();
     this.dagDirectory = dagDirectory;
@@ -155,11 +151,11 @@ public final class PlanStateManager {
     physicalPlan.getStageDAG().topologicalDo(stage -> {
       if (!stageIdToState.containsKey(stage.getId())) {
         stageIdToState.put(stage.getId(), new StageState());
-        stageIdToTaskAttemptStates.put(stage.getId(), new ArrayList<>(stage.getParallelism()));
+        stageIdToTaskIdxToAttemptStates.put(stage.getId(), new HashMap<>());
 
         // for each task idx of this stage
-        for (int taskIndex = 0; taskIndex < stage.getParallelism(); taskIndex++) {
-          stageIdToTaskAttemptStates.get(stage.getId()).add(new ArrayList<>());
+        for (final int taskIndex : stage.getTaskIndices()) {
+          stageIdToTaskIdxToAttemptStates.get(stage.getId()).put(taskIndex, new ArrayList<>());
           // task states will be initialized lazily in getTaskAttemptsToSchedule()
         }
       }
@@ -184,9 +180,9 @@ public final class PlanStateManager {
     // For each task index....
     final List<String> taskAttemptsToSchedule = new ArrayList<>();
     final Stage stage = physicalPlan.getStageDAG().getVertexById(stageId);
-    for (int taskIndex = 0; taskIndex < stage.getParallelism(); taskIndex++) {
+    for (final int taskIndex : stage.getTaskIndices()) {
       final List<TaskState> attemptStatesForThisTaskIndex =
-        stageIdToTaskAttemptStates.get(stageId).get(taskIndex);
+        stageIdToTaskIdxToAttemptStates.get(stageId).get(taskIndex);
 
       // If one of the attempts is COMPLETE, do not schedule
       if (attemptStatesForThisTaskIndex
@@ -247,9 +243,9 @@ public final class PlanStateManager {
     final long curTime = System.currentTimeMillis();
     final Map<String, Long> result = new HashMap<>();
 
-    final List<List<TaskState>> taskStates = stageIdToTaskAttemptStates.get(stageId);
-    for (int taskIndex = 0; taskIndex < taskStates.size(); taskIndex++) {
-      final List<TaskState> attemptStates = taskStates.get(taskIndex);
+    final Map<Integer, List<TaskState>> taskIdToState = stageIdToTaskIdxToAttemptStates.get(stageId);
+    for (final int taskIndex : taskIdToState.keySet()) {
+      final List<TaskState> attemptStates = taskIdToState.get(taskIndex);
       for (int attempt = 0; attempt < attemptStates.size(); attempt++) {
         if (TaskState.State.EXECUTING.equals(attemptStates.get(attempt).getStateMachine().getCurrentState())) {
           final String taskId = RuntimeIdManager.generateTaskId(stageId, taskIndex, attempt);
@@ -317,8 +313,8 @@ public final class PlanStateManager {
 
     // Log not-yet-completed tasks for us humans to track progress
     final String stageId = RuntimeIdManager.getStageIdFromTaskId(taskId);
-    final List<List<TaskState>> taskStatesOfThisStage = stageIdToTaskAttemptStates.get(stageId);
-    final long numOfCompletedTaskIndicesInThisStage = taskStatesOfThisStage.stream()
+    final Map<Integer, List<TaskState>> taskStatesOfThisStage = stageIdToTaskIdxToAttemptStates.get(stageId);
+    final long numOfCompletedTaskIndicesInThisStage = taskStatesOfThisStage.values().stream()
       .filter(attempts -> {
         final List<TaskState.State> states = attempts
           .stream()
@@ -357,7 +353,7 @@ public final class PlanStateManager {
       case COMPLETE:
       case ON_HOLD:
         if (numOfCompletedTaskIndicesInThisStage
-          == physicalPlan.getStageDAG().getVertexById(stageId).getParallelism()) {
+          == physicalPlan.getStageDAG().getVertexById(stageId).getTaskIndices().size()) {
           onStageStateChanged(stageId, StageState.State.COMPLETE);
         }
         break;
@@ -540,9 +536,9 @@ public final class PlanStateManager {
 
   private Map<String, TaskState.State> getTaskAttemptIdsToItsState(final String stageId) {
     final Map<String, TaskState.State> result = new HashMap<>();
-    final List<List<TaskState>> taskStates = stageIdToTaskAttemptStates.get(stageId);
-    for (int taskIndex = 0; taskIndex < taskStates.size(); taskIndex++) {
-      final List<TaskState> attemptStates = taskStates.get(taskIndex);
+    final Map<Integer, List<TaskState>> taskIdToState = stageIdToTaskIdxToAttemptStates.get(stageId);
+    for (final int taskIndex : taskIdToState.keySet()) {
+      final List<TaskState> attemptStates = taskIdToState.get(taskIndex);
       for (int attempt = 0; attempt < attemptStates.size(); attempt++) {
         result.put(RuntimeIdManager.generateTaskId(stageId, taskIndex, attempt),
           (TaskState.State) attemptStates.get(attempt).getStateMachine().getCurrentState());
@@ -552,7 +548,7 @@ public final class PlanStateManager {
   }
 
   private TaskState getTaskStateHelper(final String taskId) {
-    return stageIdToTaskAttemptStates
+    return stageIdToTaskIdxToAttemptStates
       .get(RuntimeIdManager.getStageIdFromTaskId(taskId))
       .get(RuntimeIdManager.getIndexFromTaskId(taskId))
       .get(RuntimeIdManager.getAttemptFromTaskId(taskId));
@@ -571,7 +567,7 @@ public final class PlanStateManager {
     final int attempt = RuntimeIdManager.getAttemptFromTaskId(taskId);
 
     final List<TaskState> otherAttemptsforTheSameTaskIndex =
-      new ArrayList<>(stageIdToTaskAttemptStates.get(stageId).get(taskIndex));
+      new ArrayList<>(stageIdToTaskIdxToAttemptStates.get(stageId).get(taskIndex));
     otherAttemptsforTheSameTaskIndex.remove(attempt);
 
     return otherAttemptsforTheSameTaskIndex.stream()
diff --git a/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/BatchScheduler.java b/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/BatchScheduler.java
index bcf0976..26bbb27 100644
--- a/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/BatchScheduler.java
+++ b/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/BatchScheduler.java
@@ -22,9 +22,10 @@ import com.google.common.collect.Sets;
 import org.apache.nemo.common.Pair;
 import org.apache.nemo.common.dag.DAG;
 import org.apache.nemo.common.ir.Readable;
-import org.apache.nemo.common.ir.edge.executionproperty.MessageIdProperty;
+import org.apache.nemo.common.ir.edge.executionproperty.MessageIdEdgeProperty;
 import org.apache.nemo.common.ir.vertex.executionproperty.ClonedSchedulingProperty;
 import org.apache.nemo.common.ir.vertex.executionproperty.IgnoreSchedulingTempDataReceiverProperty;
+import org.apache.nemo.common.ir.vertex.executionproperty.MessageIdVertexProperty;
 import org.apache.nemo.runtime.common.RuntimeIdManager;
 import org.apache.nemo.runtime.common.plan.*;
 import org.apache.nemo.runtime.common.state.BlockState;
@@ -163,7 +164,7 @@ public final class BatchScheduler implements Scheduler {
 
   private int getMessageId(final Set<StageEdge> stageEdges) {
     final Set<Integer> messageIds = stageEdges.stream()
-      .map(edge -> edge.getExecutionProperties().get(MessageIdProperty.class).get())
+      .map(edge -> edge.getExecutionProperties().get(MessageIdEdgeProperty.class).get())
       .collect(Collectors.toSet());
     if (messageIds.size() != 1) {
       throw new IllegalArgumentException(stageEdges.toString());
@@ -288,13 +289,12 @@ public final class BatchScheduler implements Scheduler {
         stage.getPropertyValue(ClonedSchedulingProperty.class).ifPresent(cloneConf -> {
           if (!cloneConf.isUpFrontCloning()) { // Upfront cloning is already handled.
             final double fractionToWaitFor = cloneConf.getFractionToWaitFor();
-            final int parallelism = stage.getParallelism();
             final Object[] completedTaskTimes = planStateManager.getCompletedTaskTimeListMs(stageId).toArray();
 
             // Only after the fraction of the tasks are done...
             // Delayed cloning (aggressive)
             if (completedTaskTimes.length > 0
-              && completedTaskTimes.length >= Math.round(parallelism * fractionToWaitFor)) {
+              && completedTaskTimes.length >= Math.round(stage.getTaskIndices().size() * fractionToWaitFor)) {
               Arrays.sort(completedTaskTimes);
               final long medianTime = (long) completedTaskTimes[completedTaskTimes.length / 2];
               final double medianTimeMultiplier = cloneConf.getMedianTimeMultiplier();
@@ -465,7 +465,7 @@ public final class BatchScheduler implements Scheduler {
 
   /**
    * Get the target edges of dynamic optimization.
-   * The edges are annotated with {@link MessageIdProperty}, which are outgoing edges of
+   * The edges are annotated with {@link MessageIdEdgeProperty}, which are outgoing edges of
    * parents of the stage put on hold.
    *
    * See {@link org.apache.nemo.compiler.optimizer.pass.compiletime.reshaping.SkewReshapingPass}
@@ -485,23 +485,25 @@ public final class BatchScheduler implements Scheduler {
 
     // Stage put on hold, i.e. stage with vertex containing MessageAggregatorTransform
     // should have a parent stage whose outgoing edges contain the target edge of dynamic optimization.
-    final List<StageEdge> edgesToStagePutOnHold = stageDag.getIncomingEdgesOf(stagePutOnHold);
-    if (edgesToStagePutOnHold.isEmpty()) {
-      throw new RuntimeException("No edges toward specified put on hold stage");
+    final List<Integer> messageIds = stagePutOnHold.getIRDAG()
+      .getVertices()
+      .stream()
+      .filter(v -> v.getPropertyValue(MessageIdVertexProperty.class).isPresent())
+      .map(v -> v.getPropertyValue(MessageIdVertexProperty.class).get())
+      .collect(Collectors.toList());
+    if (messageIds.size() != 1) {
+      throw new IllegalStateException("Must be exactly one vertex with the message id: " + messageIds.toString());
     }
-    final int messageId = edgesToStagePutOnHold.get(0).getPropertyValue(MessageIdProperty.class)
-      .orElseThrow(() -> new RuntimeException("No message id for this put on hold stage"));
-
+    final int messageId = messageIds.get(0);
     final Set<StageEdge> targetEdges = new HashSet<>();
 
-    // Get edges with identical MessageIdProperty (except the put on hold stage)
+    // Get edges with identical MessageIdEdgeProperty (except the put on hold stage)
     for (final Stage stage : stageDag.getVertices()) {
       final Set<StageEdge> targetEdgesFound = stageDag.getOutgoingEdgesOf(stage).stream()
         .filter(candidateEdge -> {
           final Optional<Integer> candidateMCId =
-            candidateEdge.getPropertyValue(MessageIdProperty.class);
-          return candidateMCId.isPresent() && candidateMCId.get().equals(messageId)
-            && !edgesToStagePutOnHold.contains(candidateEdge);
+            candidateEdge.getPropertyValue(MessageIdEdgeProperty.class);
+          return candidateMCId.isPresent() && candidateMCId.get().equals(messageId);
         })
         .collect(Collectors.toSet());
       targetEdges.addAll(targetEdgesFound);