You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tez.apache.org by ss...@apache.org on 2015/05/06 09:41:35 UTC

[33/50] [abbrv] tez git commit: TEZ-2006. Task communication plane needs to be pluggable. (sseth)

TEZ-2006. Task communication plane needs to be pluggable. (sseth)


Project: http://git-wip-us.apache.org/repos/asf/tez/repo
Commit: http://git-wip-us.apache.org/repos/asf/tez/commit/56986504
Tree: http://git-wip-us.apache.org/repos/asf/tez/tree/56986504
Diff: http://git-wip-us.apache.org/repos/asf/tez/diff/56986504

Branch: refs/heads/TEZ-2003
Commit: 56986504695e995a87c8923a9713002ab1380468
Parents: 0fc4c69
Author: Siddharth Seth <ss...@apache.org>
Authored: Thu Feb 12 11:25:45 2015 -0800
Committer: Siddharth Seth <ss...@apache.org>
Committed: Wed May 6 00:13:29 2015 -0700

----------------------------------------------------------------------
 TEZ-2003-CHANGES.txt                            |   1 +
 .../apache/tez/dag/api/TaskCommunicator.java    |  54 ++
 .../tez/dag/api/TaskCommunicatorContext.java    |  48 ++
 .../tez/dag/api/TaskHeartbeatRequest.java       |  63 +++
 .../tez/dag/api/TaskHeartbeatResponse.java      |  39 ++
 .../java/org/apache/tez/dag/app/AppContext.java |   3 +
 .../org/apache/tez/dag/app/DAGAppMaster.java    |   5 +
 .../dag/app/TaskAttemptListenerImpTezDag.java   | 517 +++++++------------
 .../tez/dag/app/TezTaskCommunicatorImpl.java    | 474 +++++++++++++++++
 .../app/launcher/LocalContainerLauncher.java    |  10 +-
 .../tez/dag/app/rm/container/AMContainer.java   |   3 +-
 .../rm/container/AMContainerEventAssignTA.java  |   2 +
 .../dag/app/rm/container/AMContainerImpl.java   |   1 +
 .../apache/tez/dag/app/MockDAGAppMaster.java    |  25 +-
 .../app/TestTaskAttemptListenerImplTezDag.java  |  82 +--
 15 files changed, 965 insertions(+), 362 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/tez/blob/56986504/TEZ-2003-CHANGES.txt
----------------------------------------------------------------------
diff --git a/TEZ-2003-CHANGES.txt b/TEZ-2003-CHANGES.txt
index 1822fcb..d7e4be5 100644
--- a/TEZ-2003-CHANGES.txt
+++ b/TEZ-2003-CHANGES.txt
@@ -1,4 +1,5 @@
 ALL CHANGES:
   TEZ-2019. Temporarily allow the scheduler and launcher to be specified via configuration.
+  TEZ-2006. Task communication plane needs to be pluggable.
 
 INCOMPATIBLE CHANGES:

http://git-wip-us.apache.org/repos/asf/tez/blob/56986504/tez-dag/src/main/java/org/apache/tez/dag/api/TaskCommunicator.java
----------------------------------------------------------------------
diff --git a/tez-dag/src/main/java/org/apache/tez/dag/api/TaskCommunicator.java b/tez-dag/src/main/java/org/apache/tez/dag/api/TaskCommunicator.java
new file mode 100644
index 0000000..97f9c16
--- /dev/null
+++ b/tez-dag/src/main/java/org/apache/tez/dag/api/TaskCommunicator.java
@@ -0,0 +1,54 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.dag.api;
+
+import java.net.InetSocketAddress;
+import java.util.Map;
+
+import org.apache.hadoop.security.Credentials;
+import org.apache.hadoop.service.AbstractService;
+import org.apache.hadoop.yarn.api.records.ContainerId;
+import org.apache.hadoop.yarn.api.records.LocalResource;
+import org.apache.tez.dag.records.TezTaskAttemptID;
+import org.apache.tez.runtime.api.impl.TaskSpec;
+
+// TODO TEZ-2003 Move this into the tez-api module
+public abstract class TaskCommunicator extends AbstractService {
+  public TaskCommunicator(String name) {
+    super(name);
+  }
+
+  // TODO TEZ-2003 Ideally, don't expose YARN containerId; instead expose a Tez specific construct.
+  // TODO When talking to an external service, this plugin implementer may need access to a host:port
+  public abstract void registerRunningContainer(ContainerId containerId, String hostname, int port);
+
+  // TODO TEZ-2003 Ideally, don't expose YARN containerId; instead expose a Tez specific construct.
+  public abstract void registerContainerEnd(ContainerId containerId);
+
+  // TODO TEZ-2003 TaskSpec breakup into a clean interface
+  // TODO TEZ-2003 Add support for priority
+  public abstract void registerRunningTaskAttempt(ContainerId containerId, TaskSpec taskSpec,
+                                                  Map<String, LocalResource> additionalResources,
+                                                  Credentials credentials,
+                                                  boolean credentialsChanged);
+
+  // TODO TEZ-2003 Remove reference to TaskAttemptID
+  public abstract void unregisterRunningTaskAttempt(TezTaskAttemptID taskAttemptID);
+
+  // TODO TEZ-2003 This doesn't necessarily belong here. A server may not start within the AM.
+  public abstract InetSocketAddress getAddress();
+
+  // TODO Eventually. Add methods here to support preemption of tasks.
+}

http://git-wip-us.apache.org/repos/asf/tez/blob/56986504/tez-dag/src/main/java/org/apache/tez/dag/api/TaskCommunicatorContext.java
----------------------------------------------------------------------
diff --git a/tez-dag/src/main/java/org/apache/tez/dag/api/TaskCommunicatorContext.java b/tez-dag/src/main/java/org/apache/tez/dag/api/TaskCommunicatorContext.java
new file mode 100644
index 0000000..9b2d889
--- /dev/null
+++ b/tez-dag/src/main/java/org/apache/tez/dag/api/TaskCommunicatorContext.java
@@ -0,0 +1,48 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.dag.api;
+
+import java.io.IOException;
+
+import org.apache.hadoop.security.Credentials;
+import org.apache.hadoop.yarn.api.records.ApplicationAttemptId;
+import org.apache.hadoop.yarn.api.records.ContainerId;
+import org.apache.tez.dag.records.TezTaskAttemptID;
+
+
+// Do not make calls into this from within a held lock.
+
+// TODO TEZ-2003 Move this into the tez-api module
+public interface TaskCommunicatorContext {
+
+  // TODO TEZ-2003 Add signalling back into this to indicate errors - e.g. Container unregsitered, task no longer running, etc.
+
+  // TODO TEZ-2003 Maybe add book-keeping as a helper library, instead of each impl tracking container to task etc.
+
+  ApplicationAttemptId getApplicationAttemptId();
+  Credentials getCredentials();
+
+  // TODO TEZ-2003 Move to vertex, taskIndex, version
+  boolean canCommit(TezTaskAttemptID taskAttemptId) throws IOException;
+
+  TaskHeartbeatResponse heartbeat(TaskHeartbeatRequest request) throws IOException, TezException;
+
+  boolean isKnownContainer(ContainerId containerId);
+
+  // TODO TEZ-2003 Move to vertex, taskIndex, version
+  void taskStartedRemotely(TezTaskAttemptID taskAttemptID, ContainerId containerId);
+
+  // TODO Eventually Add methods to report availability stats to the scheduler.
+}

http://git-wip-us.apache.org/repos/asf/tez/blob/56986504/tez-dag/src/main/java/org/apache/tez/dag/api/TaskHeartbeatRequest.java
----------------------------------------------------------------------
diff --git a/tez-dag/src/main/java/org/apache/tez/dag/api/TaskHeartbeatRequest.java b/tez-dag/src/main/java/org/apache/tez/dag/api/TaskHeartbeatRequest.java
new file mode 100644
index 0000000..f6bc8f0
--- /dev/null
+++ b/tez-dag/src/main/java/org/apache/tez/dag/api/TaskHeartbeatRequest.java
@@ -0,0 +1,63 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.dag.api;
+
+import java.util.List;
+
+import org.apache.tez.dag.records.TezTaskAttemptID;
+import org.apache.tez.runtime.api.impl.TezEvent;
+
+// TODO TEZ-2003 Move this into the tez-api module
+public class TaskHeartbeatRequest {
+
+  // TODO TEZ-2003 Ideally containerIdentifier should not be part of the request.
+  // Replace with a task lookup - vertex name + task index
+  private final String containerIdentifier;
+  // TODO TEZ-2003 Get rid of the task attemptId reference if possible
+  private final TezTaskAttemptID taskAttemptId;
+  private final List<TezEvent> events;
+  private final int startIndex;
+  private final int maxEvents;
+
+
+  public TaskHeartbeatRequest(String containerIdentifier, TezTaskAttemptID taskAttemptId, List<TezEvent> events, int startIndex,
+                              int maxEvents) {
+    this.containerIdentifier = containerIdentifier;
+    this.taskAttemptId = taskAttemptId;
+    this.events = events;
+    this.startIndex = startIndex;
+    this.maxEvents = maxEvents;
+  }
+
+  public String getContainerIdentifier() {
+    return containerIdentifier;
+  }
+
+  public TezTaskAttemptID getTaskAttemptId() {
+    return taskAttemptId;
+  }
+
+  public List<TezEvent> getEvents() {
+    return events;
+  }
+
+  public int getStartIndex() {
+    return startIndex;
+  }
+
+  public int getMaxEvents() {
+    return maxEvents;
+  }
+}

http://git-wip-us.apache.org/repos/asf/tez/blob/56986504/tez-dag/src/main/java/org/apache/tez/dag/api/TaskHeartbeatResponse.java
----------------------------------------------------------------------
diff --git a/tez-dag/src/main/java/org/apache/tez/dag/api/TaskHeartbeatResponse.java b/tez-dag/src/main/java/org/apache/tez/dag/api/TaskHeartbeatResponse.java
new file mode 100644
index 0000000..c82a743
--- /dev/null
+++ b/tez-dag/src/main/java/org/apache/tez/dag/api/TaskHeartbeatResponse.java
@@ -0,0 +1,39 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.dag.api;
+
+import java.util.List;
+
+import org.apache.tez.runtime.api.impl.TezEvent;
+
+// TODO TEZ-2003 Move this into the tez-api module
+public class TaskHeartbeatResponse {
+
+  private final boolean shouldDie;
+  private List<TezEvent> events;
+
+  public TaskHeartbeatResponse(boolean shouldDie, List<TezEvent> events) {
+    this.shouldDie = shouldDie;
+    this.events = events;
+  }
+
+  public boolean isShouldDie() {
+    return shouldDie;
+  }
+
+  public List<TezEvent> getEvents() {
+    return events;
+  }
+}

http://git-wip-us.apache.org/repos/asf/tez/blob/56986504/tez-dag/src/main/java/org/apache/tez/dag/app/AppContext.java
----------------------------------------------------------------------
diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/AppContext.java b/tez-dag/src/main/java/org/apache/tez/dag/app/AppContext.java
index 4781784..37f7624 100644
--- a/tez-dag/src/main/java/org/apache/tez/dag/app/AppContext.java
+++ b/tez-dag/src/main/java/org/apache/tez/dag/app/AppContext.java
@@ -24,6 +24,7 @@ import java.util.Set;
 import org.apache.hadoop.classification.InterfaceAudience;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.security.Credentials;
 import org.apache.hadoop.yarn.api.records.ApplicationAccessType;
 import org.apache.hadoop.yarn.api.records.ApplicationAttemptId;
 import org.apache.hadoop.yarn.api.records.ApplicationId;
@@ -106,4 +107,6 @@ public interface AppContext {
   String[] getLocalDirs();
 
   String getAMUser();
+
+  Credentials getAppCredentials();
 }

http://git-wip-us.apache.org/repos/asf/tez/blob/56986504/tez-dag/src/main/java/org/apache/tez/dag/app/DAGAppMaster.java
----------------------------------------------------------------------
diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/DAGAppMaster.java b/tez-dag/src/main/java/org/apache/tez/dag/app/DAGAppMaster.java
index 73ee56e..bfc2d58 100644
--- a/tez-dag/src/main/java/org/apache/tez/dag/app/DAGAppMaster.java
+++ b/tez-dag/src/main/java/org/apache/tez/dag/app/DAGAppMaster.java
@@ -1486,6 +1486,11 @@ public class DAGAppMaster extends AbstractService {
     }
 
     @Override
+    public Credentials getAppCredentials() {
+      return amCredentials;
+    }
+
+    @Override
     public Map<ApplicationAccessType, String> getApplicationACLs() {
       if (getServiceState() != STATE.STARTED) {
         throw new TezUncheckedException(

http://git-wip-us.apache.org/repos/asf/tez/blob/56986504/tez-dag/src/main/java/org/apache/tez/dag/app/TaskAttemptListenerImpTezDag.java
----------------------------------------------------------------------
diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/TaskAttemptListenerImpTezDag.java b/tez-dag/src/main/java/org/apache/tez/dag/app/TaskAttemptListenerImpTezDag.java
index d96da83..c34723a 100644
--- a/tez-dag/src/main/java/org/apache/tez/dag/app/TaskAttemptListenerImpTezDag.java
+++ b/tez-dag/src/main/java/org/apache/tez/dag/app/TaskAttemptListenerImpTezDag.java
@@ -18,15 +18,14 @@
 package org.apache.tez.dag.app;
 
 import java.io.IOException;
-import java.net.InetAddress;
+import java.lang.reflect.Constructor;
+import java.lang.reflect.InvocationTargetException;
 import java.net.InetSocketAddress;
 import java.net.URISyntaxException;
 import java.net.UnknownHostException;
 import java.util.ArrayList;
 import java.util.Collections;
 import java.util.List;
-import java.util.Map;
-import java.util.Map.Entry;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.ConcurrentMap;
 
@@ -38,216 +37,212 @@ import org.apache.tez.runtime.api.events.TaskStatusUpdateEvent;
 import org.apache.tez.runtime.api.impl.EventType;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
+import org.apache.hadoop.classification.InterfaceAudience;
 import org.apache.hadoop.conf.Configuration;
-import org.apache.hadoop.fs.CommonConfigurationKeysPublic;
-import org.apache.hadoop.ipc.ProtocolSignature;
-import org.apache.hadoop.ipc.RPC;
-import org.apache.hadoop.ipc.Server;
-import org.apache.hadoop.net.NetUtils;
-import org.apache.hadoop.security.authorize.PolicyProvider;
+import org.apache.hadoop.security.Credentials;
 import org.apache.hadoop.service.AbstractService;
+import org.apache.hadoop.yarn.api.records.ApplicationAttemptId;
+import org.apache.hadoop.yarn.api.records.NodeId;
+import org.apache.tez.common.ReflectionUtils;
+import org.apache.tez.dag.api.TaskCommunicator;
+import org.apache.tez.dag.api.TaskCommunicatorContext;
+import org.apache.tez.dag.api.TaskHeartbeatResponse;
+import org.apache.tez.dag.api.TezConfiguration;
 import org.apache.tez.dag.api.TezException;
 import org.apache.tez.dag.api.TezUncheckedException;
 import org.apache.hadoop.yarn.api.records.ContainerId;
-import org.apache.hadoop.yarn.api.records.LocalResource;
 import org.apache.hadoop.yarn.util.ConverterUtils;
-import org.apache.tez.common.ContainerContext;
-import org.apache.tez.common.ContainerTask;
-import org.apache.tez.common.TezConverterUtils;
-import org.apache.tez.common.TezLocalResource;
-import org.apache.tez.common.TezTaskUmbilicalProtocol;
-import org.apache.tez.dag.api.TezConfiguration;
+import org.apache.tez.dag.api.TaskHeartbeatRequest;
 import org.apache.tez.dag.app.dag.DAG;
 import org.apache.tez.dag.app.dag.Task;
 import org.apache.tez.dag.app.dag.event.TaskAttemptEventStartedRemotely;
 import org.apache.tez.dag.app.dag.event.VertexEventRouteEvent;
+import org.apache.tez.dag.app.rm.TaskSchedulerService;
 import org.apache.tez.dag.app.rm.container.AMContainerTask;
-import org.apache.tez.dag.app.security.authorize.TezAMPolicyProvider;
 import org.apache.tez.dag.records.TezTaskAttemptID;
 import org.apache.tez.dag.records.TezVertexID;
 import org.apache.tez.runtime.api.impl.TezEvent;
-import org.apache.tez.runtime.api.impl.TezHeartbeatRequest;
-import org.apache.tez.runtime.api.impl.TezHeartbeatResponse;
 import org.apache.tez.common.security.JobTokenSecretManager;
 
-import com.google.common.collect.Maps;
 
 @SuppressWarnings("unchecked")
+@InterfaceAudience.Private
 public class TaskAttemptListenerImpTezDag extends AbstractService implements
-    TezTaskUmbilicalProtocol, TaskAttemptListener {
-
-  private static final ContainerTask TASK_FOR_INVALID_JVM = new ContainerTask(
-      null, true, null, null, false);
+    TaskAttemptListener, TaskCommunicatorContext {
 
   private static final Logger LOG = LoggerFactory
       .getLogger(TaskAttemptListenerImpTezDag.class);
 
   private final AppContext context;
+  private TaskCommunicator taskCommunicator;
 
   protected final TaskHeartbeatHandler taskHeartbeatHandler;
   protected final ContainerHeartbeatHandler containerHeartbeatHandler;
-  private final JobTokenSecretManager jobTokenSecretManager;
-  private InetSocketAddress address;
-  private Server server;
-
-  static class ContainerInfo {
-    ContainerInfo() {
-      this.lastReponse = null;
-      this.lastRequestId = 0;
-      this.amContainerTask = null;
-      this.taskPulled = false;
+
+  private final TaskHeartbeatResponse RESPONSE_SHOULD_DIE = new TaskHeartbeatResponse(true, null);
+
+  private final ConcurrentMap<TezTaskAttemptID, ContainerId> registeredAttempts =
+      new ConcurrentHashMap<TezTaskAttemptID, ContainerId>();
+  private final ConcurrentMap<ContainerId, ContainerInfo> registeredContainers =
+      new ConcurrentHashMap<ContainerId, ContainerInfo>();
+
+  // Defined primarily to work around ConcurrentMaps not accepting null values
+  private static final class ContainerInfo {
+    TezTaskAttemptID taskAttemptId;
+    ContainerInfo(TezTaskAttemptID taskAttemptId) {
+      this.taskAttemptId = taskAttemptId;
     }
-    long lastRequestId;
-    TezHeartbeatResponse lastReponse;
-    AMContainerTask amContainerTask;
-    boolean taskPulled;
   }
 
-  private ConcurrentMap<TezTaskAttemptID, ContainerId> attemptToInfoMap =
-      new ConcurrentHashMap<TezTaskAttemptID, ContainerId>();
+  private static final ContainerInfo NULL_CONTAINER_INFO = new ContainerInfo(null);
 
-  private ConcurrentHashMap<ContainerId, ContainerInfo> registeredContainers =
-      new ConcurrentHashMap<ContainerId, ContainerInfo>();
 
   public TaskAttemptListenerImpTezDag(AppContext context,
-      TaskHeartbeatHandler thh, ContainerHeartbeatHandler chh,
-      JobTokenSecretManager jobTokenSecretManager) {
+                                      TaskHeartbeatHandler thh, ContainerHeartbeatHandler chh,
+                                      // TODO TEZ-2003 pre-merge. Remove reference to JobTokenSecretManager.
+                                      JobTokenSecretManager jobTokenSecretManager) {
     super(TaskAttemptListenerImpTezDag.class.getName());
     this.context = context;
-    this.jobTokenSecretManager = jobTokenSecretManager;
     this.taskHeartbeatHandler = thh;
     this.containerHeartbeatHandler = chh;
+    this.taskCommunicator = new TezTaskCommunicatorImpl(this);
   }
 
   @Override
-  public void serviceStart() {
-    startRpcServer();
-  }
-
-  protected void startRpcServer() {
-    Configuration conf = getConfig();
-    if (!conf.getBoolean(TezConfiguration.TEZ_LOCAL_MODE, TezConfiguration.TEZ_LOCAL_MODE_DEFAULT)) {
-      try {
-        server = new RPC.Builder(conf)
-            .setProtocol(TezTaskUmbilicalProtocol.class)
-            .setBindAddress("0.0.0.0")
-            .setPort(0)
-            .setInstance(this)
-            .setNumHandlers(
-                conf.getInt(TezConfiguration.TEZ_AM_TASK_LISTENER_THREAD_COUNT,
-                    TezConfiguration.TEZ_AM_TASK_LISTENER_THREAD_COUNT_DEFAULT))
-            .setSecretManager(jobTokenSecretManager).build();
-
-        // Enable service authorization?
-        if (conf.getBoolean(
-            CommonConfigurationKeysPublic.HADOOP_SECURITY_AUTHORIZATION,
-            false)) {
-          refreshServiceAcls(conf, new TezAMPolicyProvider());
-        }
-
-        server.start();
-        this.address = NetUtils.getConnectAddress(server);
-      } catch (IOException e) {
-        throw new TezUncheckedException(e);
-      }
+  public void serviceInit(Configuration conf) {
+    String taskCommClassName = conf.get(TezConfiguration.TEZ_AM_TASK_COMMUNICATOR_CLASS);
+    if (taskCommClassName == null) {
+      LOG.info("Using Default Task Communicator");
+      this.taskCommunicator = new TezTaskCommunicatorImpl(this);
     } else {
+      LOG.info("Using TaskCommunicator: " + taskCommClassName);
+      Class<? extends TaskCommunicator> taskCommClazz = (Class<? extends TaskCommunicator>) ReflectionUtils
+          .getClazz(taskCommClassName);
       try {
-        this.address = new InetSocketAddress(InetAddress.getLocalHost(), 0);
-      } catch (UnknownHostException e) {
+        Constructor<? extends TaskCommunicator> ctor = taskCommClazz.getConstructor(TaskCommunicatorContext.class);
+        ctor.setAccessible(true);
+        this.taskCommunicator = ctor.newInstance(this);
+      } catch (NoSuchMethodException e) {
+        throw new TezUncheckedException(e);
+      } catch (InvocationTargetException e) {
+        throw new TezUncheckedException(e);
+      } catch (InstantiationException e) {
+        throw new TezUncheckedException(e);
+      } catch (IllegalAccessException e) {
         throw new TezUncheckedException(e);
-      }
-      if (LOG.isDebugEnabled()) {
-        LOG.debug("Not starting TaskAttemptListener RPC in LocalMode");
       }
     }
   }
 
-  void refreshServiceAcls(Configuration configuration,
-      PolicyProvider policyProvider) {
-    this.server.refreshServiceAcl(configuration, policyProvider);
+  @Override
+  public void serviceStart() {
+    taskCommunicator.init(getConfig());
+    taskCommunicator.start();
   }
 
   @Override
   public void serviceStop() {
-    stopRpcServer();
-  }
-
-  protected void stopRpcServer() {
-    if (server != null) {
-      server.stop();
+    if (taskCommunicator != null) {
+      taskCommunicator.stop();
+      taskCommunicator = null;
     }
   }
 
-  public InetSocketAddress getAddress() {
-    return address;
-  }
-
   @Override
-  public long getProtocolVersion(String protocol, long clientVersion)
-      throws IOException {
-    return versionID;
+  public ApplicationAttemptId getApplicationAttemptId() {
+    return context.getApplicationAttemptId();
   }
 
   @Override
-  public ProtocolSignature getProtocolSignature(String protocol,
-      long clientVersion, int clientMethodsHash) throws IOException {
-    return ProtocolSignature.getProtocolSignature(this, protocol,
-        clientVersion, clientMethodsHash);
+  public Credentials getCredentials() {
+    return context.getAppCredentials();
   }
 
   @Override
-  public ContainerTask getTask(ContainerContext containerContext)
-      throws IOException {
+  public TaskHeartbeatResponse heartbeat(TaskHeartbeatRequest request)
+      throws IOException, TezException {
+    ContainerId containerId = ConverterUtils.toContainerId(request
+        .getContainerIdentifier());
+    if (LOG.isDebugEnabled()) {
+      LOG.debug("Received heartbeat from container"
+          + ", request=" + request);
+    }
 
-    ContainerTask task = null;
+    if (!registeredContainers.containsKey(containerId)) {
+      LOG.warn("Received task heartbeat from unknown container with id: " + containerId +
+          ", asking it to die");
+      return RESPONSE_SHOULD_DIE;
+    }
 
-    if (containerContext == null || containerContext.getContainerIdentifier() == null) {
-      LOG.info("Invalid task request with an empty containerContext or containerId");
-      task = TASK_FOR_INVALID_JVM;
-    } else {
-      ContainerId containerId = ConverterUtils.toContainerId(containerContext
-          .getContainerIdentifier());
+    // A heartbeat can come in anytime. The AM may have made a decision to kill a running task/container
+    // meanwhile. If the decision is processed through the pipeline before the heartbeat is processed,
+    // the heartbeat will be dropped. Otherwise the heartbeat will be processed - and the system
+    // know how to handle this - via FailedInputEvents for example (relevant only if the heartbeat has events).
+    // So - avoiding synchronization.
+
+    pingContainerHeartbeatHandler(containerId);
+    List<TezEvent> outEvents = null;
+    TezTaskAttemptID taskAttemptID = request.getTaskAttemptId();
+    if (taskAttemptID != null) {
+      ContainerId containerIdFromMap = registeredAttempts.get(taskAttemptID);
+      if (containerIdFromMap == null || !containerIdFromMap.equals(containerId)) {
+        // This can happen when a task heartbeats. Meanwhile the container is unregistered.
+        // The information will eventually make it through to the plugin via a corresponding unregister.
+        // There's a race in that case between the unregister making it through, and this method returning.
+        // TODO TEZ-2003. An exception back is likely a better approach than sending a shouldDie = true,
+        // so that the plugin can handle the scenario. Alternately augment the response with error codes.
+        // Error codes would be better than exceptions.
+        LOG.info("Attempt: " + taskAttemptID + " is not recognized for heartbeats");
+        return RESPONSE_SHOULD_DIE;
+      }
+
+      List<TezEvent> inEvents = request.getEvents();
       if (LOG.isDebugEnabled()) {
-        LOG.debug("Container with id: " + containerId + " asked for a task");
+        LOG.debug("Ping from " + taskAttemptID.toString() +
+            " events: " + (inEvents != null ? inEvents.size() : -1));
       }
-      if (!registeredContainers.containsKey(containerId)) {
-        if(context.getAllContainers().get(containerId) == null) {
-          LOG.info("Container with id: " + containerId
-              + " is invalid and will be killed");
-        } else {
-          LOG.info("Container with id: " + containerId
-              + " is valid, but no longer registered, and will be killed");
-        }
-        task = TASK_FOR_INVALID_JVM;
-      } else {
-        pingContainerHeartbeatHandler(containerId);
-        task = getContainerTask(containerId);
-        if (task == null) {
-          if (LOG.isDebugEnabled()) {
-            LOG.debug("No task current assigned to Container with id: " + containerId);
-          }
-        } else if (task == TASK_FOR_INVALID_JVM) { 
-          LOG.info("Container with id: " + containerId
-              + " is valid, but no longer registered, and will be killed. Race condition.");          
+
+      List<TezEvent> otherEvents = new ArrayList<TezEvent>();
+      for (TezEvent tezEvent : ListUtils.emptyIfNull(inEvents)) {
+        final EventType eventType = tezEvent.getEventType();
+        if (eventType == EventType.TASK_STATUS_UPDATE_EVENT ||
+            eventType == EventType.TASK_ATTEMPT_COMPLETED_EVENT) {
+          context.getEventHandler()
+              .handle(getTaskAttemptEventFromTezEvent(taskAttemptID, tezEvent));
         } else {
-          context.getEventHandler().handle(
-              new TaskAttemptEventStartedRemotely(task.getTaskSpec()
-                  .getTaskAttemptID(), containerId, context
-                  .getApplicationACLs()));
-          LOG.info("Container with id: " + containerId + " given task: "
-              + task.getTaskSpec().getTaskAttemptID());
+          otherEvents.add(tezEvent);
         }
       }
+      if(!otherEvents.isEmpty()) {
+        TezVertexID vertexId = taskAttemptID.getTaskID().getVertexID();
+        context.getEventHandler().handle(
+            new VertexEventRouteEvent(vertexId, Collections.unmodifiableList(otherEvents)));
+      }
+      taskHeartbeatHandler.pinged(taskAttemptID);
+      outEvents = context
+          .getCurrentDAG()
+          .getVertex(taskAttemptID.getTaskID().getVertexID())
+          .getTask(taskAttemptID.getTaskID())
+          .getTaskAttemptTezEvents(taskAttemptID, request.getStartIndex(),
+              request.getMaxEvents());
     }
-    if (LOG.isDebugEnabled()) {
-      LOG.debug("getTask returning task: " + task);
-    }
-    return task;
+    return new TaskHeartbeatResponse(false, outEvents);
+  }
+
+  @Override
+  public boolean isKnownContainer(ContainerId containerId) {
+    return context.getAllContainers().get(containerId) != null;
+  }
+
+  @Override
+  public void taskStartedRemotely(TezTaskAttemptID taskAttemptID, ContainerId containerId) {
+    context.getEventHandler().handle(new TaskAttemptEventStartedRemotely(taskAttemptID, containerId, null));
+    pingContainerHeartbeatHandler(containerId);
   }
 
   /**
    * Child checking whether it can commit.
-   *
+   * <p/>
    * <br/>
    * Repeatedly polls the ApplicationMaster whether it
    * {@link Task#canCommit(TezTaskAttemptID)} This is * a legacy from the
@@ -270,25 +265,12 @@ public class TaskAttemptListenerImpTezDag extends AbstractService implements
   }
 
   @Override
-  public void unregisterTaskAttempt(TezTaskAttemptID attemptId) {
-    ContainerId containerId = attemptToInfoMap.get(attemptId);
-    if(containerId == null) {
-      LOG.warn("Unregister task attempt: " + attemptId + " from unknown container");
-      return;
-    }
-    ContainerInfo containerInfo = registeredContainers.get(containerId);
-    if(containerInfo == null) {
-      LOG.warn("Unregister task attempt: " + attemptId +
-          " from non-registered container: " + containerId);
-      return;
-    }
-    synchronized (containerInfo) {
-      containerInfo.amContainerTask = null;
-      attemptToInfoMap.remove(attemptId);
-    }
-
+  public InetSocketAddress getAddress() {
+    return taskCommunicator.getAddress();
   }
 
+  // The TaskAttemptListener register / unregister methods in this class are not thread safe.
+  // The Tez framework should not invoke these methods from multiple threads.
   @Override
   public void dagComplete(DAG dag) {
     // TODO TEZ-2335. Cleanup TaskHeartbeat handler structures.
@@ -308,50 +290,82 @@ public class TaskAttemptListenerImpTezDag extends AbstractService implements
   @Override
   public void registerRunningContainer(ContainerId containerId) {
     if (LOG.isDebugEnabled()) {
-      LOG.debug("ContainerId: " + containerId
-          + " registered with TaskAttemptListener");
+      LOG.debug("ContainerId: " + containerId + " registered with TaskAttemptListener");
     }
-    ContainerInfo oldInfo = registeredContainers.put(containerId, new ContainerInfo());
-    if(oldInfo != null) {
+    ContainerInfo oldInfo = registeredContainers.put(containerId, NULL_CONTAINER_INFO);
+    if (oldInfo != null) {
       throw new TezUncheckedException(
           "Multiple registrations for containerId: " + containerId);
     }
+    NodeId nodeId = context.getAllContainers().get(containerId).getContainer().getNodeId();
+    taskCommunicator.registerRunningContainer(containerId, nodeId.getHost(), nodeId.getPort());
+  }
+
+  @Override
+  public void unregisterRunningContainer(ContainerId containerId) {
+    if (LOG.isDebugEnabled()) {
+      LOG.debug("Unregistering Container from TaskAttemptListener: " + containerId);
+    }
+    ContainerInfo containerInfo = registeredContainers.remove(containerId);
+    if (containerInfo.taskAttemptId != null) {
+      registeredAttempts.remove(containerInfo.taskAttemptId);
+    }
+    taskCommunicator.registerContainerEnd(containerId);
   }
 
   @Override
   public void registerTaskAttempt(AMContainerTask amContainerTask,
-      ContainerId containerId) {
+                                  ContainerId containerId) {
     ContainerInfo containerInfo = registeredContainers.get(containerId);
-    if(containerInfo == null) {
+    if (containerInfo == null) {
       throw new TezUncheckedException("Registering task attempt: "
           + amContainerTask.getTask().getTaskAttemptID() + " to unknown container: " + containerId);
     }
-    synchronized (containerInfo) {
-      if(containerInfo.amContainerTask != null) {
-        throw new TezUncheckedException("Registering task attempt: "
-            + amContainerTask.getTask().getTaskAttemptID() + " to container: " + containerId
-            + " with existing assignment to: " + containerInfo.amContainerTask.getTask().getTaskAttemptID());
-      }
-      containerInfo.amContainerTask = amContainerTask;
-      containerInfo.taskPulled = false;
-
-      ContainerId containerIdFromMap =
-          attemptToInfoMap.put(amContainerTask.getTask().getTaskAttemptID(), containerId);
-      if(containerIdFromMap != null) {
-        throw new TezUncheckedException("Registering task attempt: "
-            + amContainerTask.getTask().getTaskAttemptID() + " to container: " + containerId
-            + " when already assigned to: " + containerIdFromMap);
-      }
+    if (containerInfo.taskAttemptId != null) {
+      throw new TezUncheckedException("Registering task attempt: "
+          + amContainerTask.getTask().getTaskAttemptID() + " to container: " + containerId
+          + " with existing assignment to: " +
+          containerInfo.taskAttemptId);
     }
+
+    if (containerInfo.taskAttemptId != null) {
+      throw new TezUncheckedException("Registering task attempt: "
+          + amContainerTask.getTask().getTaskAttemptID() + " to container: " + containerId
+          + " with existing assignment to: " +
+          containerInfo.taskAttemptId);
+    }
+
+    // Explicitly putting in a new entry so that synchronization is not required on the existing element in the map.
+    registeredContainers.put(containerId, new ContainerInfo(amContainerTask.getTask().getTaskAttemptID()));
+
+    ContainerId containerIdFromMap = registeredAttempts.put(
+        amContainerTask.getTask().getTaskAttemptID(), containerId);
+    if (containerIdFromMap != null) {
+      throw new TezUncheckedException("Registering task attempt: "
+          + amContainerTask.getTask().getTaskAttemptID() + " to container: " + containerId
+          + " when already assigned to: " + containerIdFromMap);
+    }
+    taskCommunicator.registerRunningTaskAttempt(containerId, amContainerTask.getTask(),
+        amContainerTask.getAdditionalResources(), amContainerTask.getCredentials(),
+        amContainerTask.haveCredentialsChanged());
   }
 
   @Override
-  public void unregisterRunningContainer(ContainerId containerId) {
-    if (LOG.isDebugEnabled()) {
-      LOG.debug("Unregistering Container from TaskAttemptListener: "
-          + containerId);
+  public void unregisterTaskAttempt(TezTaskAttemptID attemptId) {
+    ContainerId containerId = registeredAttempts.remove(attemptId);
+    if (containerId == null) {
+      LOG.warn("Unregister task attempt: " + attemptId + " from unknown container");
+      return;
+    }
+    ContainerInfo containerInfo = registeredContainers.get(containerId);
+    if (containerInfo == null) {
+      LOG.warn("Unregister task attempt: " + attemptId +
+          " from non-registered container: " + containerId);
+      return;
     }
-    registeredContainers.remove(containerId);
+    // Explicitly putting in a new entry so that synchronization is not required on the existing element in the map.
+    registeredContainers.put(containerId, NULL_CONTAINER_INFO);
+    taskCommunicator.unregisterRunningTaskAttempt(attemptId);
   }
 
   private void pingContainerHeartbeatHandler(ContainerId containerId) {
@@ -359,7 +373,7 @@ public class TaskAttemptListenerImpTezDag extends AbstractService implements
   }
 
   private void pingContainerHeartbeatHandler(TezTaskAttemptID taskAttemptId) {
-    ContainerId containerId = attemptToInfoMap.get(taskAttemptId);
+    ContainerId containerId = registeredAttempts.get(taskAttemptId);
     if (containerId != null) {
       containerHeartbeatHandler.pinged(containerId);
     } else {
@@ -368,91 +382,6 @@ public class TaskAttemptListenerImpTezDag extends AbstractService implements
     }
   }
 
-  @Override
-  public TezHeartbeatResponse heartbeat(TezHeartbeatRequest request)
-      throws IOException, TezException {
-    ContainerId containerId = ConverterUtils.toContainerId(request
-        .getContainerIdentifier());
-    long requestId = request.getRequestId();
-    if (LOG.isDebugEnabled()) {
-      LOG.debug("Received heartbeat from container"
-          + ", request=" + request);
-    }
-
-    ContainerInfo containerInfo = registeredContainers.get(containerId);
-    if(containerInfo == null) {
-      LOG.warn("Received task heartbeat from unknown container with id: " + containerId +
-          ", asking it to die");
-      TezHeartbeatResponse response = new TezHeartbeatResponse();
-      response.setLastRequestId(requestId);
-      response.setShouldDie();
-      return response;
-    }
-
-    synchronized (containerInfo) {
-      pingContainerHeartbeatHandler(containerId);
-
-      if(containerInfo.lastRequestId == requestId) {
-        LOG.warn("Old sequenceId received: " + requestId
-            + ", Re-sending last response to client");
-        return containerInfo.lastReponse;
-      }
-
-      TezHeartbeatResponse response = new TezHeartbeatResponse();
-      response.setLastRequestId(requestId);
-
-      TezTaskAttemptID taskAttemptID = request.getCurrentTaskAttemptID();
-      if (taskAttemptID != null) {
-        ContainerId containerIdFromMap = attemptToInfoMap.get(taskAttemptID);
-        if(containerIdFromMap == null || !containerIdFromMap.equals(containerId)) {
-          throw new TezException("Attempt " + taskAttemptID
-            + " is not recognized for heartbeat");
-        }
-
-        if(containerInfo.lastRequestId+1 != requestId) {
-          throw new TezException("Container " + containerId
-              + " has invalid request id. Expected: "
-              + containerInfo.lastRequestId+1
-              + " and actual: " + requestId);
-        }
-
-        List<TezEvent> inEvents = request.getEvents();
-        if (LOG.isDebugEnabled()) {
-          LOG.debug("Ping from " + taskAttemptID.toString() +
-              " events: " + (inEvents != null? inEvents.size() : -1));
-        }
-
-        List<TezEvent> otherEvents = new ArrayList<TezEvent>();
-        for (TezEvent tezEvent : ListUtils.emptyIfNull(inEvents)) {
-          final EventType eventType = tezEvent.getEventType();
-          if (eventType == EventType.TASK_STATUS_UPDATE_EVENT ||
-              eventType == EventType.TASK_ATTEMPT_COMPLETED_EVENT) {
-            context.getEventHandler()
-                .handle(getTaskAttemptEventFromTezEvent(taskAttemptID, tezEvent));
-          } else {
-            otherEvents.add(tezEvent);
-          }
-        }
-        if(!otherEvents.isEmpty()) {
-          TezVertexID vertexId = taskAttemptID.getTaskID().getVertexID();
-          context.getEventHandler().handle(
-              new VertexEventRouteEvent(vertexId, Collections.unmodifiableList(otherEvents)));
-        }
-        taskHeartbeatHandler.pinged(taskAttemptID);
-        List<TezEvent> outEvents = context
-            .getCurrentDAG()
-            .getVertex(taskAttemptID.getTaskID().getVertexID())
-            .getTask(taskAttemptID.getTaskID())
-            .getTaskAttemptTezEvents(taskAttemptID, request.getStartIndex(),
-                request.getMaxEvents());
-        response.setEvents(outEvents);
-      }
-      containerInfo.lastRequestId = requestId;
-      containerInfo.lastReponse = response;
-      return response;
-    }
-  }
-
   private TaskAttemptEvent getTaskAttemptEventFromTezEvent(TezTaskAttemptID taskAttemptID,
                                                            TezEvent tezEvent) {
     final EventType eventType = tezEvent.getEventType();
@@ -475,51 +404,7 @@ public class TaskAttemptListenerImpTezDag extends AbstractService implements
     return taskAttemptEvent;
   }
 
-  private Map<String, TezLocalResource> convertLocalResourceMap(Map<String, LocalResource> ylrs)
-      throws IOException {
-    Map<String, TezLocalResource> tlrs = Maps.newHashMap();
-    if (ylrs != null) {
-      for (Entry<String, LocalResource> ylrEntry : ylrs.entrySet()) {
-        TezLocalResource tlr;
-        try {
-          tlr = TezConverterUtils.convertYarnLocalResourceToTez(ylrEntry.getValue());
-        } catch (URISyntaxException e) {
-         throw new IOException(e);
-        }
-        tlrs.put(ylrEntry.getKey(), tlr);
-      }
-    }
-    return tlrs;
-  }
-
-  private ContainerTask getContainerTask(ContainerId containerId) throws IOException {
-    ContainerTask containerTask = null;
-    ContainerInfo containerInfo = registeredContainers.get(containerId);
-    if (containerInfo == null) {
-      // This can happen if an unregisterTask comes in after we've done the initial checks for
-      // registered containers. (Race between getTask from the container, and a potential STOP_CONTAINER
-      // from somewhere within the AM)
-      // Implies that an un-registration has taken place and the container needs to be asked to die.
-      LOG.info("Container with id: " + containerId
-          + " is valid, but no longer registered, and will be killed");
-      containerTask = TASK_FOR_INVALID_JVM;
-    } else {
-      synchronized (containerInfo) {
-        if (containerInfo.amContainerTask != null) {
-          if (!containerInfo.taskPulled) {
-            containerInfo.taskPulled = true;
-            AMContainerTask amContainerTask = containerInfo.amContainerTask;
-            containerTask = new ContainerTask(amContainerTask.getTask(), false,
-                convertLocalResourceMap(amContainerTask.getAdditionalResources()),
-                amContainerTask.getCredentials(), amContainerTask.haveCredentialsChanged());
-          } else {
-            containerTask = null;
-          }
-        } else {
-          containerTask = null;
-        }
-      }
-    }
-    return containerTask;
+  public TaskCommunicator getTaskCommunicator() {
+    return taskCommunicator;
   }
 }

http://git-wip-us.apache.org/repos/asf/tez/blob/56986504/tez-dag/src/main/java/org/apache/tez/dag/app/TezTaskCommunicatorImpl.java
----------------------------------------------------------------------
diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/TezTaskCommunicatorImpl.java b/tez-dag/src/main/java/org/apache/tez/dag/app/TezTaskCommunicatorImpl.java
new file mode 100644
index 0000000..5652937
--- /dev/null
+++ b/tez-dag/src/main/java/org/apache/tez/dag/app/TezTaskCommunicatorImpl.java
@@ -0,0 +1,474 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.dag.app;
+
+import java.io.IOException;
+import java.net.InetAddress;
+import java.net.InetSocketAddress;
+import java.net.URISyntaxException;
+import java.net.UnknownHostException;
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ConcurrentMap;
+
+import com.google.common.base.Preconditions;
+import com.google.common.collect.Maps;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.classification.InterfaceAudience;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.CommonConfigurationKeysPublic;
+import org.apache.hadoop.ipc.ProtocolSignature;
+import org.apache.hadoop.ipc.RPC;
+import org.apache.hadoop.ipc.Server;
+import org.apache.hadoop.net.NetUtils;
+import org.apache.hadoop.security.Credentials;
+import org.apache.hadoop.security.authorize.PolicyProvider;
+import org.apache.hadoop.security.token.Token;
+import org.apache.hadoop.yarn.api.records.ContainerId;
+import org.apache.hadoop.yarn.api.records.LocalResource;
+import org.apache.hadoop.yarn.util.ConverterUtils;
+import org.apache.tez.common.*;
+import org.apache.tez.common.ContainerContext;
+import org.apache.tez.common.security.JobTokenIdentifier;
+import org.apache.tez.common.security.JobTokenSecretManager;
+import org.apache.tez.common.security.TokenCache;
+import org.apache.tez.dag.api.TaskCommunicator;
+import org.apache.tez.dag.api.TaskCommunicatorContext;
+import org.apache.tez.dag.api.TaskHeartbeatRequest;
+import org.apache.tez.dag.api.TaskHeartbeatResponse;
+import org.apache.tez.dag.api.TezConfiguration;
+import org.apache.tez.dag.api.TezException;
+import org.apache.tez.dag.api.TezUncheckedException;
+import org.apache.tez.dag.app.security.authorize.TezAMPolicyProvider;
+import org.apache.tez.dag.records.TezTaskAttemptID;
+import org.apache.tez.runtime.api.impl.TaskSpec;
+import org.apache.tez.runtime.api.impl.TezHeartbeatRequest;
+import org.apache.tez.runtime.api.impl.TezHeartbeatResponse;
+
+@InterfaceAudience.Private
+public class TezTaskCommunicatorImpl extends TaskCommunicator {
+
+  private static final Log LOG = LogFactory.getLog(TezTaskCommunicatorImpl.class);
+
+  private static final ContainerTask TASK_FOR_INVALID_JVM = new ContainerTask(
+      null, true, null, null, false);
+
+  private final TaskCommunicatorContext taskCommunicatorContext;
+
+  private final ConcurrentMap<ContainerId, ContainerInfo> registeredContainers =
+      new ConcurrentHashMap<ContainerId, ContainerInfo>();
+  private final ConcurrentMap<TaskAttempt, ContainerId> attemptToContainerMap =
+      new ConcurrentHashMap<TaskAttempt, ContainerId>();
+
+  private final TezTaskUmbilicalProtocol taskUmbilical;
+  private InetSocketAddress address;
+  private Server server;
+
+  private static final class ContainerInfo {
+
+    ContainerInfo(ContainerId containerId) {
+      this.containerId = containerId;
+    }
+
+    ContainerId containerId;
+    TezHeartbeatResponse lastResponse = null;
+    TaskSpec taskSpec = null;
+    long lastRequestId = 0;
+    Map<String, LocalResource> additionalLRs = null;
+    Credentials credentials = null;
+    boolean credentialsChanged = false;
+    boolean taskPulled = false;
+
+    void reset() {
+      taskSpec = null;
+      additionalLRs = null;
+      credentials = null;
+      credentialsChanged = false;
+      taskPulled = false;
+    }
+  }
+
+
+
+  /**
+   * Construct the service.
+   */
+  public TezTaskCommunicatorImpl(TaskCommunicatorContext taskCommunicatorContext) {
+    super(TezTaskCommunicatorImpl.class.getName());
+    this.taskCommunicatorContext = taskCommunicatorContext;
+    this.taskUmbilical = new TezTaskUmbilicalProtocolImpl();
+  }
+
+
+  @Override
+  public void serviceStart() {
+
+    startRpcServer();
+  }
+
+  @Override
+  public void serviceStop() {
+    stopRpcServer();
+  }
+
+  protected void startRpcServer() {
+    Configuration conf = getConfig();
+    if (!conf.getBoolean(TezConfiguration.TEZ_LOCAL_MODE, TezConfiguration.TEZ_LOCAL_MODE_DEFAULT)) {
+      try {
+        JobTokenSecretManager jobTokenSecretManager =
+            new JobTokenSecretManager();
+        Token<JobTokenIdentifier> sessionToken = TokenCache.getSessionToken(taskCommunicatorContext.getCredentials());
+        jobTokenSecretManager.addTokenForJob(
+            taskCommunicatorContext.getApplicationAttemptId().getApplicationId().toString(), sessionToken);
+
+        server = new RPC.Builder(conf)
+            .setProtocol(TezTaskUmbilicalProtocol.class)
+            .setBindAddress("0.0.0.0")
+            .setPort(0)
+            .setInstance(taskUmbilical)
+            .setNumHandlers(
+                conf.getInt(TezConfiguration.TEZ_AM_TASK_LISTENER_THREAD_COUNT,
+                    TezConfiguration.TEZ_AM_TASK_LISTENER_THREAD_COUNT_DEFAULT))
+            .setSecretManager(jobTokenSecretManager).build();
+
+        // Enable service authorization?
+        if (conf.getBoolean(
+            CommonConfigurationKeysPublic.HADOOP_SECURITY_AUTHORIZATION,
+            false)) {
+          refreshServiceAcls(conf, new TezAMPolicyProvider());
+        }
+
+        server.start();
+        this.address = NetUtils.getConnectAddress(server);
+      } catch (IOException e) {
+        throw new TezUncheckedException(e);
+      }
+    } else {
+      try {
+        this.address = new InetSocketAddress(InetAddress.getLocalHost(), 0);
+      } catch (UnknownHostException e) {
+        throw new TezUncheckedException(e);
+      }
+      if (LOG.isDebugEnabled()) {
+        LOG.debug("Not starting TaskAttemptListener RPC in LocalMode");
+      }
+    }
+  }
+
+  protected void stopRpcServer() {
+    if (server != null) {
+      server.stop();
+      server = null;
+    }
+  }
+
+  private void refreshServiceAcls(Configuration configuration,
+                                  PolicyProvider policyProvider) {
+    this.server.refreshServiceAcl(configuration, policyProvider);
+  }
+
+  @Override
+  public void registerRunningContainer(ContainerId containerId, String host, int port) {
+    ContainerInfo oldInfo = registeredContainers.putIfAbsent(containerId, new ContainerInfo(containerId));
+    if (oldInfo != null) {
+      throw new TezUncheckedException("Multiple registrations for containerId: " + containerId);
+    }
+  }
+
+  @Override
+  public void registerContainerEnd(ContainerId containerId) {
+    ContainerInfo containerInfo = registeredContainers.remove(containerId);
+    if (containerInfo != null) {
+      synchronized(containerInfo) {
+        if (containerInfo.taskSpec != null && containerInfo.taskSpec.getTaskAttemptID() != null) {
+          attemptToContainerMap.remove(containerInfo.taskSpec.getTaskAttemptID());
+        }
+      }
+    }
+  }
+
+  @Override
+  public void registerRunningTaskAttempt(ContainerId containerId, TaskSpec taskSpec,
+                                         Map<String, LocalResource> additionalResources,
+                                         Credentials credentials, boolean credentialsChanged) {
+
+    ContainerInfo containerInfo = registeredContainers.get(containerId);
+    Preconditions.checkNotNull(containerInfo,
+        "Cannot register task attempt: " + taskSpec.getTaskAttemptID() + " to unknown container: " +
+            containerId);
+    synchronized (containerInfo) {
+      if (containerInfo.taskSpec != null) {
+        throw new TezUncheckedException(
+            "Cannot register task: " + taskSpec.getTaskAttemptID() + " to container: " +
+                containerId + " , with pre-existing assignment: " +
+                containerInfo.taskSpec.getTaskAttemptID());
+      }
+      containerInfo.taskSpec = taskSpec;
+      containerInfo.additionalLRs = additionalResources;
+      containerInfo.credentials = credentials;
+      containerInfo.credentialsChanged = credentialsChanged;
+      containerInfo.taskPulled = false;
+
+      ContainerId oldId = attemptToContainerMap.putIfAbsent(new TaskAttempt(taskSpec.getTaskAttemptID()), containerId);
+      if (oldId != null) {
+        throw new TezUncheckedException(
+            "Attempting to register an already registered taskAttempt with id: " +
+                taskSpec.getTaskAttemptID() + " to containerId: " + containerId +
+                ". Already registered to containerId: " + oldId);
+      }
+    }
+
+  }
+
+  @Override
+  public void unregisterRunningTaskAttempt(TezTaskAttemptID taskAttemptID) {
+    TaskAttempt taskAttempt = new TaskAttempt(taskAttemptID);
+    ContainerId containerId = attemptToContainerMap.remove(taskAttempt);
+    if(containerId == null) {
+      LOG.warn("Unregister task attempt: " + taskAttempt + " from unknown container");
+      return;
+    }
+    ContainerInfo containerInfo = registeredContainers.get(containerId);
+    if (containerInfo == null) {
+      LOG.warn("Unregister task attempt: " + taskAttempt +
+          " from non-registered container: " + containerId);
+      return;
+    }
+    synchronized (containerInfo) {
+      containerInfo.reset();
+      attemptToContainerMap.remove(taskAttempt);
+    }
+  }
+
+  @Override
+  public InetSocketAddress getAddress() {
+    return address;
+  }
+
+  public TezTaskUmbilicalProtocol getUmbilical() {
+    return this.taskUmbilical;
+  }
+
+  private class TezTaskUmbilicalProtocolImpl implements TezTaskUmbilicalProtocol {
+
+    @Override
+    public ContainerTask getTask(ContainerContext containerContext) throws IOException {
+      ContainerTask task = null;
+      if (containerContext == null || containerContext.getContainerIdentifier() == null) {
+        LOG.info("Invalid task request with an empty containerContext or containerId");
+        task = TASK_FOR_INVALID_JVM;
+      } else {
+        ContainerId containerId = ConverterUtils.toContainerId(containerContext
+            .getContainerIdentifier());
+        if (LOG.isDebugEnabled()) {
+          LOG.debug("Container with id: " + containerId + " asked for a task");
+        }
+        task = getContainerTask(containerId);
+        if (task != null && !task.shouldDie()) {
+          taskCommunicatorContext
+              .taskStartedRemotely(task.getTaskSpec().getTaskAttemptID(), containerId);
+        }
+      }
+      if (LOG.isDebugEnabled()) {
+        LOG.debug("getTask returning task: " + task);
+      }
+      return task;
+    }
+
+    @Override
+    public boolean canCommit(TezTaskAttemptID taskAttemptId) throws IOException {
+      return taskCommunicatorContext.canCommit(taskAttemptId);
+    }
+
+    @Override
+    public TezHeartbeatResponse heartbeat(TezHeartbeatRequest request) throws IOException,
+        TezException {
+      ContainerId containerId = ConverterUtils.toContainerId(request.getContainerIdentifier());
+      long requestId = request.getRequestId();
+      if (LOG.isDebugEnabled()) {
+        LOG.debug("Received heartbeat from container"
+            + ", request=" + request);
+      }
+
+      ContainerInfo containerInfo = registeredContainers.get(containerId);
+      if (containerInfo == null) {
+        LOG.warn("Received task heartbeat from unknown container with id: " + containerId +
+            ", asking it to die");
+        TezHeartbeatResponse response = new TezHeartbeatResponse();
+        response.setLastRequestId(requestId);
+        response.setShouldDie();
+        return response;
+      }
+
+      synchronized (containerInfo) {
+        if (containerInfo.lastRequestId == requestId) {
+          LOG.warn("Old sequenceId received: " + requestId
+              + ", Re-sending last response to client");
+          return containerInfo.lastResponse;
+        }
+      }
+
+      TaskHeartbeatResponse tResponse = null;
+
+
+      TezTaskAttemptID taskAttemptID = request.getCurrentTaskAttemptID();
+      if (taskAttemptID != null) {
+        synchronized (containerInfo) {
+          ContainerId containerIdFromMap = attemptToContainerMap.get(new TaskAttempt(taskAttemptID));
+          if (containerIdFromMap == null || !containerIdFromMap.equals(containerId)) {
+            throw new TezException("Attempt " + taskAttemptID
+                + " is not recognized for heartbeat");
+          }
+
+          if (containerInfo.lastRequestId + 1 != requestId) {
+            throw new TezException("Container " + containerId
+                + " has invalid request id. Expected: "
+                + containerInfo.lastRequestId + 1
+                + " and actual: " + requestId);
+          }
+        }
+        TaskHeartbeatRequest tRequest = new TaskHeartbeatRequest(request.getContainerIdentifier(),
+            request.getCurrentTaskAttemptID(), request.getEvents(), request.getStartIndex(),
+            request.getMaxEvents());
+        tResponse = taskCommunicatorContext.heartbeat(tRequest);
+      }
+      TezHeartbeatResponse response;
+      if (tResponse == null) {
+        response = new TezHeartbeatResponse();
+      } else {
+        response = new TezHeartbeatResponse(tResponse.getEvents());
+      }
+      response.setLastRequestId(requestId);
+      containerInfo.lastRequestId = requestId;
+      containerInfo.lastResponse = response;
+      return response;
+    }
+
+
+    // TODO Remove this method once we move to the Protobuf RPC engine
+    @Override
+    public long getProtocolVersion(String protocol, long clientVersion) throws IOException {
+      return versionID;
+    }
+
+    // TODO Remove this method once we move to the Protobuf RPC engine
+    @Override
+    public ProtocolSignature getProtocolSignature(String protocol, long clientVersion,
+                                                  int clientMethodsHash) throws IOException {
+      return ProtocolSignature.getProtocolSignature(this, protocol,
+          clientVersion, clientMethodsHash);
+    }
+  }
+
+  private ContainerTask getContainerTask(ContainerId containerId) throws IOException {
+    ContainerInfo containerInfo = registeredContainers.get(containerId);
+    ContainerTask task = null;
+    if (containerInfo == null) {
+      if (taskCommunicatorContext.isKnownContainer(containerId)) {
+        LOG.info("Container with id: " + containerId
+            + " is valid, but no longer registered, and will be killed");
+      } else {
+        LOG.info("Container with id: " + containerId
+            + " is invalid and will be killed");
+      }
+      task = TASK_FOR_INVALID_JVM;
+    } else {
+      synchronized (containerInfo) {
+        if (containerInfo.taskSpec != null) {
+          if (!containerInfo.taskPulled) {
+            containerInfo.taskPulled = true;
+            task = constructContainerTask(containerInfo);
+          } else {
+            if (LOG.isDebugEnabled()) {
+              LOG.debug("Task " + containerInfo.taskSpec.getTaskAttemptID() +
+                  " already sent to container: " + containerId);
+            }
+            task = null;
+          }
+        } else {
+          task = null;
+          if (LOG.isDebugEnabled()) {
+            LOG.debug("No task assigned yet for running container: " + containerId);
+          }
+        }
+      }
+    }
+    return task;
+  }
+
+  private ContainerTask constructContainerTask(ContainerInfo containerInfo) throws IOException {
+    return new ContainerTask(containerInfo.taskSpec, false,
+        convertLocalResourceMap(containerInfo.additionalLRs), containerInfo.credentials,
+        containerInfo.credentialsChanged);
+  }
+
+  private Map<String, TezLocalResource> convertLocalResourceMap(Map<String, LocalResource> ylrs)
+      throws IOException {
+    Map<String, TezLocalResource> tlrs = Maps.newHashMap();
+    if (ylrs != null) {
+      for (Map.Entry<String, LocalResource> ylrEntry : ylrs.entrySet()) {
+        TezLocalResource tlr;
+        try {
+          tlr = TezConverterUtils.convertYarnLocalResourceToTez(ylrEntry.getValue());
+        } catch (URISyntaxException e) {
+          throw new IOException(e);
+        }
+        tlrs.put(ylrEntry.getKey(), tlr);
+      }
+    }
+    return tlrs;
+  }
+
+
+  // Holder for Task information, which eventually will likely be VertexImplm taskIndex, attemptIndex
+  private static class TaskAttempt {
+    // TODO TEZ-2003 Change this to work with VertexName, int id, int version
+    // TODO TEZ-2003 Avoid constructing this unit all over the place
+    private TezTaskAttemptID taskAttemptId;
+
+    TaskAttempt(TezTaskAttemptID taskAttemptId) {
+      this.taskAttemptId = taskAttemptId;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+      if (this == o) {
+        return true;
+      }
+      if (!(o instanceof TaskAttempt)) {
+        return false;
+      }
+
+      TaskAttempt that = (TaskAttempt) o;
+
+      if (!taskAttemptId.equals(that.taskAttemptId)) {
+        return false;
+      }
+
+      return true;
+    }
+
+    @Override
+    public int hashCode() {
+      return taskAttemptId.hashCode();
+    }
+
+    @Override
+    public String toString() {
+      return "TaskAttempt{" + "taskAttemptId=" + taskAttemptId + '}';
+    }
+  }
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/tez/blob/56986504/tez-dag/src/main/java/org/apache/tez/dag/app/launcher/LocalContainerLauncher.java
----------------------------------------------------------------------
diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/launcher/LocalContainerLauncher.java b/tez-dag/src/main/java/org/apache/tez/dag/app/launcher/LocalContainerLauncher.java
index 9faf8c0..e9ba9d7 100644
--- a/tez-dag/src/main/java/org/apache/tez/dag/app/launcher/LocalContainerLauncher.java
+++ b/tez-dag/src/main/java/org/apache/tez/dag/app/launcher/LocalContainerLauncher.java
@@ -59,6 +59,8 @@ import org.apache.tez.dag.api.TezException;
 import org.apache.tez.dag.api.TezUncheckedException;
 import org.apache.tez.dag.app.AppContext;
 import org.apache.tez.dag.app.TaskAttemptListener;
+import org.apache.tez.dag.app.TaskAttemptListenerImpTezDag;
+import org.apache.tez.dag.app.TezTaskCommunicatorImpl;
 import org.apache.tez.dag.app.rm.NMCommunicatorEvent;
 import org.apache.tez.dag.app.rm.NMCommunicatorLaunchRequestEvent;
 import org.apache.tez.dag.app.rm.NMCommunicatorStopRequestEvent;
@@ -86,7 +88,7 @@ public class LocalContainerLauncher extends AbstractService implements
 
   private static final Logger LOG = LoggerFactory.getLogger(LocalContainerLauncher.class);
   private final AppContext context;
-  private final TaskAttemptListener taskAttemptListener;
+  private final TezTaskUmbilicalProtocol taskUmbilicalProtocol;
   private final AtomicBoolean serviceStopped = new AtomicBoolean(false);
   private final String workingDirectory;
   private final Map<String, String> localEnv = new HashMap<String, String>();
@@ -114,7 +116,9 @@ public class LocalContainerLauncher extends AbstractService implements
                                 String workingDirectory) throws UnknownHostException {
     super(LocalContainerLauncher.class.getName());
     this.context = context;
-    this.taskAttemptListener = taskAttemptListener;
+    TaskAttemptListenerImpTezDag taListener = (TaskAttemptListenerImpTezDag)taskAttemptListener;
+    TezTaskCommunicatorImpl taskComm = (TezTaskCommunicatorImpl) taListener.getTaskCommunicator();
+    this.taskUmbilicalProtocol = taskComm.getUmbilical();
     this.workingDirectory = workingDirectory;
     AuxiliaryServiceHelper.setServiceDataIntoEnv(
         ShuffleUtils.SHUFFLE_HANDLER_SERVICE_ID, ByteBuffer.allocate(4).putInt(0), localEnv);
@@ -215,7 +219,7 @@ public class LocalContainerLauncher extends AbstractService implements
         tezChild =
             createTezChild(context.getAMConf(), event.getContainerId(), tokenIdentifier,
                 context.getApplicationAttemptId().getAttemptId(), context.getLocalDirs(),
-                (TezTaskUmbilicalProtocol) taskAttemptListener,
+                taskUmbilicalProtocol,
                 TezCommonUtils.parseCredentialsBytes(event.getContainerLaunchContext().getTokens().array()));
       } catch (InterruptedException e) {
         handleLaunchFailed(e, event.getContainerId());

http://git-wip-us.apache.org/repos/asf/tez/blob/56986504/tez-dag/src/main/java/org/apache/tez/dag/app/rm/container/AMContainer.java
----------------------------------------------------------------------
diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/rm/container/AMContainer.java b/tez-dag/src/main/java/org/apache/tez/dag/app/rm/container/AMContainer.java
index a6b403d..0fc2e12 100644
--- a/tez-dag/src/main/java/org/apache/tez/dag/app/rm/container/AMContainer.java
+++ b/tez-dag/src/main/java/org/apache/tez/dag/app/rm/container/AMContainer.java
@@ -22,6 +22,7 @@ import java.util.List;
 
 import org.apache.hadoop.yarn.api.records.Container;
 import org.apache.hadoop.yarn.api.records.ContainerId;
+import org.apache.hadoop.yarn.api.records.NodeId;
 import org.apache.hadoop.yarn.event.EventHandler;
 import org.apache.tez.dag.records.TezTaskAttemptID;
 
@@ -32,5 +33,5 @@ public interface AMContainer extends EventHandler<AMContainerEvent>{
   public Container getContainer();
   public List<TezTaskAttemptID> getAllTaskAttempts();
   public TezTaskAttemptID getCurrentTaskAttempt();
-  
+
 }

http://git-wip-us.apache.org/repos/asf/tez/blob/56986504/tez-dag/src/main/java/org/apache/tez/dag/app/rm/container/AMContainerEventAssignTA.java
----------------------------------------------------------------------
diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/rm/container/AMContainerEventAssignTA.java b/tez-dag/src/main/java/org/apache/tez/dag/app/rm/container/AMContainerEventAssignTA.java
index 682cd02..0398882 100644
--- a/tez-dag/src/main/java/org/apache/tez/dag/app/rm/container/AMContainerEventAssignTA.java
+++ b/tez-dag/src/main/java/org/apache/tez/dag/app/rm/container/AMContainerEventAssignTA.java
@@ -27,6 +27,8 @@ import org.apache.tez.runtime.api.impl.TaskSpec;
 
 public class AMContainerEventAssignTA extends AMContainerEvent {
 
+  // TODO TEZ-2003. Add the task priority to this event.
+
   private final TezTaskAttemptID attemptId;
   // TODO Maybe have tht TAL pull the remoteTask from the TaskAttempt itself ?
   private final TaskSpec remoteTaskSpec;

http://git-wip-us.apache.org/repos/asf/tez/blob/56986504/tez-dag/src/main/java/org/apache/tez/dag/app/rm/container/AMContainerImpl.java
----------------------------------------------------------------------
diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/rm/container/AMContainerImpl.java b/tez-dag/src/main/java/org/apache/tez/dag/app/rm/container/AMContainerImpl.java
index 330f2b7..1acec9c 100644
--- a/tez-dag/src/main/java/org/apache/tez/dag/app/rm/container/AMContainerImpl.java
+++ b/tez-dag/src/main/java/org/apache/tez/dag/app/rm/container/AMContainerImpl.java
@@ -35,6 +35,7 @@ import org.apache.hadoop.yarn.api.records.ContainerExitStatus;
 import org.apache.hadoop.yarn.api.records.ContainerId;
 import org.apache.hadoop.yarn.api.records.ContainerLaunchContext;
 import org.apache.hadoop.yarn.api.records.LocalResource;
+import org.apache.hadoop.yarn.api.records.NodeId;
 import org.apache.hadoop.yarn.event.Event;
 import org.apache.hadoop.yarn.event.EventHandler;
 import org.apache.hadoop.yarn.state.InvalidStateTransitonException;

http://git-wip-us.apache.org/repos/asf/tez/blob/56986504/tez-dag/src/test/java/org/apache/tez/dag/app/MockDAGAppMaster.java
----------------------------------------------------------------------
diff --git a/tez-dag/src/test/java/org/apache/tez/dag/app/MockDAGAppMaster.java b/tez-dag/src/test/java/org/apache/tez/dag/app/MockDAGAppMaster.java
index 18286b5..0b2ea2f 100644
--- a/tez-dag/src/test/java/org/apache/tez/dag/app/MockDAGAppMaster.java
+++ b/tez-dag/src/test/java/org/apache/tez/dag/app/MockDAGAppMaster.java
@@ -37,6 +37,7 @@ import java.util.concurrent.atomic.AtomicLong;
 import org.apache.tez.dag.app.dag.DAG;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
+import com.google.common.annotations.VisibleForTesting;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.security.Credentials;
 import org.apache.hadoop.service.AbstractService;
@@ -50,7 +51,10 @@ import org.apache.tez.client.TezApiVersionInfo;
 import org.apache.tez.common.ContainerContext;
 import org.apache.tez.common.ContainerTask;
 import org.apache.tez.common.counters.TezCounters;
+import org.apache.tez.dag.api.TaskHeartbeatRequest;
+import org.apache.tez.dag.api.TaskHeartbeatResponse;
 import org.apache.tez.dag.api.TezConfiguration;
+import org.apache.tez.dag.api.TaskCommunicator;
 import org.apache.tez.dag.api.TezUncheckedException;
 import org.apache.tez.dag.app.launcher.ContainerLauncher;
 import org.apache.tez.dag.app.rm.NMCommunicatorEvent;
@@ -72,8 +76,6 @@ import org.apache.tez.runtime.api.impl.TaskSpec;
 import org.apache.tez.runtime.api.impl.TaskStatistics;
 import org.apache.tez.runtime.api.impl.TezEvent;
 import org.apache.tez.runtime.api.impl.EventMetaData.EventProducerConsumerType;
-import org.apache.tez.runtime.api.impl.TezHeartbeatRequest;
-import org.apache.tez.runtime.api.impl.TezHeartbeatResponse;
 
 import com.google.common.base.Preconditions;
 import com.google.common.collect.Lists;
@@ -126,6 +128,7 @@ public class MockDAGAppMaster extends DAGAppMaster {
     Map<ContainerId, ContainerData> containers = Maps.newConcurrentMap();
     ArrayBlockingQueue<Worker> workers;
     TaskAttemptListenerImpTezDag taListener;
+    TezTaskCommunicatorImpl taskCommunicator;
     
     AtomicBoolean startScheduling = new AtomicBoolean(true);
     AtomicBoolean goFlag;
@@ -188,6 +191,7 @@ public class MockDAGAppMaster extends DAGAppMaster {
     @Override
     public void serviceStart() throws Exception {
       taListener = (TaskAttemptListenerImpTezDag) getTaskAttemptListener();
+      taskCommunicator = (TezTaskCommunicatorImpl) taListener.getTaskCommunicator();
       eventHandlingThread = new Thread(this);
       eventHandlingThread.start();
       ExecutorService rawExecutor = Executors.newFixedThreadPool(handlerConcurrency,
@@ -327,10 +331,10 @@ public class MockDAGAppMaster extends DAGAppMaster {
       }
     }
     
-    private void doHeartbeat(TezHeartbeatRequest request, ContainerData cData) throws Exception {
+    private void doHeartbeat(TaskHeartbeatRequest request, ContainerData cData) throws Exception {
       long startTime = System.nanoTime();
       long startCpuTime = threadMxBean.getCurrentThreadCpuTime();
-      TezHeartbeatResponse response = taListener.heartbeat(request);
+      TaskHeartbeatResponse response = taListener.heartbeat(request);
       if (response.shouldDie()) {
         cData.remove();
       } else {
@@ -381,7 +385,8 @@ public class MockDAGAppMaster extends DAGAppMaster {
         try {
           if (cData.taId == null) {
             // if container is not assigned a task, ask for a task
-            ContainerTask cTask = taListener.getTask(new ContainerContext(cData.cIdStr));
+            ContainerTask cTask =
+                taskCommunicator.getUmbilical().getTask(new ContainerContext(cData.cIdStr));
             if (cTask != null) {
               if (cTask.shouldDie()) {
                 cData.remove();
@@ -427,8 +432,9 @@ public class MockDAGAppMaster extends DAGAppMaster {
               float progress = updateProgress ? cData.numUpdates/maxUpdates : 0f;
               events.add(new TezEvent(new TaskStatusUpdateEvent(counters, progress, stats), new EventMetaData(
                   EventProducerConsumerType.SYSTEM, cData.vName, "", cData.taId)));
-              TezHeartbeatRequest request = new TezHeartbeatRequest(cData.numUpdates, events,
-                  cData.cIdStr, cData.taId, cData.nextFromEventId, 10000);
+              TaskHeartbeatRequest request =
+                  new TaskHeartbeatRequest(cData.cIdStr, cData.taId, events, cData.nextFromEventId,
+                      10000);
               doHeartbeat(request, cData);
             } else if (version != null && cData.taId.getId() <= version.intValue()) {
               preemptContainer(cData);
@@ -438,8 +444,9 @@ public class MockDAGAppMaster extends DAGAppMaster {
               List<TezEvent> events = Collections.singletonList(new TezEvent(
                   new TaskAttemptCompletedEvent(), new EventMetaData(
                       EventProducerConsumerType.SYSTEM, cData.vName, "", cData.taId)));
-              TezHeartbeatRequest request = new TezHeartbeatRequest(++cData.numUpdates, events,
-                  cData.cIdStr, cData.taId, cData.nextFromEventId, 10000);
+              TaskHeartbeatRequest request =
+                  new TaskHeartbeatRequest(cData.cIdStr, cData.taId, events, cData.nextFromEventId,
+                      10000);
               doHeartbeat(request, cData);
               cData.clear();
             }

http://git-wip-us.apache.org/repos/asf/tez/blob/56986504/tez-dag/src/test/java/org/apache/tez/dag/app/TestTaskAttemptListenerImplTezDag.java
----------------------------------------------------------------------
diff --git a/tez-dag/src/test/java/org/apache/tez/dag/app/TestTaskAttemptListenerImplTezDag.java b/tez-dag/src/test/java/org/apache/tez/dag/app/TestTaskAttemptListenerImplTezDag.java
index ec4f99a..286e897 100644
--- a/tez-dag/src/test/java/org/apache/tez/dag/app/TestTaskAttemptListenerImplTezDag.java
+++ b/tez-dag/src/test/java/org/apache/tez/dag/app/TestTaskAttemptListenerImplTezDag.java
@@ -1,16 +1,16 @@
 /*
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+*     http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
 
 package org.apache.tez.dag.app;
 
@@ -18,6 +18,7 @@ import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertNull;
 import static org.junit.Assert.assertTrue;
+import static org.mockito.Matchers.any;
 import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.times;
@@ -33,6 +34,7 @@ import java.util.Map;
 import org.apache.hadoop.yarn.api.records.ApplicationAccessType;
 import org.apache.hadoop.yarn.api.records.ApplicationAttemptId;
 import org.apache.hadoop.yarn.api.records.ApplicationId;
+import org.apache.hadoop.yarn.api.records.Container;
 import org.apache.hadoop.yarn.api.records.ContainerId;
 import org.apache.hadoop.yarn.event.Event;
 import org.apache.hadoop.yarn.event.EventHandler;
@@ -40,6 +42,12 @@ import org.apache.tez.common.ContainerContext;
 import org.apache.tez.common.ContainerTask;
 import org.apache.tez.common.security.JobTokenSecretManager;
 import org.apache.tez.dag.api.TezException;
+import org.apache.hadoop.yarn.api.records.NodeId;
+import org.apache.hadoop.yarn.event.EventHandler;
+import org.apache.tez.common.ContainerContext;
+import org.apache.tez.common.ContainerTask;
+import org.apache.tez.common.TezTaskUmbilicalProtocol;
+import org.apache.tez.dag.api.TaskCommunicatorContext;
 import org.apache.tez.dag.app.dag.DAG;
 import org.apache.tez.dag.app.dag.Task;
 import org.apache.tez.dag.app.dag.Vertex;
@@ -99,9 +107,18 @@ public class TestTaskAttemptListenerImplTezDag {
     doReturn(dag).when(appContext).getCurrentDAG();
     doReturn(appAcls).when(appContext).getApplicationACLs();
     doReturn(amContainerMap).when(appContext).getAllContainers();
-
-    taskAttemptListener = new TaskAttemptListenerImplForTest(appContext,
-        mock(TaskHeartbeatHandler.class), mock(ContainerHeartbeatHandler.class), null);
+    NodeId nodeId = NodeId.newInstance("localhost", 0);
+    AMContainer amContainer = mock(AMContainer.class);
+    Container container = mock(Container.class);
+    doReturn(nodeId).when(container).getNodeId();
+    doReturn(amContainer).when(amContainerMap).get(any(ContainerId.class));
+    doReturn(container).when(amContainer).getContainer();
+
+    taskAttemptListener =
+        new TaskAttemptListenerImpTezDag(appContext, mock(TaskHeartbeatHandler.class),
+            mock(ContainerHeartbeatHandler.class), null);
+    TezTaskCommunicatorImpl taskCommunicator = (TezTaskCommunicatorImpl)taskAttemptListener.getTaskCommunicator();
+    TezTaskUmbilicalProtocol tezUmbilical = taskCommunicator.getUmbilical();
 
     taskSpec = mock(TaskSpec.class);
     doReturn(taskAttemptID).when(taskSpec).getTaskAttemptID();
@@ -113,32 +130,30 @@ public class TestTaskAttemptListenerImplTezDag {
   public void testGetTask() throws IOException {
 
     ContainerId containerId1 = createContainerId(appId, 1);
-    doReturn(mock(AMContainer.class)).when(amContainerMap).get(containerId1);
     ContainerContext containerContext1 = new ContainerContext(containerId1.toString());
-    containerTask = taskAttemptListener.getTask(containerContext1);
+    containerTask = tezUmbilical.getTask(containerContext1);
     assertTrue(containerTask.shouldDie());
 
     ContainerId containerId2 = createContainerId(appId, 2);
-    doReturn(mock(AMContainer.class)).when(amContainerMap).get(containerId2);
     ContainerContext containerContext2 = new ContainerContext(containerId2.toString());
     taskAttemptListener.registerRunningContainer(containerId2);
-    containerTask = taskAttemptListener.getTask(containerContext2);
+    containerTask = tezUmbilical.getTask(containerContext2);
     assertNull(containerTask);
 
     // Valid task registered
     taskAttemptListener.registerTaskAttempt(amContainerTask, containerId2);
-    containerTask = taskAttemptListener.getTask(containerContext2);
+    containerTask = tezUmbilical.getTask(containerContext2);
     assertFalse(containerTask.shouldDie());
     assertEquals(taskSpec, containerTask.getTaskSpec());
 
     // Task unregistered. Should respond to heartbeats
-    taskAttemptListener.unregisterTaskAttempt(taskAttemptID);
-    containerTask = taskAttemptListener.getTask(containerContext2);
+    taskAttemptListener.unregisterTaskAttempt(taskAttemptId);
+    containerTask = tezUmbilical.getTask(containerContext2);
     assertNull(containerTask);
 
     // Container unregistered. Should send a shouldDie = true
     taskAttemptListener.unregisterRunningContainer(containerId2);
-    containerTask = taskAttemptListener.getTask(containerContext2);
+    containerTask = tezUmbilical.getTask(containerContext2);
     assertTrue(containerTask.shouldDie());
 
     ContainerId containerId3 = createContainerId(appId, 3);
@@ -152,27 +167,30 @@ public class TestTaskAttemptListenerImplTezDag {
     AMContainerTask amContainerTask2 = new AMContainerTask(taskSpec, null, null, false, 0);
     taskAttemptListener.registerTaskAttempt(amContainerTask2, containerId3);
     taskAttemptListener.unregisterRunningContainer(containerId3);
-    containerTask = taskAttemptListener.getTask(containerContext3);
+    containerTask = tezUmbilical.getTask(containerContext3);
     assertTrue(containerTask.shouldDie());
   }
 
   @Test(timeout = 5000)
   public void testGetTaskMultiplePulls() throws IOException {
+    TezTaskCommunicatorImpl taskCommunicator = (TezTaskCommunicatorImpl)taskAttemptListener.getTaskCommunicator();
+    TezTaskUmbilicalProtocol tezUmbilical = taskCommunicator.getUmbilical();
+
     ContainerId containerId1 = createContainerId(appId, 1);
     doReturn(mock(AMContainer.class)).when(amContainerMap).get(containerId1);
     ContainerContext containerContext1 = new ContainerContext(containerId1.toString());
     taskAttemptListener.registerRunningContainer(containerId1);
-    containerTask = taskAttemptListener.getTask(containerContext1);
+    containerTask = tezUmbilical.getTask(containerContext1);
     assertNull(containerTask);
 
     // Register task
     taskAttemptListener.registerTaskAttempt(amContainerTask, containerId1);
-    containerTask = taskAttemptListener.getTask(containerContext1);
+    containerTask = tezUmbilical.getTask(containerContext1);
     assertFalse(containerTask.shouldDie());
     assertEquals(taskSpec, containerTask.getTaskSpec());
 
     // Try pulling again - simulates re-use pull
-    containerTask = taskAttemptListener.getTask(containerContext1);
+    containerTask = tezUmbilical.getTask(containerContext1);
     assertNull(containerTask);
   }
 
@@ -250,13 +268,11 @@ public class TestTaskAttemptListenerImplTezDag {
     return ContainerId.newInstance(appAttemptId, containerIdx);
   }
 
-  private static class TaskAttemptListenerImplForTest extends TaskAttemptListenerImpTezDag {
+  private static class TezTaskCommunicatorImplForTest extends TezTaskCommunicatorImpl {
 
-    public TaskAttemptListenerImplForTest(AppContext context,
-                                          TaskHeartbeatHandler thh,
-                                          ContainerHeartbeatHandler chh,
-                                          JobTokenSecretManager jobTokenSecretManager) {
-      super(context, thh, chh, jobTokenSecretManager);
+    public TezTaskCommunicatorImplForTest(
+        TaskCommunicatorContext taskCommunicatorContext) {
+      super(taskCommunicatorContext);
     }
 
     @Override