You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tez.apache.org by hi...@apache.org on 2014/02/22 21:13:35 UTC

git commit: TEZ-769. Change Vertex.setParallelism() to accept a set of EdgeManagerDescriptors. (hitesh)

Repository: incubator-tez
Updated Branches:
  refs/heads/master 238255b69 -> 649abcbd0


TEZ-769. Change Vertex.setParallelism() to accept a set of EdgeManagerDescriptors. (hitesh)


Project: http://git-wip-us.apache.org/repos/asf/incubator-tez/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-tez/commit/649abcbd
Tree: http://git-wip-us.apache.org/repos/asf/incubator-tez/tree/649abcbd
Diff: http://git-wip-us.apache.org/repos/asf/incubator-tez/diff/649abcbd

Branch: refs/heads/master
Commit: 649abcbd0f2ca1fceb10f4c2811b1124243698ac
Parents: 238255b
Author: Hitesh Shah <hi...@apache.org>
Authored: Sat Feb 22 12:13:05 2014 -0800
Committer: Hitesh Shah <hi...@apache.org>
Committed: Sat Feb 22 12:13:05 2014 -0800

----------------------------------------------------------------------
 .../tez/dag/api/VertexManagerPluginContext.java |  6 +-
 .../java/org/apache/tez/dag/app/dag/Vertex.java |  4 +-
 .../org/apache/tez/dag/app/dag/impl/Edge.java   | 72 ++++++++++------
 .../apache/tez/dag/app/dag/impl/VertexImpl.java | 40 ++++++---
 .../tez/dag/app/dag/impl/VertexManager.java     |  4 +-
 .../tez/dag/app/dag/impl/TestVertexImpl.java    | 45 +++++++---
 .../org/apache/tez/test/EdgeManagerForTest.java | 18 ++--
 .../vertexmanager/ShuffleVertexManager.java     | 88 ++++++++++++++++----
 .../src/main/proto/ShufflePayloads.proto        |  7 ++
 .../vertexmanager/TestShuffleVertexManager.java | 31 ++++++-
 10 files changed, 230 insertions(+), 85 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-tez/blob/649abcbd/tez-api/src/main/java/org/apache/tez/dag/api/VertexManagerPluginContext.java
----------------------------------------------------------------------
diff --git a/tez-api/src/main/java/org/apache/tez/dag/api/VertexManagerPluginContext.java b/tez-api/src/main/java/org/apache/tez/dag/api/VertexManagerPluginContext.java
index f3ca5ef..0e21a92 100644
--- a/tez-api/src/main/java/org/apache/tez/dag/api/VertexManagerPluginContext.java
+++ b/tez-api/src/main/java/org/apache/tez/dag/api/VertexManagerPluginContext.java
@@ -64,11 +64,11 @@ public interface VertexManagerPluginContext {
    * disallowed
    * @param parallelism New number of tasks in the vertex
    * @param locationHint the placement policy for tasks.
-   * @param sourceEdgeManagers
+   * @param sourceEdgeManagers Edge Managers to be updated
    * @return true if the operation was allowed.
    */
   public boolean setVertexParallelism(int parallelism, VertexLocationHint locationHint,
-      Map<String, EdgeManager> sourceEdgeManagers);
+      Map<String, EdgeManagerDescriptor> sourceEdgeManagers);
   
   /**
    * Allows a VertexManagerPlugin to assign Events for Root Inputs
@@ -83,7 +83,7 @@ public interface VertexManagerPluginContext {
    *          the Vertex. The target index on individual events represents the
    *          task to which events need to be sent.
    */
-  public void addRootInputEvents(String inputName, Collection<RootInputDataInformationEvent> event);
+  public void addRootInputEvents(String inputName, Collection<RootInputDataInformationEvent> events);
   
   /**
    * Notify the vertex to start the given tasks

http://git-wip-us.apache.org/repos/asf/incubator-tez/blob/649abcbd/tez-dag/src/main/java/org/apache/tez/dag/app/dag/Vertex.java
----------------------------------------------------------------------
diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/Vertex.java b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/Vertex.java
index 5157401..9e7a0a7 100644
--- a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/Vertex.java
+++ b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/Vertex.java
@@ -24,7 +24,7 @@ import java.util.Set;
 
 import org.apache.hadoop.yarn.api.records.Resource;
 import org.apache.tez.common.counters.TezCounters;
-import org.apache.tez.dag.api.EdgeManager;
+import org.apache.tez.dag.api.EdgeManagerDescriptor;
 import org.apache.tez.dag.api.InputDescriptor;
 import org.apache.tez.dag.api.OutputDescriptor;
 import org.apache.tez.dag.api.ProcessorDescriptor;
@@ -77,7 +77,7 @@ public interface Vertex extends Comparable<Vertex> {
 
 
   boolean setParallelism(int parallelism, VertexLocationHint vertexLocationHint,
-      Map<String, EdgeManager> sourceEdgeManagers);
+      Map<String, EdgeManagerDescriptor> sourceEdgeManagers);
   void setVertexLocationHint(VertexLocationHint vertexLocationHint);
 
   // CHANGE THESE TO LISTS AND MAINTAIN ORDER?

http://git-wip-us.apache.org/repos/asf/incubator-tez/blob/649abcbd/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/Edge.java
----------------------------------------------------------------------
diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/Edge.java b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/Edge.java
index 55ab86f..7b4a120 100644
--- a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/Edge.java
+++ b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/Edge.java
@@ -26,6 +26,7 @@ import java.util.concurrent.atomic.AtomicBoolean;
 import org.apache.hadoop.yarn.event.EventHandler;
 import org.apache.tez.dag.api.EdgeManager;
 import org.apache.tez.dag.api.EdgeManagerContext;
+import org.apache.tez.dag.api.EdgeManagerDescriptor;
 import org.apache.tez.dag.api.EdgeProperty;
 import org.apache.tez.dag.api.TezUncheckedException;
 import org.apache.tez.dag.api.EdgeProperty.DataMovementType;
@@ -95,24 +96,28 @@ public class Edge {
   public Edge(EdgeProperty edgeProperty, EventHandler eventHandler) {
     this.edgeProperty = edgeProperty;
     this.eventHandler = eventHandler;
+    createEdgeManager();
+  }
+
+  private void createEdgeManager() {
     switch (edgeProperty.getDataMovementType()) {
-    case ONE_TO_ONE:
-      edgeManager = new OneToOneEdgeManager();
-      break;
-    case BROADCAST:
-      edgeManager = new BroadcastEdgeManager();
-      break;
-    case SCATTER_GATHER:
-      edgeManager = new ScatterGatherEdgeManager();
-      break;
-    case CUSTOM:
-      String edgeManagerClassName = edgeProperty.getEdgeManagerDescriptor().getClassName();
-      edgeManager = RuntimeUtils.createClazzInstance(edgeManagerClassName);
-      break;
-    default:
-      String message = "Unknown edge data movement type: "
-          + edgeProperty.getDataMovementType();
-      throw new TezUncheckedException(message);
+      case ONE_TO_ONE:
+        edgeManager = new OneToOneEdgeManager();
+        break;
+      case BROADCAST:
+        edgeManager = new BroadcastEdgeManager();
+        break;
+      case SCATTER_GATHER:
+        edgeManager = new ScatterGatherEdgeManager();
+        break;
+      case CUSTOM:
+        String edgeManagerClassName = edgeProperty.getEdgeManagerDescriptor().getClassName();
+        edgeManager = RuntimeUtils.createClazzInstance(edgeManagerClassName);
+        break;
+      default:
+        String message = "Unknown edge data movement type: "
+            + edgeProperty.getDataMovementType();
+        throw new TezUncheckedException(message);
     }
   }
 
@@ -130,6 +135,18 @@ public class Edge {
         null);
   }
 
+  public synchronized void setCustomEdgeManager(EdgeManagerDescriptor descriptor) {
+    EdgeProperty modifiedEdgeProperty =
+        new EdgeProperty(descriptor,
+            edgeProperty.getDataSourceType(),
+            edgeProperty.getSchedulingType(),
+            edgeProperty.getEdgeSource(),
+            edgeProperty.getEdgeDestination());
+    this.edgeProperty = modifiedEdgeProperty;
+    createEdgeManager();
+    initialize();
+  }
+
   public EdgeProperty getEdgeProperty() {
     return this.edgeProperty;
   }
@@ -137,15 +154,7 @@ public class Edge {
   public EdgeManager getEdgeManager() {
     return this.edgeManager;
   }
-  
-  public void setEdgeManager(EdgeManager edgeManager) {
-    if(edgeManager == null) {
-      throw new TezUncheckedException("Edge manager cannot be null");
-    }
-    this.edgeManager = edgeManager;
-    this.edgeManager.initialize(edgeManagerContext);
-  }
-  
+
   public void setSourceVertex(Vertex sourceVertex) {
     if (this.sourceVertex != null && this.sourceVertex != sourceVertex) {
       throw new TezUncheckedException("Source vertex exists: "
@@ -173,7 +182,7 @@ public class Edge {
   public OutputSpec getSourceSpec(int sourceTaskIndex) {
     return new OutputSpec(destinationVertex.getName(),
         edgeProperty.getEdgeSource(), edgeManager.getNumSourceTaskPhysicalOutputs(
-            destinationVertex.getTotalTasks(), sourceTaskIndex));
+        destinationVertex.getTotalTasks(), sourceTaskIndex));
   }
   
   public void startEventBuffering() {
@@ -335,4 +344,13 @@ public class Edge {
   private void sendEventToTask(TezTaskID taskId, TezEvent tezEvent) {
     eventHandler.handle(new TaskEventAddTezEvent(taskId, tezEvent));
   }
+
+  public String getSourceVertexName() {
+    return this.sourceVertex.getName();
+  }
+
+  public String getDestinationVertexName() {
+    return this.destinationVertex.getName();
+  }
+
 }

http://git-wip-us.apache.org/repos/asf/incubator-tez/blob/649abcbd/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/VertexImpl.java
----------------------------------------------------------------------
diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/VertexImpl.java b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/VertexImpl.java
index 02b602f..5ec55ee 100644
--- a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/VertexImpl.java
+++ b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/VertexImpl.java
@@ -52,6 +52,9 @@ import org.apache.hadoop.yarn.state.StateMachineFactory;
 import org.apache.hadoop.yarn.util.Clock;
 import org.apache.tez.common.counters.TezCounters;
 import org.apache.tez.dag.api.DagTypeConverters;
+import org.apache.tez.dag.api.EdgeManagerContext;
+import org.apache.tez.dag.api.EdgeManagerDescriptor;
+import org.apache.tez.dag.api.EdgeProperty;
 import org.apache.tez.dag.api.EdgeProperty.DataMovementType;
 import org.apache.tez.dag.api.EdgeManager;
 import org.apache.tez.dag.api.InputDescriptor;
@@ -817,7 +820,7 @@ public class VertexImpl implements org.apache.tez.dag.app.dag.Vertex,
 
   @Override
   public boolean setParallelism(int parallelism, VertexLocationHint vertexLocationHint,
-      Map<String, EdgeManager> sourceEdgeManagers) {
+      Map<String, EdgeManagerDescriptor> sourceEdgeManagers) {
     writeLock.lock();
     setVertexLocationHint(vertexLocationHint);
     try {
@@ -836,21 +839,29 @@ public class VertexImpl implements org.apache.tez.dag.app.dag.Vertex,
             " parallelism set to " + parallelism);
 
         if(sourceEdgeManagers != null) {
-          for(Map.Entry<String, EdgeManager> entry : sourceEdgeManagers.entrySet()) {
+          for(Map.Entry<String, EdgeManagerDescriptor> entry : sourceEdgeManagers.entrySet()) {
             LOG.info("Replacing edge manager for source:"
                 + entry.getKey() + " destination: " + getVertexId());
             Vertex sourceVertex = appContext.getCurrentDAG().getVertex(entry.getKey());
-            EdgeManager edgeManager = entry.getValue();
             Edge edge = sourceVertices.get(sourceVertex);
-            edge.setEdgeManager(edgeManager);
+            try {
+              edge.setCustomEdgeManager(entry.getValue());
+            } catch (Exception e) {
+              LOG.warn("Failed to initialize edge manager for edge"
+                  + ", sourceVertexName=" + sourceVertex.getName()
+                  + ", destinationVertexName=" + edge.getDestinationVertexName(),
+                  e);
+              return false;
+            }
           }
         }
       } else {
         if (parallelism >= numTasks) {
           // not that hard to support perhaps. but checking right now since there
           // is no use case for it and checking may catch other bugs.
-          throw new TezUncheckedException(
-              "Increasing parallelism is not supported");
+          LOG.warn("Increasing parallelism is not supported, vertexId="
+              + logIdentifier);
+          return false;
         }
         if (parallelism == numTasks) {
           LOG.info("setParallelism same as current value: " + parallelism);
@@ -881,9 +892,10 @@ public class VertexImpl implements org.apache.tez.dag.app.dag.Vertex,
           Map.Entry<TezTaskID, Task> entry = iter.next();
           Task task = entry.getValue();
           if (task.getState() != TaskState.NEW) {
-            throw new TezUncheckedException(
+            LOG.warn(
                 "All tasks must be in initial state when changing parallelism"
                     + " for vertex: " + getVertexId() + " name: " + getName());
+            return false;
           }
           pendingEvents.addAll(task.getAndClearTaskTezEvents());
           if (i <= parallelism) {
@@ -897,13 +909,21 @@ public class VertexImpl implements org.apache.tez.dag.app.dag.Vertex,
   
         // set new edge managers
         if(sourceEdgeManagers != null) {
-          for(Map.Entry<String, EdgeManager> entry : sourceEdgeManagers.entrySet()) {
+          for(Map.Entry<String, EdgeManagerDescriptor> entry : sourceEdgeManagers.entrySet()) {
             LOG.info("Replacing edge manager for source:"
                 + entry.getKey() + " destination: " + getVertexId());
             Vertex sourceVertex = appContext.getCurrentDAG().getVertex(entry.getKey());
-            EdgeManager edgeManager = entry.getValue();
             Edge edge = sourceVertices.get(sourceVertex);
-            edge.setEdgeManager(edgeManager);
+            EdgeProperty edgeProperty = edge.getEdgeProperty();
+            try {
+              edge.setCustomEdgeManager(entry.getValue());
+            } catch (Exception e) {
+              LOG.warn("Failed to initialize edge manager for edge"
+                  + ", sourceVertexName=" + sourceVertex.getName()
+                  + ", destinationVertexName=" + edge.getDestinationVertexName(),
+                  e);
+              return false;
+            }
           }
         }
   

http://git-wip-us.apache.org/repos/asf/incubator-tez/blob/649abcbd/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/VertexManager.java
----------------------------------------------------------------------
diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/VertexManager.java b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/VertexManager.java
index d45e77b..df7696b 100644
--- a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/VertexManager.java
+++ b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/VertexManager.java
@@ -27,7 +27,7 @@ import java.util.Set;
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.apache.tez.common.TezUtils;
-import org.apache.tez.dag.api.EdgeManager;
+import org.apache.tez.dag.api.EdgeManagerDescriptor;
 import org.apache.tez.dag.api.EdgeProperty;
 import org.apache.tez.dag.api.InputDescriptor;
 import org.apache.tez.dag.api.TezUncheckedException;
@@ -95,7 +95,7 @@ public class VertexManager {
 
     @Override
     public boolean setVertexParallelism(int parallelism, VertexLocationHint vertexLocationHint,
-        Map<String, EdgeManager> sourceEdgeManagers) {
+        Map<String, EdgeManagerDescriptor> sourceEdgeManagers) {
       return managedVertex.setParallelism(parallelism, vertexLocationHint, sourceEdgeManagers);
     }
 

http://git-wip-us.apache.org/repos/asf/incubator-tez/blob/649abcbd/tez-dag/src/test/java/org/apache/tez/dag/app/dag/impl/TestVertexImpl.java
----------------------------------------------------------------------
diff --git a/tez-dag/src/test/java/org/apache/tez/dag/app/dag/impl/TestVertexImpl.java b/tez-dag/src/test/java/org/apache/tez/dag/app/dag/impl/TestVertexImpl.java
index d99c57f..c2ff32d 100644
--- a/tez-dag/src/test/java/org/apache/tez/dag/app/dag/impl/TestVertexImpl.java
+++ b/tez-dag/src/test/java/org/apache/tez/dag/app/dag/impl/TestVertexImpl.java
@@ -37,6 +37,7 @@ import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
+import java.util.Map.Entry;
 
 import com.google.protobuf.ByteString;
 import org.apache.commons.logging.Log;
@@ -55,6 +56,8 @@ import org.apache.hadoop.yarn.util.Clock;
 import org.apache.hadoop.yarn.util.SystemClock;
 import org.apache.tez.dag.api.DagTypeConverters;
 import org.apache.tez.dag.api.EdgeManager;
+import org.apache.tez.dag.api.EdgeManagerContext;
+import org.apache.tez.dag.api.EdgeManagerDescriptor;
 import org.apache.tez.dag.api.EdgeProperty;
 import org.apache.tez.dag.api.InputDescriptor;
 import org.apache.tez.dag.api.TezConfiguration;
@@ -111,6 +114,8 @@ import org.apache.tez.dag.records.TezVertexID;
 import org.apache.tez.runtime.api.Event;
 import org.apache.tez.runtime.api.OutputCommitter;
 import org.apache.tez.runtime.api.OutputCommitterContext;
+import org.apache.tez.runtime.api.events.DataMovementEvent;
+import org.apache.tez.runtime.api.events.InputReadErrorEvent;
 import org.apache.tez.runtime.api.events.RootInputConfigureVertexTasksEvent;
 import org.apache.tez.runtime.api.events.RootInputDataInformationEvent;
 import org.apache.tez.test.EdgeManagerForTest;
@@ -1399,16 +1404,20 @@ public class TestVertexImpl {
     startVertex(v3);
 
     Vertex v1 = vertices.get("vertex1");
-    EdgeManager mockEdgeManager = mock(EdgeManager.class);
-    Map<String, EdgeManager> edgeManager = Collections.singletonMap(
-       v1.getName(), mockEdgeManager);
-    v3.setParallelism(1, null, edgeManager);
+    EdgeManagerDescriptor mockEdgeManagerDescriptor =
+        new EdgeManagerDescriptor(EdgeManagerForTest.class.getName());
+
+    Map<String, EdgeManagerDescriptor> edgeManagerDescriptors =
+        Collections.singletonMap(
+       v1.getName(), mockEdgeManagerDescriptor);
+    Assert.assertTrue(v3.setParallelism(1, null, edgeManagerDescriptors));
+    Assert.assertTrue(v3.sourceVertices.get(v1).getEdgeManager() instanceof
+        EdgeManagerForTest);
     Assert.assertEquals(1, v3.getTotalTasks());
     Assert.assertEquals(1, tasks.size());
     // the last one is removed
     Assert.assertTrue(tasks.keySet().iterator().next().equals(firstTask));
 
-    Assert.assertTrue(v3.sourceVertices.get(v1).getEdgeManager() == mockEdgeManager);
   }
 
   @Test(timeout = 5000)
@@ -1417,23 +1426,33 @@ public class TestVertexImpl {
     Edge edge = edges.get("e1");
     EdgeManager em = edge.getEdgeManager();
     EdgeManagerForTest originalEm = (EdgeManagerForTest) em;
-    Assert.assertEquals(true, originalEm.isCreatedByFramework());
     Assert.assertTrue(Arrays.equals(edgePayload, originalEm.getEdgeManagerContext()
         .getUserPayload()));
 
-    em = EdgeManagerForTest.createInstance();
+    byte[] userPayload = new String("foo").getBytes();
+    EdgeManagerDescriptor edgeManagerDescriptor =
+        new EdgeManagerDescriptor(EdgeManagerForTest.class.getName());
+    edgeManagerDescriptor.setUserPayload(userPayload);
+
     Vertex v1 = vertices.get("vertex1");
     Vertex v3 = vertices.get("vertex3"); // Vertex3 linked to v1 (v1 src, v3
                                          // dest)
-    Map<String, EdgeManager> edgeManagers = Collections.singletonMap(v1.getName(), em);
-    v3.setParallelism(v3.getTotalTasks() - 1, null, edgeManagers); // Must decrease.
 
-    EdgeManagerForTest edgeManagerPostSet = (EdgeManagerForTest) edge.getEdgeManager();
-    Assert.assertEquals(false, edgeManagerPostSet.isCreatedByFramework());
+    Map<String, EdgeManagerDescriptor> edgeManagerDescriptors =
+        Collections.singletonMap(v1.getName(), edgeManagerDescriptor);
+    Assert.assertTrue(v3.setParallelism(v3.getTotalTasks() - 1, null,
+        edgeManagerDescriptors)); // Must decrease.
+
+    VertexImpl v3Impl = (VertexImpl) v3;
+
+    EdgeManager modifiedEdgeManager = v3Impl.sourceVertices.get(v1)
+        .getEdgeManager();
+    Assert.assertNotNull(modifiedEdgeManager);
+    Assert.assertTrue(modifiedEdgeManager instanceof EdgeManagerForTest);
 
     // Ensure initialize() is called with the correct payload
-    Assert.assertTrue(Arrays.equals(originalEm.getEdgeManagerContext().getUserPayload(),
-        edgeManagerPostSet.getEdgeManagerContext().getUserPayload()));
+    Assert.assertTrue(Arrays.equals(userPayload,
+        ((EdgeManagerForTest) modifiedEdgeManager).getUserPayload()));
   }
 
   @SuppressWarnings("unchecked")

http://git-wip-us.apache.org/repos/asf/incubator-tez/blob/649abcbd/tez-dag/src/test/java/org/apache/tez/test/EdgeManagerForTest.java
----------------------------------------------------------------------
diff --git a/tez-dag/src/test/java/org/apache/tez/test/EdgeManagerForTest.java b/tez-dag/src/test/java/org/apache/tez/test/EdgeManagerForTest.java
index 7fbd4d8..24b39fe 100644
--- a/tez-dag/src/test/java/org/apache/tez/test/EdgeManagerForTest.java
+++ b/tez-dag/src/test/java/org/apache/tez/test/EdgeManagerForTest.java
@@ -29,32 +29,30 @@ import org.apache.tez.runtime.api.events.InputReadErrorEvent;
 public class EdgeManagerForTest implements EdgeManager {
 
   private EdgeManagerContext edgeManagerContext = null;
-  private boolean createdByFramework = true;
+  private byte[] userPayload;
 
   public static EdgeManagerForTest createInstance() {
     EdgeManagerForTest e = new EdgeManagerForTest();
-    e.createdByFramework = false;
     return e;
   }
-  
-  public boolean isCreatedByFramework() {
-    return createdByFramework;
-  }
-  
+
   public EdgeManagerContext getEdgeManagerContext() {
     return edgeManagerContext;
   }
 
   
-  
-  // Overridden methods
-  
   public EdgeManagerForTest() {
   }
 
+  public byte[] getUserPayload() {
+    return userPayload;
+  }
+
+  // Overridden methods
   @Override
   public void initialize(EdgeManagerContext edgeManagerContext) {
     this.edgeManagerContext = edgeManagerContext;
+    this.userPayload = edgeManagerContext.getUserPayload();
   }
 
   @Override

http://git-wip-us.apache.org/repos/asf/incubator-tez/blob/649abcbd/tez-runtime-library/src/main/java/org/apache/tez/dag/library/vertexmanager/ShuffleVertexManager.java
----------------------------------------------------------------------
diff --git a/tez-runtime-library/src/main/java/org/apache/tez/dag/library/vertexmanager/ShuffleVertexManager.java b/tez-runtime-library/src/main/java/org/apache/tez/dag/library/vertexmanager/ShuffleVertexManager.java
index 70e9fae..f4e1957 100644
--- a/tez-runtime-library/src/main/java/org/apache/tez/dag/library/vertexmanager/ShuffleVertexManager.java
+++ b/tez-runtime-library/src/main/java/org/apache/tez/dag/library/vertexmanager/ShuffleVertexManager.java
@@ -33,10 +33,10 @@ import org.apache.hadoop.conf.Configuration;
 import org.apache.tez.common.TezUtils;
 import org.apache.tez.dag.api.EdgeManager;
 import org.apache.tez.dag.api.EdgeManagerContext;
+import org.apache.tez.dag.api.EdgeManagerDescriptor;
 import org.apache.tez.dag.api.EdgeProperty;
 import org.apache.tez.dag.api.InputDescriptor;
 import org.apache.tez.dag.api.TezUncheckedException;
-import org.apache.tez.dag.api.VertexLocationHint;
 import org.apache.tez.dag.api.VertexManagerPlugin;
 import org.apache.tez.dag.api.VertexManagerPluginContext;
 import org.apache.tez.dag.api.EdgeProperty.DataMovementType;
@@ -44,6 +44,7 @@ import org.apache.tez.runtime.api.Event;
 import org.apache.tez.runtime.api.events.DataMovementEvent;
 import org.apache.tez.runtime.api.events.InputReadErrorEvent;
 import org.apache.tez.runtime.api.events.VertexManagerEvent;
+import org.apache.tez.runtime.library.shuffle.impl.ShuffleUserPayloads.ShuffleEdgeManagerConfigPayloadProto;
 import org.apache.tez.runtime.library.shuffle.impl.ShuffleUserPayloads.VertexManagerEventPayloadProto;
 
 import com.google.common.collect.Lists;
@@ -125,23 +126,35 @@ public class ShuffleVertexManager implements VertexManagerPlugin {
   }
   
   
-  public class CustomShuffleEdgeManager implements EdgeManager {
+  public static class CustomShuffleEdgeManager implements EdgeManager {
     int numSourceTaskOutputs;
     int numDestinationTasks;
     int basePartitionRange;
     int remainderRangeForLastShuffler;
-    
-    CustomShuffleEdgeManager(int numSourceTaskOutputs, int numDestinationTasks,
-        int basePartitionRange, int remainderPartitionForLastShuffler) {
-      this.numSourceTaskOutputs = numSourceTaskOutputs;
-      this.numDestinationTasks = numDestinationTasks;
-      this.basePartitionRange = basePartitionRange;
-      this.remainderRangeForLastShuffler = remainderPartitionForLastShuffler;
+
+    public CustomShuffleEdgeManager() {
     }
 
     @Override
     public void initialize(EdgeManagerContext edgeManagerContext) {
       // Nothing to do. This class isn't currently designed to be used at the DAG API level.
+      byte[] userPayload = edgeManagerContext.getUserPayload();
+      if (userPayload == null
+        || userPayload.length == 0) {
+        throw new RuntimeException("Could not initialize CustomShuffleEdgeManager"
+            + " from provided user payload");
+      }
+      CustomShuffleEdgeManagerConfig config;
+      try {
+        config = CustomShuffleEdgeManagerConfig.fromUserPayload(userPayload);
+      } catch (InvalidProtocolBufferException e) {
+        throw new RuntimeException("Could not initialize CustomShuffleEdgeManager"
+            + " from provided user payload", e);
+      }
+      this.numSourceTaskOutputs = config.numSourceTaskOutputs;
+      this.numDestinationTasks = config.numDestinationTasks;
+      this.basePartitionRange = config.basePartitionRange;
+      this.remainderRangeForLastShuffler = config.remainderRangeForLastShuffler;
     }
 
     @Override
@@ -179,7 +192,7 @@ public class ShuffleVertexManager implements VertexManagerPlugin {
           sourceTaskIndex * partitionRange 
           + sourceIndex % partitionRange;
       
-      inputIndicesToTaskIndices.put(new Integer(targetIndex), 
+      inputIndicesToTaskIndices.put(new Integer(targetIndex),
           Collections.singletonList(new Integer(destinationTaskIndex)));
     }
     
@@ -233,6 +246,44 @@ public class ShuffleVertexManager implements VertexManagerPlugin {
         int numDestTasks) {
       return numDestTasks;
     }
+   }
+
+  private static class CustomShuffleEdgeManagerConfig {
+    int numSourceTaskOutputs;
+    int numDestinationTasks;
+    int basePartitionRange;
+    int remainderRangeForLastShuffler;
+
+    private CustomShuffleEdgeManagerConfig(int numSourceTaskOutputs,
+        int numDestinationTasks,
+        int basePartitionRange,
+        int remainderRangeForLastShuffler) {
+      this.numSourceTaskOutputs = numSourceTaskOutputs;
+      this.numDestinationTasks = numDestinationTasks;
+      this.basePartitionRange = basePartitionRange;
+      this.remainderRangeForLastShuffler = remainderRangeForLastShuffler;
+    }
+
+    public byte[] toUserPayload() {
+      return ShuffleEdgeManagerConfigPayloadProto.newBuilder()
+          .setNumSourceTaskOutputs(numSourceTaskOutputs)
+          .setNumDestinationTasks(numDestinationTasks)
+          .setBasePartitionRange(basePartitionRange)
+          .setRemainderRangeForLastShuffler(remainderRangeForLastShuffler)
+          .build().toByteArray();
+    }
+
+    public static CustomShuffleEdgeManagerConfig fromUserPayload(
+        byte[] userPayload) throws InvalidProtocolBufferException {
+      ShuffleEdgeManagerConfigPayloadProto proto =
+          ShuffleEdgeManagerConfigPayloadProto.parseFrom(userPayload);
+      return new CustomShuffleEdgeManagerConfig(
+          proto.getNumSourceTaskOutputs(),
+          proto.getNumDestinationTasks(),
+          proto.getBasePartitionRange(),
+          proto.getRemainderRangeForLastShuffler());
+
+    }
   }
 
   
@@ -360,15 +411,20 @@ public class ShuffleVertexManager implements VertexManagerPlugin {
           
     if(finalTaskParallelism < currentParallelism) {
       // final parallelism is less than actual parallelism
-      Map<String, EdgeManager> edgeManagers = new HashMap<String, EdgeManager>(
-          bipartiteSources.size());
+      Map<String, EdgeManagerDescriptor> edgeManagers =
+          new HashMap<String, EdgeManagerDescriptor>(bipartiteSources.size());
       for(String vertex : bipartiteSources.keySet()) {
         // use currentParallelism for numSourceTasks to maintain original state
         // for the source tasks
-        edgeManagers.put(vertex, new CustomShuffleEdgeManager(
-            currentParallelism, finalTaskParallelism, basePartitionRange,
-            ((remainderRangeForLastShuffler > 0) ?
-                remainderRangeForLastShuffler : basePartitionRange)));
+        CustomShuffleEdgeManagerConfig edgeManagerConfig =
+            new CustomShuffleEdgeManagerConfig(
+                currentParallelism, finalTaskParallelism, basePartitionRange,
+                ((remainderRangeForLastShuffler > 0) ?
+                    remainderRangeForLastShuffler : basePartitionRange));
+        EdgeManagerDescriptor edgeManagerDescriptor =
+            new EdgeManagerDescriptor(CustomShuffleEdgeManager.class.getName());
+        edgeManagerDescriptor.setUserPayload(edgeManagerConfig.toUserPayload());
+        edgeManagers.put(vertex, edgeManagerDescriptor);
       }
       
       context.setVertexParallelism(finalTaskParallelism, null, edgeManagers);

http://git-wip-us.apache.org/repos/asf/incubator-tez/blob/649abcbd/tez-runtime-library/src/main/proto/ShufflePayloads.proto
----------------------------------------------------------------------
diff --git a/tez-runtime-library/src/main/proto/ShufflePayloads.proto b/tez-runtime-library/src/main/proto/ShufflePayloads.proto
index 34767ba..b4ae332 100644
--- a/tez-runtime-library/src/main/proto/ShufflePayloads.proto
+++ b/tez-runtime-library/src/main/proto/ShufflePayloads.proto
@@ -42,3 +42,10 @@ message InputInformationEventPayloadProto {
 message VertexManagerEventPayloadProto {
   optional int64 output_size = 1;
 }
+
+message ShuffleEdgeManagerConfigPayloadProto {
+  optional int32 num_source_task_outputs = 1;
+  optional int32 num_destination_tasks = 2;
+  optional int32 base_partition_range = 3;
+  optional int32 remainder_range_for_last_shuffler = 4;
+}

http://git-wip-us.apache.org/repos/asf/incubator-tez/blob/649abcbd/tez-runtime-library/src/test/java/org/apache/tez/dag/library/vertexmanager/TestShuffleVertexManager.java
----------------------------------------------------------------------
diff --git a/tez-runtime-library/src/test/java/org/apache/tez/dag/library/vertexmanager/TestShuffleVertexManager.java b/tez-runtime-library/src/test/java/org/apache/tez/dag/library/vertexmanager/TestShuffleVertexManager.java
index fd11378..334ebb4 100644
--- a/tez-runtime-library/src/test/java/org/apache/tez/dag/library/vertexmanager/TestShuffleVertexManager.java
+++ b/tez-runtime-library/src/test/java/org/apache/tez/dag/library/vertexmanager/TestShuffleVertexManager.java
@@ -23,10 +23,13 @@ import java.util.HashMap;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
+import java.util.Map.Entry;
 
 import org.apache.hadoop.conf.Configuration;
 import org.apache.tez.common.TezUtils;
 import org.apache.tez.dag.api.EdgeManager;
+import org.apache.tez.dag.api.EdgeManagerContext;
+import org.apache.tez.dag.api.EdgeManagerDescriptor;
 import org.apache.tez.dag.api.EdgeProperty;
 import org.apache.tez.dag.api.InputDescriptor;
 import org.apache.tez.dag.api.OutputDescriptor;
@@ -34,6 +37,7 @@ import org.apache.tez.dag.api.TezUncheckedException;
 import org.apache.tez.dag.api.VertexLocationHint;
 import org.apache.tez.dag.api.EdgeProperty.SchedulingType;
 import org.apache.tez.dag.api.VertexManagerPluginContext;
+import org.apache.tez.runtime.RuntimeUtils;
 import org.apache.tez.runtime.api.events.DataMovementEvent;
 import org.apache.tez.runtime.api.events.VertexManagerEvent;
 import org.apache.tez.runtime.library.shuffle.impl.ShuffleUserPayloads.VertexManagerEventPayloadProto;
@@ -108,13 +112,36 @@ public class TestShuffleVertexManager {
           return null;
       }}).when(mockContext).scheduleVertexTasks(anyList());
     
-    final Map<String, EdgeManager> newEdgeManagers = new HashMap<String, EdgeManager>();
+    final Map<String, EdgeManager> newEdgeManagers =
+        new HashMap<String, EdgeManager>();
     
     doAnswer(new Answer() {
       public Object answer(InvocationOnMock invocation) {
           when(mockContext.getVertexNumTasks(mockManagedVertexId)).thenReturn(2);
           newEdgeManagers.clear();
-          newEdgeManagers.putAll((Map<String, EdgeManager>)invocation.getArguments()[2]);
+          for (Entry<String, EdgeManagerDescriptor> entry :
+              ((Map<String, EdgeManagerDescriptor>)invocation.getArguments()[2]).entrySet()) {
+            EdgeManager edgeManager = RuntimeUtils.createClazzInstance(
+                entry.getValue().getClassName());
+            final byte[] userPayload = entry.getValue().getUserPayload();
+            edgeManager.initialize(new EdgeManagerContext() {
+              @Override
+              public byte[] getUserPayload() {
+                return userPayload;
+              }
+
+              @Override
+              public String getSrcVertexName() {
+                return null;
+              }
+
+              @Override
+              public String getDestVertexName() {
+                return null;
+              }
+            });
+            newEdgeManagers.put(entry.getKey(), edgeManager);
+          }
           return null;
       }}).when(mockContext).setVertexParallelism(eq(2), any(VertexLocationHint.class), anyMap());