You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tez.apache.org by ss...@apache.org on 2015/02/10 22:52:28 UTC

[16/17] tez git commit: TEZ-2045. TaskAttemptListener should not pull Tasks from AMContainer. (sseth)

TEZ-2045. TaskAttemptListener should not pull Tasks from AMContainer.
(sseth)


Project: http://git-wip-us.apache.org/repos/asf/tez/repo
Commit: http://git-wip-us.apache.org/repos/asf/tez/commit/fe39ede3
Tree: http://git-wip-us.apache.org/repos/asf/tez/tree/fe39ede3
Diff: http://git-wip-us.apache.org/repos/asf/tez/diff/fe39ede3

Branch: refs/heads/TEZ-2003
Commit: fe39ede3305bab665fcdbca07fd381be0e875e80
Parents: f035468
Author: Siddharth Seth <ss...@apache.org>
Authored: Tue Feb 10 13:42:23 2015 -0800
Committer: Siddharth Seth <ss...@apache.org>
Committed: Tue Feb 10 13:42:23 2015 -0800

----------------------------------------------------------------------
 CHANGES.txt                                     |   1 +
 .../apache/tez/dag/app/TaskAttemptListener.java |  35 +--
 .../dag/app/TaskAttemptListenerImpTezDag.java   |  93 +++---
 .../tez/dag/app/rm/container/AMContainer.java   |   5 +-
 .../app/rm/container/AMContainerEventType.java  |   3 -
 .../dag/app/rm/container/AMContainerImpl.java   | 304 ++++++------------
 .../dag/app/rm/container/AMContainerTask.java   |  10 +-
 .../app/TestTaskAttemptListenerImplTezDag.java  | 182 +++++++++++
 .../dag/app/rm/container/TestAMContainer.java   | 313 ++++++++++---------
 9 files changed, 505 insertions(+), 441 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/tez/blob/fe39ede3/CHANGES.txt
----------------------------------------------------------------------
diff --git a/CHANGES.txt b/CHANGES.txt
index d617bee..9979c50 100644
--- a/CHANGES.txt
+++ b/CHANGES.txt
@@ -7,6 +7,7 @@ Release 0.7.0: Unreleased
 INCOMPATIBLE CHANGES
 
 ALL CHANGES:
+  TEZ-2045. TaskAttemptListener should not pull Tasks from AMContainer. Instead these should be registered with the listener.
   TEZ-1914. VertexManager logic should not run on the central dispatcher
   TEZ-2023. Refactor logIndividualFetchComplete() to be common for both shuffle-schedulers.
   TEZ-1999. IndexOutOfBoundsException during merge.

http://git-wip-us.apache.org/repos/asf/tez/blob/fe39ede3/tez-dag/src/main/java/org/apache/tez/dag/app/TaskAttemptListener.java
----------------------------------------------------------------------
diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/TaskAttemptListener.java b/tez-dag/src/main/java/org/apache/tez/dag/app/TaskAttemptListener.java
index e80c8b3..aeb0cd5 100644
--- a/tez-dag/src/main/java/org/apache/tez/dag/app/TaskAttemptListener.java
+++ b/tez-dag/src/main/java/org/apache/tez/dag/app/TaskAttemptListener.java
@@ -21,6 +21,7 @@ package org.apache.tez.dag.app;
 import java.net.InetSocketAddress;
 
 import org.apache.hadoop.yarn.api.records.ContainerId;
+import org.apache.tez.dag.app.rm.container.AMContainerTask;
 import org.apache.tez.dag.records.TezTaskAttemptID;
 /**
  * This class listens for changes to the state of a Task.
@@ -30,41 +31,11 @@ public interface TaskAttemptListener {
   InetSocketAddress getAddress();
 
   void registerRunningContainer(ContainerId containerId);
-//  void registerRunningJvm(WrappedJvmID jvmID, ContainerId containerId);
-  
-  void registerTaskAttempt(TezTaskAttemptID attemptId, ContainerId containerId);
-  
-//  void registerTaskAttempt(TezTaskAttemptID attemptId, WrappedJvmID jvmId);
+
+  void registerTaskAttempt(AMContainerTask amContainerTask, ContainerId containerId);
   
   void unregisterRunningContainer(ContainerId containerId);
   
-//  void unregisterRunningJvm(WrappedJvmID jvmID);
-  
   void unregisterTaskAttempt(TezTaskAttemptID attemptID);
-  /**
-   * Register a JVM with the listener.  This should be called as soon as a 
-   * JVM ID is assigned to a task attempt, before it has been launched.
-   * @param task the task itself for this JVM.
-   * @param jvmID The ID of the JVM .
-   */
-//  void registerPendingTask(Task task, WrappedJvmID jvmID);
-  
-  /**
-   * Register task attempt. This should be called when the JVM has been
-   * launched.
-   * 
-   * @param attemptID
-   *          the id of the attempt for this JVM.
-   * @param jvmID the ID of the JVM.
-   */
-//  void registerLaunchedTask(TezTaskAttemptID attemptID, WrappedJvmID jvmID);
-
-  /**
-   * Unregister the JVM and the attempt associated with it.  This should be 
-   * called when the attempt/JVM has finished executing and is being cleaned up.
-   * @param attemptID the ID of the attempt.
-   * @param jvmID the ID of the JVM for that attempt.
-   */
-//  void unregister(TezTaskAttemptID attemptID, WrappedJvmID jvmID);
 
 }

http://git-wip-us.apache.org/repos/asf/tez/blob/fe39ede3/tez-dag/src/main/java/org/apache/tez/dag/app/TaskAttemptListenerImpTezDag.java
----------------------------------------------------------------------
diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/TaskAttemptListenerImpTezDag.java b/tez-dag/src/main/java/org/apache/tez/dag/app/TaskAttemptListenerImpTezDag.java
index b1cb3f6..28f2c32 100644
--- a/tez-dag/src/main/java/org/apache/tez/dag/app/TaskAttemptListenerImpTezDag.java
+++ b/tez-dag/src/main/java/org/apache/tez/dag/app/TaskAttemptListenerImpTezDag.java
@@ -53,7 +53,6 @@ import org.apache.tez.dag.app.dag.DAG;
 import org.apache.tez.dag.app.dag.Task;
 import org.apache.tez.dag.app.dag.event.TaskAttemptEventStartedRemotely;
 import org.apache.tez.dag.app.dag.event.VertexEventRouteEvent;
-import org.apache.tez.dag.app.rm.container.AMContainerImpl;
 import org.apache.tez.dag.app.rm.container.AMContainerTask;
 import org.apache.tez.dag.app.security.authorize.TezAMPolicyProvider;
 import org.apache.tez.dag.records.TezTaskAttemptID;
@@ -87,11 +86,13 @@ public class TaskAttemptListenerImpTezDag extends AbstractService implements
     ContainerInfo() {
       this.lastReponse = null;
       this.lastRequestId = 0;
-      this.currentAttemptId = null;
+      this.amContainerTask = null;
+      this.taskPulled = false;
     }
     long lastRequestId;
     TezHeartbeatResponse lastReponse;
-    TezTaskAttemptID currentAttemptId;
+    AMContainerTask amContainerTask;
+    boolean taskPulled;
   }
 
   private ConcurrentMap<TezTaskAttemptID, ContainerId> attemptToInfoMap =
@@ -212,30 +213,18 @@ public class TaskAttemptListenerImpTezDag extends AbstractService implements
         task = TASK_FOR_INVALID_JVM;
       } else {
         pingContainerHeartbeatHandler(containerId);
-        AMContainerTask taskContext = pullTaskAttemptContext(containerId);
-        if (taskContext.shouldDie()) {
-          LOG.info("No more tasks for container with id : " + containerId
-              + ". Asking it to die");
-          task = TASK_FOR_INVALID_JVM; // i.e. ask the child to die.
+        task = getContainerTask(containerId);
+        if (task == null) {
+          if (LOG.isDebugEnabled()) {
+            LOG.debug("No task current assigned to Container with id: " + containerId);
+          }
         } else {
-          if (taskContext.getTask() == null) {
-            if (LOG.isDebugEnabled()) {
-              LOG.debug("No task currently assigned to Container with id: "
-                  + containerId);
-            }
-          } else {
-            registerTaskAttempt(taskContext.getTask().getTaskAttemptID(),
-                containerId);
-            task = new ContainerTask(taskContext.getTask(), false,
-                convertLocalResourceMap(taskContext.getAdditionalResources()),
-                taskContext.getCredentials(), taskContext.haveCredentialsChanged());
             context.getEventHandler().handle(
-                new TaskAttemptEventStartedRemotely(taskContext.getTask()
+                new TaskAttemptEventStartedRemotely(task.getTaskSpec()
                     .getTaskAttemptID(), containerId, context
                     .getApplicationACLs()));
             LOG.info("Container with id: " + containerId + " given task: "
-                + taskContext.getTask().getTaskAttemptID());
-          }
+                + task.getTaskSpec().getTaskAttemptID());
         }
       }
     }
@@ -283,18 +272,12 @@ public class TaskAttemptListenerImpTezDag extends AbstractService implements
       return;
     }
     synchronized (containerInfo) {
-      containerInfo.currentAttemptId = null;
+      containerInfo.amContainerTask = null;
       attemptToInfoMap.remove(attemptId);
     }
 
   }
 
-  public AMContainerTask pullTaskAttemptContext(ContainerId containerId) {
-    AMContainerImpl container = (AMContainerImpl) context.getAllContainers()
-        .get(containerId);
-    return container.pullTaskContext();
-  }
-
   @Override
   public void registerRunningContainer(ContainerId containerId) {
     if (LOG.isDebugEnabled()) {
@@ -309,24 +292,27 @@ public class TaskAttemptListenerImpTezDag extends AbstractService implements
   }
 
   @Override
-  public void registerTaskAttempt(TezTaskAttemptID attemptId,
+  public void registerTaskAttempt(AMContainerTask amContainerTask,
       ContainerId containerId) {
     ContainerInfo containerInfo = registeredContainers.get(containerId);
     if(containerInfo == null) {
       throw new TezUncheckedException("Registering task attempt: "
-          + attemptId + " to unknown container: " + containerId);
+          + amContainerTask.getTask().getTaskAttemptID() + " to unknown container: " + containerId);
     }
     synchronized (containerInfo) {
-      if(containerInfo.currentAttemptId != null) {
+      if(containerInfo.amContainerTask != null) {
         throw new TezUncheckedException("Registering task attempt: "
-            + attemptId + " to container: " + containerId
-            + " with existing assignment to: " + containerInfo.currentAttemptId);
+            + amContainerTask.getTask().getTaskAttemptID() + " to container: " + containerId
+            + " with existing assignment to: " + containerInfo.amContainerTask.getTask().getTaskAttemptID());
       }
-      containerInfo.currentAttemptId = attemptId;
-      ContainerId containerIdFromMap = attemptToInfoMap.put(attemptId, containerId);
+      containerInfo.amContainerTask = amContainerTask;
+      containerInfo.taskPulled = false;
+
+      ContainerId containerIdFromMap =
+          attemptToInfoMap.put(amContainerTask.getTask().getTaskAttemptID(), containerId);
       if(containerIdFromMap != null) {
         throw new TezUncheckedException("Registering task attempt: "
-            + attemptId + " to container: " + containerId
+            + amContainerTask.getTask().getTaskAttemptID() + " to container: " + containerId
             + " when already assigned to: " + containerIdFromMap);
       }
     }
@@ -368,6 +354,8 @@ public class TaskAttemptListenerImpTezDag extends AbstractService implements
 
     ContainerInfo containerInfo = registeredContainers.get(containerId);
     if(containerInfo == null) {
+      LOG.warn("Received task heartbeat from unknown container with id: " + containerId +
+          ", asking it to die");
       TezHeartbeatResponse response = new TezHeartbeatResponse();
       response.setLastRequestId(requestId);
       response.setShouldDie();
@@ -442,4 +430,35 @@ public class TaskAttemptListenerImpTezDag extends AbstractService implements
     }
     return tlrs;
   }
+
+  private ContainerTask getContainerTask(ContainerId containerId) throws IOException {
+    ContainerTask containerTask = null;
+    ContainerInfo containerInfo = registeredContainers.get(containerId);
+    if (containerInfo == null) {
+      // This can happen if an unregisterTask comes in after we've done the initial checks for
+      // registered containers. (Race between getTask from the container, and a potential STOP_CONTAINER
+      // from somewhere within the AM)
+      // Implies that an un-registration has taken place and the container needs to be asked to die.
+      LOG.info("Container with id: " + containerId
+          + " is valid, but no longer registered, and will be killed");
+      containerTask = TASK_FOR_INVALID_JVM;
+    } else {
+      synchronized (containerInfo) {
+        if (containerInfo.amContainerTask != null) {
+          if (!containerInfo.taskPulled) {
+            containerInfo.taskPulled = true;
+            AMContainerTask amContainerTask = containerInfo.amContainerTask;
+            containerTask = new ContainerTask(amContainerTask.getTask(), false,
+                convertLocalResourceMap(amContainerTask.getAdditionalResources()),
+                amContainerTask.getCredentials(), amContainerTask.haveCredentialsChanged());
+          } else {
+            containerTask = null;
+          }
+        } else {
+          containerTask = null;
+        }
+      }
+    }
+    return containerTask;
+  }
 }

http://git-wip-us.apache.org/repos/asf/tez/blob/fe39ede3/tez-dag/src/main/java/org/apache/tez/dag/app/rm/container/AMContainer.java
----------------------------------------------------------------------
diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/rm/container/AMContainer.java b/tez-dag/src/main/java/org/apache/tez/dag/app/rm/container/AMContainer.java
index e00ad3d..a6b403d 100644
--- a/tez-dag/src/main/java/org/apache/tez/dag/app/rm/container/AMContainer.java
+++ b/tez-dag/src/main/java/org/apache/tez/dag/app/rm/container/AMContainer.java
@@ -31,9 +31,6 @@ public interface AMContainer extends EventHandler<AMContainerEvent>{
   public ContainerId getContainerId();
   public Container getContainer();
   public List<TezTaskAttemptID> getAllTaskAttempts();
-  public TezTaskAttemptID getRunningTaskAttempt();
-  public List<TezTaskAttemptID> getQueuedTaskAttempts();
+  public TezTaskAttemptID getCurrentTaskAttempt();
   
-  // TODO Add a method to get the containers capabilities - to match taskAttempts.
-
 }

http://git-wip-us.apache.org/repos/asf/tez/blob/fe39ede3/tez-dag/src/main/java/org/apache/tez/dag/app/rm/container/AMContainerEventType.java
----------------------------------------------------------------------
diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/rm/container/AMContainerEventType.java b/tez-dag/src/main/java/org/apache/tez/dag/app/rm/container/AMContainerEventType.java
index 582ec91..330ad57 100644
--- a/tez-dag/src/main/java/org/apache/tez/dag/app/rm/container/AMContainerEventType.java
+++ b/tez-dag/src/main/java/org/apache/tez/dag/app/rm/container/AMContainerEventType.java
@@ -28,9 +28,6 @@ public enum AMContainerEventType {
   C_LAUNCHED,
   C_LAUNCH_FAILED,
 
-  //Producer: TAL: PULL_TA is a sync call.
-  C_PULL_TA,
-
   //Producer: Scheduler via TA
   C_TA_SUCCEEDED, // maybe change this to C_TA_FINISHED with a status.
 

http://git-wip-us.apache.org/repos/asf/tez/blob/fe39ede3/tez-dag/src/main/java/org/apache/tez/dag/app/rm/container/AMContainerImpl.java
----------------------------------------------------------------------
diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/rm/container/AMContainerImpl.java b/tez-dag/src/main/java/org/apache/tez/dag/app/rm/container/AMContainerImpl.java
index 5c5a8c5..f72e62a 100644
--- a/tez-dag/src/main/java/org/apache/tez/dag/app/rm/container/AMContainerImpl.java
+++ b/tez-dag/src/main/java/org/apache/tez/dag/app/rm/container/AMContainerImpl.java
@@ -18,7 +18,6 @@
 
 package org.apache.tez.dag.app.rm.container;
 
-import java.util.Collections;
 import java.util.EnumSet;
 import java.util.HashMap;
 import java.util.LinkedList;
@@ -62,8 +61,6 @@ import org.apache.tez.dag.history.events.ContainerStoppedEvent;
 import org.apache.tez.dag.records.TaskAttemptTerminationCause;
 import org.apache.tez.dag.records.TezDAGID;
 import org.apache.tez.dag.records.TezTaskAttemptID;
-//import org.apache.tez.dag.app.dag.event.TaskAttemptEventDiagnosticsUpdate;
-import org.apache.tez.runtime.api.impl.TaskSpec;
 
 import com.google.common.annotations.VisibleForTesting;
 import com.google.common.base.Preconditions;
@@ -87,12 +84,6 @@ public class AMContainerImpl implements AMContainer {
   private final List<TezTaskAttemptID> completedAttempts =
       new LinkedList<TezTaskAttemptID>();
 
-  // TODO Maybe this should be pulled from the TaskAttempt.s
-  private final Map<TezTaskAttemptID, TaskSpec> remoteTaskMap =
-      new HashMap<TezTaskAttemptID, TaskSpec>();
-
-  // TODO ?? Convert to list and hash.
-
   private long idleTimeBetweenTasks = 0;
   private long lastTaskFinishTime;
 
@@ -103,17 +94,8 @@ public class AMContainerImpl implements AMContainer {
   // be modelled as a separate state.
   private boolean nodeFailed = false;
 
-  private TezTaskAttemptID pendingAttempt;
-  private TezTaskAttemptID runningAttempt;
+  private TezTaskAttemptID currentAttempt;
   private List<TezTaskAttemptID> failedAssignments;
-  private TezTaskAttemptID pullAttempt;
-
-  private AMContainerTask noAllocationContainerTask;
-
-  private static final AMContainerTask NO_MORE_TASKS = new AMContainerTask(
-      true, null, null, null, false);
-  private static final AMContainerTask WAIT_TASK = new AMContainerTask(false,
-      null, null, null, false);
 
   private boolean inError = false;
 
@@ -160,26 +142,19 @@ public class AMContainerImpl implements AMContainer {
           AMContainerState.COMPLETED,
           EnumSet.of(AMContainerEventType.C_LAUNCHED,
               AMContainerEventType.C_LAUNCH_FAILED,
-              AMContainerEventType.C_PULL_TA,
               AMContainerEventType.C_TA_SUCCEEDED,
               AMContainerEventType.C_NM_STOP_SENT,
               AMContainerEventType.C_NM_STOP_FAILED,
               AMContainerEventType.C_TIMED_OUT), new ErrorTransition())
-
       .addTransition(
           AMContainerState.LAUNCHING,
-          EnumSet.of(AMContainerState.LAUNCHING,
-              AMContainerState.STOP_REQUESTED),
+          EnumSet.of(AMContainerState.LAUNCHING, AMContainerState.STOP_REQUESTED),
           AMContainerEventType.C_ASSIGN_TA, new AssignTaskAttemptTransition())
-      .addTransition(AMContainerState.LAUNCHING, AMContainerState.IDLE,
+      .addTransition(AMContainerState.LAUNCHING,
+          EnumSet.of(AMContainerState.IDLE, AMContainerState.RUNNING),
           AMContainerEventType.C_LAUNCHED, new LaunchedTransition())
       .addTransition(AMContainerState.LAUNCHING, AMContainerState.STOPPING,
           AMContainerEventType.C_LAUNCH_FAILED, new LaunchFailedTransition())
-      // TODO CREUSE : Maybe, consider sending back an attempt if the container
-      // asks for one in this state. Waiting for a LAUNCHED event from the
-      // NMComm may delay the task allocation.
-      .addTransition(AMContainerState.LAUNCHING, AMContainerState.LAUNCHING,
-          AMContainerEventType.C_PULL_TA)
       // Is assuming the pullAttempt will be null.
       .addTransition(AMContainerState.LAUNCHING, AMContainerState.COMPLETED,
           AMContainerEventType.C_COMPLETED,
@@ -201,12 +176,9 @@ public class AMContainerImpl implements AMContainer {
           new ErrorAtLaunchingTransition())
 
       .addTransition(AMContainerState.IDLE,
-          EnumSet.of(AMContainerState.IDLE, AMContainerState.STOP_REQUESTED),
+          EnumSet.of(AMContainerState.RUNNING, AMContainerState.STOP_REQUESTED),
           AMContainerEventType.C_ASSIGN_TA,
-          new AssignTaskAttemptAtIdleTransition())
-      .addTransition(AMContainerState.IDLE,
-          EnumSet.of(AMContainerState.RUNNING, AMContainerState.IDLE),
-          AMContainerEventType.C_PULL_TA, new PullTAAtIdleTransition())
+          new AssignTaskAttemptTransition())
       .addTransition(AMContainerState.IDLE, AMContainerState.COMPLETED,
           AMContainerEventType.C_COMPLETED, new CompletedAtIdleTransition())
       .addTransition(AMContainerState.IDLE, AMContainerState.STOP_REQUESTED,
@@ -230,8 +202,6 @@ public class AMContainerImpl implements AMContainer {
       .addTransition(AMContainerState.RUNNING, AMContainerState.STOP_REQUESTED,
           AMContainerEventType.C_ASSIGN_TA,
           new AssignTaskAttemptAtRunningTransition())
-      .addTransition(AMContainerState.RUNNING, AMContainerState.RUNNING,
-          AMContainerEventType.C_PULL_TA)
       .addTransition(AMContainerState.RUNNING, AMContainerState.IDLE,
           AMContainerEventType.C_TA_SUCCEEDED,
           new TASucceededAtRunningTransition())
@@ -259,9 +229,6 @@ public class AMContainerImpl implements AMContainer {
           AMContainerState.STOP_REQUESTED, AMContainerEventType.C_ASSIGN_TA,
           new AssignTAAtWindDownTransition())
       .addTransition(AMContainerState.STOP_REQUESTED,
-          AMContainerState.STOP_REQUESTED, AMContainerEventType.C_PULL_TA,
-          new PullTAAfterStopTransition())
-      .addTransition(AMContainerState.STOP_REQUESTED,
           AMContainerState.COMPLETED, AMContainerEventType.C_COMPLETED,
           new CompletedAtWindDownTransition())
       .addTransition(AMContainerState.STOP_REQUESTED,
@@ -285,10 +252,10 @@ public class AMContainerImpl implements AMContainer {
           AMContainerEventType.C_LAUNCH_REQUEST,
           new ErrorAtNMStopRequestedTransition())
 
+
+
       .addTransition(AMContainerState.STOPPING, AMContainerState.STOPPING,
           AMContainerEventType.C_ASSIGN_TA, new AssignTAAtWindDownTransition())
-      .addTransition(AMContainerState.STOPPING, AMContainerState.STOPPING,
-          AMContainerEventType.C_PULL_TA, new PullTAAfterStopTransition())
       // TODO This transition is wrong. Should be a noop / error.
       .addTransition(AMContainerState.STOPPING, AMContainerState.COMPLETED,
           AMContainerEventType.C_COMPLETED, new CompletedAtWindDownTransition())
@@ -311,8 +278,6 @@ public class AMContainerImpl implements AMContainer {
       .addTransition(AMContainerState.COMPLETED, AMContainerState.COMPLETED,
           AMContainerEventType.C_ASSIGN_TA, new AssignTAAtCompletedTransition())
       .addTransition(AMContainerState.COMPLETED, AMContainerState.COMPLETED,
-          AMContainerEventType.C_PULL_TA, new PullTAAfterStopTransition())
-      .addTransition(AMContainerState.COMPLETED, AMContainerState.COMPLETED,
           AMContainerEventType.C_NODE_FAILED, new NodeFailedBaseTransition())
       .addTransition(
           AMContainerState.COMPLETED,
@@ -348,7 +313,6 @@ public class AMContainerImpl implements AMContainer {
     this.containerHeartbeatHandler = chh;
     this.taskAttemptListener = tal;
     this.failedAssignments = new LinkedList<TezTaskAttemptID>();
-    this.noAllocationContainerTask = WAIT_TASK;
     this.stateMachine = stateMachineFactory.make(this);
   }
 
@@ -379,11 +343,8 @@ public class AMContainerImpl implements AMContainer {
       List<TezTaskAttemptID> allAttempts = new LinkedList<TezTaskAttemptID>();
       allAttempts.addAll(this.completedAttempts);
       allAttempts.addAll(this.failedAssignments);
-      if (this.pendingAttempt != null) {
-        allAttempts.add(this.pendingAttempt);
-      }
-      if (this.runningAttempt != null) {
-        allAttempts.add(this.runningAttempt);
+      if (this.currentAttempt != null) {
+        allAttempts.add(this.currentAttempt);
       }
       return allAttempts;
     } finally {
@@ -392,24 +353,10 @@ public class AMContainerImpl implements AMContainer {
   }
 
   @Override
-  public List<TezTaskAttemptID> getQueuedTaskAttempts() {
-    readLock.lock();
-    try {
-      if (pendingAttempt != null) {
-        return Collections.singletonList(this.pendingAttempt);
-      } else {
-        return Collections.emptyList();
-      }
-    } finally {
-      readLock.unlock();
-    }
-  }
-
-  @Override
-  public TezTaskAttemptID getRunningTaskAttempt() {
+  public TezTaskAttemptID getCurrentTaskAttempt() {
     readLock.lock();
     try {
-      return this.runningAttempt;
+      return this.currentAttempt;
     } finally {
       readLock.unlock();
     }
@@ -453,32 +400,6 @@ public class AMContainerImpl implements AMContainer {
     this.eventHandler.handle(event);
   }
 
-  // Push the TaskAttempt to the TAL, instead of the TAL pulling when a JVM asks
-  // for a TaskAttempt.
-  public AMContainerTask pullTaskContext() {
-    this.writeLock.lock();
-    try {
-      this.handle(
-          new AMContainerEvent(containerId, AMContainerEventType.C_PULL_TA));
-      if (pullAttempt == null) {
-        // As a later optimization, it should be possible for a running container to localize
-        // additional resources before a task is assigned to the container.
-        return noAllocationContainerTask;
-      } else {
-        // Avoid sending credentials if credentials have not changed.
-        AMContainerTask amContainerTask = new AMContainerTask(false,
-            remoteTaskMap.remove(pullAttempt), this.additionalLocalResources,
-            this.credentialsChanged ? this.credentials : null, this.credentialsChanged);
-        this.additionalLocalResources = null;
-        this.credentialsChanged = false;
-        this.pullAttempt = null;
-        return amContainerTask;
-      }
-    } finally {
-      this.writeLock.unlock();
-    }
-  }
-
   //////////////////////////////////////////////////////////////////////////////
   //                   Start of Transition Classes                            //
   //////////////////////////////////////////////////////////////////////////////
@@ -591,14 +512,15 @@ public class AMContainerImpl implements AMContainer {
     public AMContainerState transition(
         AMContainerImpl container, AMContainerEvent cEvent) {
       AMContainerEventAssignTA event = (AMContainerEventAssignTA) cEvent;
-      if (container.pendingAttempt != null) {
+      if (LOG.isDebugEnabled()) {
+        LOG.debug("AssignTaskAttempt at state " + container.getState() + ", attempt: " +
+            ((AMContainerEventAssignTA) cEvent).getRemoteTaskSpec());
+      }
+      if (container.currentAttempt != null) {
         // This may include a couple of additional (harmless) unregister calls
         // to the taskAttemptListener and containerHeartbeatHandler - in case
         // of assign at any state prior to IDLE.
-        container.handleExtraTAAssign(event, container.pendingAttempt);
-        // TODO XXX: Verify that it's ok to send in a NM_STOP_REQUEST. The
-        // NMCommunicator should be able to handle this. The STOP_REQUEST would
-        // only go out after the START_REQUEST.
+        container.handleExtraTAAssign(event, container.currentAttempt);
         return AMContainerState.STOP_REQUESTED;
       }
       
@@ -609,7 +531,7 @@ public class AMContainerImpl implements AMContainer {
           container.containerLocalResources, taskLocalResources);
       // Register the additional resources back for this container.
       container.containerLocalResources.putAll(container.additionalLocalResources);
-      container.pendingAttempt = event.getTaskAttemptId();
+      container.currentAttempt = event.getTaskAttemptId();
       if (LOG.isDebugEnabled()) {
         LOG.debug("AssignTA: attempt: " + event.getRemoteTaskSpec());
         LOG.debug("AdditionalLocalResources: " + container.additionalLocalResources);
@@ -625,17 +547,46 @@ public class AMContainerImpl implements AMContainer {
         container.credentialsChanged = false;
       }
 
-      container.remoteTaskMap
-          .put(event.getTaskAttemptId(), event.getRemoteTaskSpec());
-      return container.getState();
+      if (container.lastTaskFinishTime != 0) {
+        // This effectively measures the time during which nothing was scheduler to execute on a container.
+        // The time from this point to the task actually being available to containers needs to be computed elsewhere.
+        long idleTimeDiff =
+            System.currentTimeMillis() - container.lastTaskFinishTime;
+        container.idleTimeBetweenTasks += idleTimeDiff;
+        if (LOG.isDebugEnabled()) {
+          LOG.debug("Computing idle (scheduling) time for container: " +
+              container.getContainerId() + ", lastFinishTime: " +
+              container.lastTaskFinishTime + ", Incremented by: " +
+              idleTimeDiff);
+        }
+      }
+
+      LOG.info("Assigned taskAttempt + [" + container.currentAttempt +
+          "] to container: [" + container.getContainerId() + "]");
+      AMContainerTask amContainerTask = new AMContainerTask(
+          event.getRemoteTaskSpec(), container.additionalLocalResources,
+          container.credentialsChanged ? container.credentials : null, container.credentialsChanged);
+      container.registerAttemptWithListener(amContainerTask);
+      container.additionalLocalResources = null;
+      container.credentialsChanged = false;
+      if (container.getState() == AMContainerState.IDLE) {
+        return AMContainerState.RUNNING;
+      } else {
+        return container.getState();
+      }
     }
   }
 
-  protected static class LaunchedTransition implements
-      SingleArcTransition<AMContainerImpl, AMContainerEvent> {
+  protected static class LaunchedTransition
+      implements MultipleArcTransition<AMContainerImpl, AMContainerEvent, AMContainerState> {
     @Override
-    public void transition(AMContainerImpl container, AMContainerEvent cEvent) {
+    public AMContainerState transition(AMContainerImpl container, AMContainerEvent cEvent) {
       container.registerWithContainerListener();
+      if (container.currentAttempt != null) {
+        return AMContainerState.RUNNING;
+      } else {
+        return AMContainerState.IDLE;
+      }
     }
   }
 
@@ -643,11 +594,11 @@ public class AMContainerImpl implements AMContainer {
       SingleArcTransition<AMContainerImpl, AMContainerEvent> {
     @Override
     public void transition(AMContainerImpl container, AMContainerEvent cEvent) {
-      if (container.pendingAttempt != null) {
+      if (container.currentAttempt != null) {
         AMContainerEventLaunchFailed event = (AMContainerEventLaunchFailed) cEvent;
         // for a properly setup cluster this should almost always be an app error
         // need to differentiate between launch failed due to framework/cluster or app
-        container.sendTerminatingToTaskAttempt(container.pendingAttempt,
+        container.sendTerminatingToTaskAttempt(container.currentAttempt,
             event.getMessage(), TaskAttemptTerminationCause.CONTAINER_LAUNCH_FAILED);
       }
       container.unregisterFromTAListener();
@@ -660,22 +611,22 @@ public class AMContainerImpl implements AMContainer {
     @Override
     public void transition(AMContainerImpl container, AMContainerEvent cEvent) {
       AMContainerEventCompleted event = (AMContainerEventCompleted) cEvent;
-      if (container.pendingAttempt != null) {
+      if (container.currentAttempt!= null) {
         String errorMessage = getMessage(container, event);
         if (event.isSystemAction()) {
-          container.sendContainerTerminatedBySystemToTaskAttempt(container.pendingAttempt,
+          container.sendContainerTerminatedBySystemToTaskAttempt(container.currentAttempt,
               errorMessage, event.getTerminationCause());
         } else {
           container
               .sendTerminatedToTaskAttempt(
-                  container.pendingAttempt,
+                  container.currentAttempt,
                   errorMessage,
                   // if termination cause is generic exited then replace with specific
                   (event.getTerminationCause() == TaskAttemptTerminationCause.CONTAINER_EXITED ? 
                       TaskAttemptTerminationCause.CONTAINER_LAUNCH_FAILED : event.getTerminationCause()));
         }
-        container.registerFailedAttempt(container.pendingAttempt);
-        container.pendingAttempt = null;
+        container.registerFailedAttempt(container.currentAttempt);
+        container.currentAttempt = null;
         LOG.warn(errorMessage);
       }
       container.containerLocalResources = null;
@@ -702,12 +653,12 @@ public class AMContainerImpl implements AMContainer {
 
     @Override
     public void transition(AMContainerImpl container, AMContainerEvent cEvent) {
-      if (container.pendingAttempt != null) {
-        container.sendTerminatingToTaskAttempt(container.pendingAttempt,
+      if (container.currentAttempt != null) {
+        container.sendTerminatingToTaskAttempt(container.currentAttempt,
             getMessage(container, cEvent), TaskAttemptTerminationCause.CONTAINER_STOPPED);
       }
       container.unregisterFromTAListener();
-      container.logStopped(container.pendingAttempt == null ? 
+      container.logStopped(container.currentAttempt == null ?
           ContainerExitStatus.SUCCESS 
           : ContainerExitStatus.INVALID);
       container.sendStopRequestToNM();
@@ -742,18 +693,11 @@ public class AMContainerImpl implements AMContainer {
         container.sendNodeFailureToTA(taId, errorMessage, TaskAttemptTerminationCause.NODE_FAILED);
       }
 
-      if (container.pendingAttempt != null) {
+      if (container.currentAttempt != null) {
         // Will be null in COMPLETED state.
-        container.sendNodeFailureToTA(container.pendingAttempt, errorMessage, 
+        container.sendNodeFailureToTA(container.currentAttempt, errorMessage,
             TaskAttemptTerminationCause.NODE_FAILED);
-        container.sendTerminatingToTaskAttempt(container.pendingAttempt, errorMessage,
-            TaskAttemptTerminationCause.NODE_FAILED);
-      }
-      if (container.runningAttempt != null) {
-        // Will be null in COMPLETED state.
-        container.sendNodeFailureToTA(container.runningAttempt, errorMessage, 
-            TaskAttemptTerminationCause.NODE_FAILED);
-        container.sendTerminatingToTaskAttempt(container.runningAttempt, errorMessage,
+        container.sendTerminatingToTaskAttempt(container.currentAttempt, errorMessage,
             TaskAttemptTerminationCause.NODE_FAILED);
       }
       container.logStopped(ContainerExitStatus.ABORTED);
@@ -775,61 +719,15 @@ public class AMContainerImpl implements AMContainer {
     @Override
     public void transition(AMContainerImpl container, AMContainerEvent cEvent) {
       super.transition(container, cEvent);
-      if (container.pendingAttempt != null) {
-        container.sendTerminatingToTaskAttempt(container.pendingAttempt,
+      if (container.currentAttempt != null) {
+        container.sendTerminatingToTaskAttempt(container.currentAttempt,
             "Container " + container.getContainerId() +
                 " hit an invalid transition - " + cEvent.getType() + " at " +
                 container.getState(), TaskAttemptTerminationCause.FRAMEWORK_ERROR);
       }
       container.logStopped(ContainerExitStatus.ABORTED);
-      container.sendStopRequestToNM();
       container.unregisterFromTAListener();
-    }
-  }
-
-  protected static class AssignTaskAttemptAtIdleTransition
-      extends AssignTaskAttemptTransition {
-    @Override
-    public AMContainerState transition(
-        AMContainerImpl container, AMContainerEvent cEvent) {
-      if (LOG.isDebugEnabled()) {
-        LOG.debug("AssignTAAtIdle: attempt: " +
-            ((AMContainerEventAssignTA) cEvent).getRemoteTaskSpec());
-      }
-      return super.transition(container, cEvent);
-    }
-  }
-
-  protected static class PullTAAtIdleTransition implements
-      MultipleArcTransition<AMContainerImpl, AMContainerEvent, AMContainerState> {
-
-    @Override
-    public AMContainerState transition(
-        AMContainerImpl container, AMContainerEvent cEvent) {
-      if (container.pendingAttempt != null) {
-        // This will be invoked as part of the PULL_REQUEST - so pullAttempt pullAttempt
-        // should ideally only end up being populated during the duration of this call,
-        // which is in a write lock. pullRequest() should move this to the running state.
-        container.pullAttempt = container.pendingAttempt;
-        container.runningAttempt = container.pendingAttempt;
-        container.pendingAttempt = null;
-        if (container.lastTaskFinishTime != 0) {
-          long idleTimeDiff =
-              System.currentTimeMillis() - container.lastTaskFinishTime;
-          container.idleTimeBetweenTasks += idleTimeDiff;
-          if (LOG.isDebugEnabled()) {
-            LOG.debug("Computing idle time for container: " +
-                container.getContainerId() + ", lastFinishTime: " +
-                container.lastTaskFinishTime + ", Incremented by: " +
-                idleTimeDiff);
-          }
-        }
-        LOG.info("Assigned taskAttempt + [" + container.runningAttempt +
-            "] to container: [" + container.getContainerId() + "]");
-        return AMContainerState.RUNNING;
-      } else {
-        return AMContainerState.IDLE;
-      }
+      container.sendStopRequestToNM();
     }
   }
 
@@ -900,8 +798,8 @@ public class AMContainerImpl implements AMContainer {
     public void transition(AMContainerImpl container, AMContainerEvent cEvent) {
 
       AMContainerEventAssignTA event = (AMContainerEventAssignTA) cEvent;
-      container.unregisterAttemptFromListener(container.runningAttempt);
-      container.handleExtraTAAssign(event, container.runningAttempt);
+      container.unregisterAttemptFromListener(container.currentAttempt);
+      container.handleExtraTAAssign(event, container.currentAttempt);
     }
   }
 
@@ -910,9 +808,9 @@ public class AMContainerImpl implements AMContainer {
     @Override
     public void transition(AMContainerImpl container, AMContainerEvent cEvent) {
       container.lastTaskFinishTime = System.currentTimeMillis();
-      container.completedAttempts.add(container.runningAttempt);
-      container.unregisterAttemptFromListener(container.runningAttempt);
-      container.runningAttempt = null;
+      container.completedAttempts.add(container.currentAttempt);
+      container.unregisterAttemptFromListener(container.currentAttempt);
+      container.currentAttempt = null;
     }
   }
 
@@ -922,15 +820,15 @@ public class AMContainerImpl implements AMContainer {
     public void transition(AMContainerImpl container, AMContainerEvent cEvent) {
       AMContainerEventCompleted event = (AMContainerEventCompleted) cEvent;
       if (event.isSystemAction()) {
-        container.sendContainerTerminatedBySystemToTaskAttempt(container.runningAttempt,
+        container.sendContainerTerminatedBySystemToTaskAttempt(container.currentAttempt,
             getMessage(container, event), event.getTerminationCause());
       } else {
-        container.sendTerminatedToTaskAttempt(container.runningAttempt,
+        container.sendTerminatedToTaskAttempt(container.currentAttempt,
             getMessage(container, event), event.getTerminationCause());
       }
-      container.unregisterAttemptFromListener(container.runningAttempt);
-      container.registerFailedAttempt(container.runningAttempt);
-      container.runningAttempt = null;
+      container.unregisterAttemptFromListener(container.currentAttempt);
+      container.registerFailedAttempt(container.currentAttempt);
+      container.currentAttempt= null;
       super.transition(container, cEvent);
     }
   }
@@ -938,11 +836,7 @@ public class AMContainerImpl implements AMContainer {
   protected static class StopRequestAtRunningTransition
       extends StopRequestAtIdleTransition {
     public void transition(AMContainerImpl container, AMContainerEvent cEvent) {
-
-      container.unregisterAttemptFromListener(container.runningAttempt);
-      container.sendTerminatingToTaskAttempt(container.runningAttempt,
-          " Container" + container.getContainerId() + " received a STOP_REQUEST",
-          TaskAttemptTerminationCause.CONTAINER_STOPPED);
+      container.unregisterAttemptFromListener(container.currentAttempt);
       super.transition(container, cEvent);
     }
   }
@@ -963,7 +857,7 @@ public class AMContainerImpl implements AMContainer {
     @Override
     public void transition(AMContainerImpl container, AMContainerEvent cEvent) {
       super.transition(container, cEvent);
-      container.unregisterAttemptFromListener(container.runningAttempt);
+      container.unregisterAttemptFromListener(container.currentAttempt);
     }
   }
 
@@ -972,8 +866,8 @@ public class AMContainerImpl implements AMContainer {
     @Override
     public void transition(AMContainerImpl container, AMContainerEvent cEvent) {
       super.transition(container, cEvent);
-      container.unregisterAttemptFromListener(container.runningAttempt);
-      container.sendTerminatingToTaskAttempt(container.runningAttempt,
+      container.unregisterAttemptFromListener(container.currentAttempt);
+      container.sendTerminatingToTaskAttempt(container.currentAttempt,
           "Container " + container.getContainerId() +
               " hit an invalid transition - " + cEvent.getType() + " at " +
               container.getState(), TaskAttemptTerminationCause.FRAMEWORK_ERROR);
@@ -996,17 +890,6 @@ public class AMContainerImpl implements AMContainer {
     }
   }
 
-  // Hack to some extent. This allocation should be done while entering one of
-  // the post-running states, insetad of being a transition on the post stop
-  // states.
-  protected static class PullTAAfterStopTransition
-      implements SingleArcTransition<AMContainerImpl, AMContainerEvent> {
-    @Override
-    public void transition(AMContainerImpl container, AMContainerEvent cEvent) {
-      container.noAllocationContainerTask = NO_MORE_TASKS;
-    }
-  }
-
   protected static class CompletedAtWindDownTransition implements
       SingleArcTransition<AMContainerImpl, AMContainerEvent> {
     @Override
@@ -1017,17 +900,11 @@ public class AMContainerImpl implements AMContainer {
         container.sendTerminatedToTaskAttempt(taId, diag, 
             TaskAttemptTerminationCause.CONTAINER_EXITED);
       }
-      if (container.pendingAttempt != null) {
-        container.sendTerminatedToTaskAttempt(container.pendingAttempt, diag, 
+      if (container.currentAttempt != null) {
+        container.sendTerminatedToTaskAttempt(container.currentAttempt, diag,
             TaskAttemptTerminationCause.CONTAINER_EXITED);
-        container.registerFailedAttempt(container.pendingAttempt);
-        container.pendingAttempt = null;
-      }
-      if (container.runningAttempt != null) {
-        container.sendTerminatedToTaskAttempt(container.runningAttempt, diag, 
-            TaskAttemptTerminationCause.CONTAINER_EXITED);
-        container.registerFailedAttempt(container.runningAttempt);
-        container.runningAttempt = null;
+        container.registerFailedAttempt(container.currentAttempt);
+        container.currentAttempt = null;
       }
       if (!(diag == null || diag.equals(""))) {
         LOG.info("Container " + container.getContainerId()
@@ -1177,6 +1054,10 @@ public class AMContainerImpl implements AMContainer {
     taskAttemptListener.unregisterTaskAttempt(attemptId);
   }
 
+  protected void registerAttemptWithListener(AMContainerTask amContainerTask) {
+    taskAttemptListener.registerTaskAttempt(amContainerTask, this.containerId);
+  }
+
   protected void registerWithTAListener() {
     taskAttemptListener.registerRunningContainer(containerId);
   }
@@ -1185,7 +1066,6 @@ public class AMContainerImpl implements AMContainer {
     this.taskAttemptListener.unregisterRunningContainer(containerId);
   }
 
-
   protected void registerWithContainerListener() {
     this.containerHeartbeatHandler.register(this.containerId);
   }

http://git-wip-us.apache.org/repos/asf/tez/blob/fe39ede3/tez-dag/src/main/java/org/apache/tez/dag/app/rm/container/AMContainerTask.java
----------------------------------------------------------------------
diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/rm/container/AMContainerTask.java b/tez-dag/src/main/java/org/apache/tez/dag/app/rm/container/AMContainerTask.java
index efe2cca..89a434b 100644
--- a/tez-dag/src/main/java/org/apache/tez/dag/app/rm/container/AMContainerTask.java
+++ b/tez-dag/src/main/java/org/apache/tez/dag/app/rm/container/AMContainerTask.java
@@ -20,30 +20,26 @@ package org.apache.tez.dag.app.rm.container;
 
 import java.util.Map;
 
+import com.google.common.base.Preconditions;
 import org.apache.hadoop.security.Credentials;
 import org.apache.hadoop.yarn.api.records.LocalResource;
 import org.apache.tez.runtime.api.impl.TaskSpec;
 
 public class AMContainerTask {
-  private final boolean shouldDie;
   private final Map<String, LocalResource> additionalResources;
   private final TaskSpec tezTask;
   private final Credentials credentials;
   private final boolean credentialsChanged;
 
-  public AMContainerTask(boolean shouldDie, TaskSpec tezTask,
+  public AMContainerTask(TaskSpec tezTask,
       Map<String, LocalResource> additionalResources, Credentials credentials, boolean credentialsChanged) {
-    this.shouldDie = shouldDie;
+    Preconditions.checkNotNull(tezTask, "TaskSpec cannot be null");
     this.tezTask = tezTask;
     this.additionalResources = additionalResources;
     this.credentials = credentials;
     this.credentialsChanged = credentialsChanged;
   }
 
-  public boolean shouldDie() {
-    return this.shouldDie;
-  }
-
   public TaskSpec getTask() {
     return this.tezTask;
   }

http://git-wip-us.apache.org/repos/asf/tez/blob/fe39ede3/tez-dag/src/test/java/org/apache/tez/dag/app/TestTaskAttemptListenerImplTezDag.java
----------------------------------------------------------------------
diff --git a/tez-dag/src/test/java/org/apache/tez/dag/app/TestTaskAttemptListenerImplTezDag.java b/tez-dag/src/test/java/org/apache/tez/dag/app/TestTaskAttemptListenerImplTezDag.java
new file mode 100644
index 0000000..599a289
--- /dev/null
+++ b/tez-dag/src/test/java/org/apache/tez/dag/app/TestTaskAttemptListenerImplTezDag.java
@@ -0,0 +1,182 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.dag.app;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertTrue;
+import static org.mockito.Mockito.doReturn;
+import static org.mockito.Mockito.mock;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+
+import org.apache.hadoop.yarn.api.records.ApplicationAccessType;
+import org.apache.hadoop.yarn.api.records.ApplicationAttemptId;
+import org.apache.hadoop.yarn.api.records.ApplicationId;
+import org.apache.hadoop.yarn.api.records.ContainerId;
+import org.apache.hadoop.yarn.event.EventHandler;
+import org.apache.tez.common.ContainerContext;
+import org.apache.tez.common.ContainerTask;
+import org.apache.tez.common.security.JobTokenSecretManager;
+import org.apache.tez.dag.app.dag.DAG;
+import org.apache.tez.dag.app.rm.container.AMContainer;
+import org.apache.tez.dag.app.rm.container.AMContainerMap;
+import org.apache.tez.dag.app.rm.container.AMContainerTask;
+import org.apache.tez.dag.records.TezTaskAttemptID;
+import org.apache.tez.runtime.api.impl.TaskSpec;
+import org.junit.Test;
+
+public class TestTaskAttemptListenerImplTezDag {
+
+  @Test(timeout = 5000)
+  public void testGetTask() throws IOException {
+    ApplicationId appId = ApplicationId.newInstance(1000, 1);
+    AppContext appContext = mock(AppContext.class);
+    EventHandler eventHandler = mock(EventHandler.class);
+    DAG dag = mock(DAG.class);
+    AMContainerMap amContainerMap = mock(AMContainerMap.class);
+    Map<ApplicationAccessType, String> appAcls = new HashMap<ApplicationAccessType, String>();
+    doReturn(eventHandler).when(appContext).getEventHandler();
+    doReturn(dag).when(appContext).getCurrentDAG();
+    doReturn(appAcls).when(appContext).getApplicationACLs();
+    doReturn(amContainerMap).when(appContext).getAllContainers();
+
+    TaskAttemptListenerImpTezDag taskAttemptListener =
+        new TaskAttemptListenerImplForTest(appContext, mock(TaskHeartbeatHandler.class),
+            mock(ContainerHeartbeatHandler.class), null);
+
+
+    TaskSpec taskSpec = mock(TaskSpec.class);
+    TezTaskAttemptID taskAttemptId = mock(TezTaskAttemptID.class);
+    doReturn(taskAttemptId).when(taskSpec).getTaskAttemptID();
+    AMContainerTask amContainerTask = new AMContainerTask(taskSpec, null, null, false);
+    ContainerTask containerTask = null;
+
+
+    ContainerId containerId1 = createContainerId(appId, 1);
+    doReturn(mock(AMContainer.class)).when(amContainerMap).get(containerId1);
+    ContainerContext containerContext1 = new ContainerContext(containerId1.toString());
+    containerTask = taskAttemptListener.getTask(containerContext1);
+    assertTrue(containerTask.shouldDie());
+
+
+    ContainerId containerId2 = createContainerId(appId, 2);
+    doReturn(mock(AMContainer.class)).when(amContainerMap).get(containerId2);
+    ContainerContext containerContext2 = new ContainerContext(containerId2.toString());
+    taskAttemptListener.registerRunningContainer(containerId2);
+    containerTask = taskAttemptListener.getTask(containerContext2);
+    assertNull(containerTask);
+
+    // Valid task registered
+    taskAttemptListener.registerTaskAttempt(amContainerTask, containerId2);
+    containerTask = taskAttemptListener.getTask(containerContext2);
+    assertFalse(containerTask.shouldDie());
+    assertEquals(taskSpec, containerTask.getTaskSpec());
+
+    // Task unregistered. Should respond to heartbeats
+    taskAttemptListener.unregisterTaskAttempt(taskAttemptId);
+    containerTask = taskAttemptListener.getTask(containerContext2);
+    assertNull(containerTask);
+
+    // Container unregistered. Should send a shouldDie = true
+    taskAttemptListener.unregisterRunningContainer(containerId2);
+    containerTask = taskAttemptListener.getTask(containerContext2);
+    assertTrue(containerTask.shouldDie());
+
+    ContainerId containerId3 = createContainerId(appId, 3);
+    ContainerContext containerContext3 = new ContainerContext(containerId3.toString());
+    taskAttemptListener.registerRunningContainer(containerId3);
+
+    // Register task to container3, followed by unregistering container 3 all together
+    TaskSpec taskSpec2 = mock(TaskSpec.class);
+    TezTaskAttemptID taskAttemptId2 = mock(TezTaskAttemptID.class);
+    doReturn(taskAttemptId2).when(taskSpec2).getTaskAttemptID();
+    AMContainerTask amContainerTask2 = new AMContainerTask(taskSpec, null, null, false);
+    taskAttemptListener.registerTaskAttempt(amContainerTask2, containerId3);
+    taskAttemptListener.unregisterRunningContainer(containerId3);
+    containerTask = taskAttemptListener.getTask(containerContext3);
+    assertTrue(containerTask.shouldDie());
+  }
+
+  @Test(timeout = 5000)
+  public void testGetTaskMultiplePulls() throws IOException {
+    ApplicationId appId = ApplicationId.newInstance(1000, 1);
+    AppContext appContext = mock(AppContext.class);
+    EventHandler eventHandler = mock(EventHandler.class);
+    DAG dag = mock(DAG.class);
+    AMContainerMap amContainerMap = mock(AMContainerMap.class);
+    Map<ApplicationAccessType, String> appAcls = new HashMap<ApplicationAccessType, String>();
+    doReturn(eventHandler).when(appContext).getEventHandler();
+    doReturn(dag).when(appContext).getCurrentDAG();
+    doReturn(appAcls).when(appContext).getApplicationACLs();
+    doReturn(amContainerMap).when(appContext).getAllContainers();
+
+    TaskAttemptListenerImpTezDag taskAttemptListener =
+        new TaskAttemptListenerImplForTest(appContext, mock(TaskHeartbeatHandler.class),
+            mock(ContainerHeartbeatHandler.class), null);
+
+
+    TaskSpec taskSpec = mock(TaskSpec.class);
+    TezTaskAttemptID taskAttemptId = mock(TezTaskAttemptID.class);
+    doReturn(taskAttemptId).when(taskSpec).getTaskAttemptID();
+    AMContainerTask amContainerTask = new AMContainerTask(taskSpec, null, null, false);
+    ContainerTask containerTask = null;
+
+
+    ContainerId containerId1 = createContainerId(appId, 1);
+    doReturn(mock(AMContainer.class)).when(amContainerMap).get(containerId1);
+    ContainerContext containerContext1 = new ContainerContext(containerId1.toString());
+    taskAttemptListener.registerRunningContainer(containerId1);
+    containerTask = taskAttemptListener.getTask(containerContext1);
+    assertNull(containerTask);
+
+    // Register task
+    taskAttemptListener.registerTaskAttempt(amContainerTask, containerId1);
+    containerTask = taskAttemptListener.getTask(containerContext1);
+    assertFalse(containerTask.shouldDie());
+    assertEquals(taskSpec, containerTask.getTaskSpec());
+
+    // Try pulling again - simulates re-use pull
+    containerTask = taskAttemptListener.getTask(containerContext1);
+    assertNull(containerTask);
+  }
+
+  private ContainerId createContainerId(ApplicationId applicationId, long containerIdLong) {
+    ApplicationAttemptId appAttemptId = ApplicationAttemptId.newInstance(applicationId, 1);
+    ContainerId containerId = ContainerId.newContainerId(appAttemptId, containerIdLong);
+    return containerId;
+  }
+
+  private static class TaskAttemptListenerImplForTest extends TaskAttemptListenerImpTezDag {
+
+    public TaskAttemptListenerImplForTest(AppContext context,
+                                          TaskHeartbeatHandler thh,
+                                          ContainerHeartbeatHandler chh,
+                                          JobTokenSecretManager jobTokenSecretManager) {
+      super(context, thh, chh, jobTokenSecretManager);
+    }
+
+    @Override
+    protected void startRpcServer() {
+    }
+
+    @Override
+    protected void stopRpcServer() {
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/tez/blob/fe39ede3/tez-dag/src/test/java/org/apache/tez/dag/app/rm/container/TestAMContainer.java
----------------------------------------------------------------------
diff --git a/tez-dag/src/test/java/org/apache/tez/dag/app/rm/container/TestAMContainer.java b/tez-dag/src/test/java/org/apache/tez/dag/app/rm/container/TestAMContainer.java
index 438c50d..22c0559 100644
--- a/tez-dag/src/test/java/org/apache/tez/dag/app/rm/container/TestAMContainer.java
+++ b/tez-dag/src/test/java/org/apache/tez/dag/app/rm/container/TestAMContainer.java
@@ -20,10 +20,11 @@ package org.apache.tez.dag.app.rm.container;
 
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
-import static org.junit.Assert.assertNull;
 import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertNull;
 import static org.junit.Assert.assertTrue;
 import static org.mockito.Matchers.any;
+import static org.mockito.Matchers.eq;
 import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.doAnswer;
 import static org.mockito.Mockito.mock;
@@ -103,40 +104,35 @@ public class TestAMContainer {
     wc.verifyState(AMContainerState.LAUNCHING);
     // 1 Launch request.
     wc.verifyCountAndGetOutgoingEvents(1);
+    verify(wc.tal).registerRunningContainer(wc.containerID);
+    assertNull(wc.amContainer.getCurrentTaskAttempt());
 
     // Assign task.
     wc.assignTaskAttempt(wc.taskAttemptID);
     wc.verifyState(AMContainerState.LAUNCHING);
     wc.verifyNoOutgoingEvents();
-    assertEquals(wc.taskAttemptID, wc.amContainer.getQueuedTaskAttempts()
-        .get(0));
+    assertEquals(wc.taskAttemptID, wc.amContainer.getCurrentTaskAttempt());
 
     // Container Launched
     wc.containerLaunched();
-    wc.verifyState(AMContainerState.IDLE);
-    wc.verifyNoOutgoingEvents();
-    assertEquals(wc.taskAttemptID, wc.amContainer.getQueuedTaskAttempts()
-        .get(0));
-    assertNull(wc.amContainer.getRunningTaskAttempt());
-    verify(wc.tal).registerRunningContainer(wc.containerID);
-    verify(wc.chh).register(wc.containerID);
-
-    // Pull TA
-    AMContainerTask pulledTask = wc.pullTaskToRun();
     wc.verifyState(AMContainerState.RUNNING);
     wc.verifyNoOutgoingEvents();
-    assertFalse(pulledTask.shouldDie());
-    assertEquals(wc.taskSpec.getTaskAttemptID(), pulledTask.getTask()
-        .getTaskAttemptID());
-    assertEquals(wc.taskAttemptID, wc.amContainer.getRunningTaskAttempt());
-    assertEquals(0, wc.amContainer.getQueuedTaskAttempts().size());
+    assertEquals(wc.taskAttemptID, wc.amContainer.getCurrentTaskAttempt());
+    // Once for the previous NO_TASKS, one for the actual task.
+    verify(wc.chh).register(wc.containerID);
+    ArgumentCaptor<AMContainerTask> argumentCaptor = ArgumentCaptor.forClass(AMContainerTask.class);
+    verify(wc.tal, times(1)).registerTaskAttempt(argumentCaptor.capture(), eq(wc.containerID));
+    assertEquals(1, argumentCaptor.getAllValues().size());
+    assertEquals(wc.taskAttemptID, argumentCaptor.getAllValues().get(0).getTask().getTaskAttemptID());
 
+    // Attempt succeeded
     wc.taskAttemptSucceeded(wc.taskAttemptID);
     wc.verifyState(AMContainerState.IDLE);
     wc.verifyNoOutgoingEvents();
-    assertNull(wc.amContainer.getRunningTaskAttempt());
+    assertNull(wc.amContainer.getCurrentTaskAttempt());
     verify(wc.tal).unregisterTaskAttempt(wc.taskAttemptID);
 
+    // Container completed
     wc.containerCompleted();
     wc.verifyHistoryStopEvent();
     wc.verifyState(AMContainerState.COMPLETED);
@@ -160,39 +156,98 @@ public class TestAMContainer {
     wc.verifyState(AMContainerState.LAUNCHING);
     // 1 Launch request.
     wc.verifyCountAndGetOutgoingEvents(1);
+    verify(wc.tal).registerRunningContainer(wc.containerID);
 
     // Container Launched
     wc.containerLaunched();
     wc.verifyState(AMContainerState.IDLE);
     wc.verifyNoOutgoingEvents();
-    assertEquals(0, wc.amContainer.getQueuedTaskAttempts().size());
-    verify(wc.tal).registerRunningContainer(wc.containerID);
+    assertNull(wc.amContainer.getCurrentTaskAttempt());
     verify(wc.chh).register(wc.containerID);
 
     // Assign task.
     wc.assignTaskAttempt(wc.taskAttemptID);
+    wc.verifyState(AMContainerState.RUNNING);
+    wc.verifyNoOutgoingEvents();
+    assertEquals(wc.taskAttemptID, wc.amContainer.getCurrentTaskAttempt());
+    ArgumentCaptor<AMContainerTask> argumentCaptor = ArgumentCaptor.forClass(AMContainerTask.class);
+    verify(wc.tal, times(1)).registerTaskAttempt(argumentCaptor.capture(), eq(wc.containerID));
+    assertEquals(1, argumentCaptor.getAllValues().size());
+    assertEquals(wc.taskAttemptID, argumentCaptor.getAllValues().get(0).getTask().getTaskAttemptID());
+
+    wc.taskAttemptSucceeded(wc.taskAttemptID);
     wc.verifyState(AMContainerState.IDLE);
     wc.verifyNoOutgoingEvents();
-    assertEquals(wc.taskAttemptID, wc.amContainer.getQueuedTaskAttempts()
-        .get(0));
-    assertNull(wc.amContainer.getRunningTaskAttempt());
+    assertNull(wc.amContainer.getCurrentTaskAttempt());
+    verify(wc.tal).unregisterTaskAttempt(wc.taskAttemptID);
 
-    // Pull TA
-    AMContainerTask pulledTask = wc.pullTaskToRun();
+    wc.containerCompleted();
+    wc.verifyHistoryStopEvent();
+    wc.verifyState(AMContainerState.COMPLETED);
+    wc.verifyNoOutgoingEvents();
+    verify(wc.tal).unregisterRunningContainer(wc.containerID);
+    verify(wc.chh).unregister(wc.containerID);
+
+    assertEquals(1, wc.amContainer.getAllTaskAttempts().size());
+    assertFalse(wc.amContainer.isInErrorState());
+  }
+
+  @Test (timeout=5000)
+  // Assign before launch.
+  public void tetMultipleSuccessfulTaskFlow() {
+    WrappedContainer wc = new WrappedContainer();
+
+    wc.verifyState(AMContainerState.ALLOCATED);
+
+    // Launch request.
+    wc.launchContainer();
+    wc.verifyState(AMContainerState.LAUNCHING);
+    // 1 Launch request.
+    wc.verifyCountAndGetOutgoingEvents(1);
+    verify(wc.tal).registerRunningContainer(wc.containerID);
+    assertNull(wc.amContainer.getCurrentTaskAttempt());
+
+    // Assign task.
+    wc.assignTaskAttempt(wc.taskAttemptID);
+    wc.verifyState(AMContainerState.LAUNCHING);
+    wc.verifyNoOutgoingEvents();
+    assertEquals(wc.taskAttemptID, wc.amContainer.getCurrentTaskAttempt());
+
+    // Container Launched
+    wc.containerLaunched();
     wc.verifyState(AMContainerState.RUNNING);
     wc.verifyNoOutgoingEvents();
-    assertFalse(pulledTask.shouldDie());
-    assertEquals(wc.taskSpec.getTaskAttemptID(), pulledTask.getTask()
-        .getTaskAttemptID());
-    assertEquals(wc.taskAttemptID, wc.amContainer.getRunningTaskAttempt());
-    assertEquals(0, wc.amContainer.getQueuedTaskAttempts().size());
+    assertEquals(wc.taskAttemptID, wc.amContainer.getCurrentTaskAttempt());
+    // Once for the previous NO_TASKS, one for the actual task.
+    verify(wc.chh).register(wc.containerID);
+    ArgumentCaptor<AMContainerTask> argumentCaptor = ArgumentCaptor.forClass(AMContainerTask.class);
+    verify(wc.tal, times(1)).registerTaskAttempt(argumentCaptor.capture(), eq(wc.containerID));
+    assertEquals(1, argumentCaptor.getAllValues().size());
+    assertEquals(wc.taskAttemptID, argumentCaptor.getAllValues().get(0).getTask().getTaskAttemptID());
 
+    // Attempt succeeded
     wc.taskAttemptSucceeded(wc.taskAttemptID);
     wc.verifyState(AMContainerState.IDLE);
     wc.verifyNoOutgoingEvents();
-    assertNull(wc.amContainer.getRunningTaskAttempt());
+    assertNull(wc.amContainer.getCurrentTaskAttempt());
     verify(wc.tal).unregisterTaskAttempt(wc.taskAttemptID);
 
+    TezTaskAttemptID taId2 = TezTaskAttemptID.getInstance(wc.taskID, 2);
+    wc.assignTaskAttempt(taId2);
+    wc.verifyState(AMContainerState.RUNNING);
+    argumentCaptor = ArgumentCaptor.forClass(AMContainerTask.class);
+    verify(wc.tal, times(2)).registerTaskAttempt(argumentCaptor.capture(), eq(wc.containerID));
+    assertEquals(2, argumentCaptor.getAllValues().size());
+    assertEquals(taId2, argumentCaptor.getAllValues().get(1).getTask().getTaskAttemptID());
+
+    // Attempt succeeded
+    wc.taskAttemptSucceeded(taId2);
+    wc.verifyState(AMContainerState.IDLE);
+    wc.verifyNoOutgoingEvents();
+    assertNull(wc.amContainer.getCurrentTaskAttempt());
+    verify(wc.tal).unregisterTaskAttempt(taId2);
+
+    // Container completed
     wc.containerCompleted();
     wc.verifyHistoryStopEvent();
     wc.verifyState(AMContainerState.COMPLETED);
@@ -200,7 +255,7 @@ public class TestAMContainer {
     verify(wc.tal).unregisterRunningContainer(wc.containerID);
     verify(wc.chh).unregister(wc.containerID);
 
-    assertEquals(1, wc.amContainer.getAllTaskAttempts().size());
+    assertEquals(2, wc.amContainer.getAllTaskAttempts().size());
     assertFalse(wc.amContainer.isInErrorState());
   }
 
@@ -213,7 +268,6 @@ public class TestAMContainer {
     wc.launchContainer();
     wc.assignTaskAttempt(wc.taskAttemptID);
     wc.containerLaunched();
-    wc.pullTaskToRun();
     wc.taskAttemptSucceeded(wc.taskAttemptID);
 
     wc.stopRequest();
@@ -234,8 +288,7 @@ public class TestAMContainer {
     verify(wc.tal).unregisterRunningContainer(wc.containerID);
     verify(wc.chh).unregister(wc.containerID);
 
-    assertEquals(0, wc.amContainer.getQueuedTaskAttempts().size());
-    assertNull(wc.amContainer.getRunningTaskAttempt());
+    assertNull(wc.amContainer.getCurrentTaskAttempt());
     assertEquals(1, wc.amContainer.getAllTaskAttempts().size());
     assertFalse(wc.amContainer.isInErrorState());
   }
@@ -249,7 +302,6 @@ public class TestAMContainer {
     wc.launchContainer();
     wc.assignTaskAttempt(wc.taskAttemptID);
     wc.containerLaunched();
-    wc.pullTaskToRun();
     wc.taskAttemptSucceeded(wc.taskAttemptID);
 
     wc.stopRequest();
@@ -273,22 +325,21 @@ public class TestAMContainer {
     verify(wc.tal).unregisterRunningContainer(wc.containerID);
     verify(wc.chh).unregister(wc.containerID);
 
-    assertEquals(0, wc.amContainer.getQueuedTaskAttempts().size());
-    assertNull(wc.amContainer.getRunningTaskAttempt());
+    assertNull(wc.amContainer.getCurrentTaskAttempt());
     assertEquals(1, wc.amContainer.getAllTaskAttempts().size());
     assertFalse(wc.amContainer.isInErrorState());
   }
 
   @SuppressWarnings("rawtypes")
   @Test (timeout=5000)
-  public void testMultipleAllocationsAtIdle() {
+  public void testMultipleAllocationsWhileActive() {
     WrappedContainer wc = new WrappedContainer();
     List<Event> outgoingEvents;
 
     wc.launchContainer();
     wc.containerLaunched();
     wc.assignTaskAttempt(wc.taskAttemptID);
-    wc.verifyState(AMContainerState.IDLE);
+    wc.verifyState(AMContainerState.RUNNING);
 
     TezTaskAttemptID taID2 = TezTaskAttemptID.getInstance(wc.taskID, 2);
     wc.assignTaskAttempt(taID2);
@@ -313,22 +364,20 @@ public class TestAMContainer {
         TaskAttemptEventType.TA_CONTAINER_TERMINATED,
         TaskAttemptEventType.TA_CONTAINER_TERMINATED);
 
-    assertNull(wc.amContainer.getRunningTaskAttempt());
-    assertEquals(0, wc.amContainer.getQueuedTaskAttempts().size());
+    assertNull(wc.amContainer.getCurrentTaskAttempt());
     assertEquals(2, wc.amContainer.getAllTaskAttempts().size());
   }
 
   @SuppressWarnings("rawtypes")
   @Test (timeout=5000)
-  public void testAllocationAtRunning() {
+  public void testMultipleAllocationsAtLaunching() {
     WrappedContainer wc = new WrappedContainer();
     List<Event> outgoingEvents;
 
     wc.launchContainer();
-    wc.containerLaunched();
     wc.assignTaskAttempt(wc.taskAttemptID);
-    wc.pullTaskToRun();
-    wc.verifyState(AMContainerState.RUNNING);
+    wc.verifyState(AMContainerState.LAUNCHING);
+    verify(wc.tal).registerRunningContainer(wc.containerID);
 
     TezTaskAttemptID taID2 = TezTaskAttemptID.getInstance(wc.taskID, 2);
     wc.assignTaskAttempt(taID2);
@@ -353,63 +402,56 @@ public class TestAMContainer {
         TaskAttemptEventType.TA_CONTAINER_TERMINATED,
         TaskAttemptEventType.TA_CONTAINER_TERMINATED);
 
-    assertNull(wc.amContainer.getRunningTaskAttempt());
-    assertEquals(0, wc.amContainer.getQueuedTaskAttempts().size());
+    assertNull(wc.amContainer.getCurrentTaskAttempt());
     assertEquals(2, wc.amContainer.getAllTaskAttempts().size());
   }
 
   @SuppressWarnings("rawtypes")
   @Test (timeout=5000)
-  public void testMultipleAllocationsAtLaunching() {
+  public void testContainerTimedOutAtRunning() {
     WrappedContainer wc = new WrappedContainer();
     List<Event> outgoingEvents;
 
     wc.launchContainer();
+    wc.containerLaunched();
     wc.assignTaskAttempt(wc.taskAttemptID);
-    wc.pullTaskToRun();
-    wc.verifyState(AMContainerState.LAUNCHING);
-
-    TezTaskAttemptID taID2 = TezTaskAttemptID.getInstance(wc.taskID, 2);
-    wc.assignTaskAttempt(taID2);
+    wc.verifyState(AMContainerState.RUNNING);
 
+    wc.containerTimedOut();
     wc.verifyState(AMContainerState.STOP_REQUESTED);
     verify(wc.tal).unregisterRunningContainer(wc.containerID);
     verify(wc.chh).unregister(wc.containerID);
-    // 1 for NM stop request. 2 TERMINATING to TaskAttempt.
-    outgoingEvents = wc.verifyCountAndGetOutgoingEvents(3);
+    // 1 to TA, 1 for RM de-allocate.
+    outgoingEvents = wc.verifyCountAndGetOutgoingEvents(2);
     verifyUnOrderedOutgoingEventTypes(outgoingEvents,
-        NMCommunicatorEventType.CONTAINER_STOP_REQUEST,
         TaskAttemptEventType.TA_CONTAINER_TERMINATING,
-        TaskAttemptEventType.TA_CONTAINER_TERMINATING);
-    assertTrue(wc.amContainer.isInErrorState());
+        NMCommunicatorEventType.CONTAINER_STOP_REQUEST);
+    // TODO Should this be an RM DE-ALLOCATE instead ?
 
-    wc.nmStopSent();
     wc.containerCompleted();
     wc.verifyHistoryStopEvent();
-    // 1 Inform scheduler. 2 TERMINATED to TaskAttempt.
-    outgoingEvents = wc.verifyCountAndGetOutgoingEvents(2);
+    outgoingEvents = wc.verifyCountAndGetOutgoingEvents(1);
     verifyUnOrderedOutgoingEventTypes(outgoingEvents,
-        TaskAttemptEventType.TA_CONTAINER_TERMINATED,
         TaskAttemptEventType.TA_CONTAINER_TERMINATED);
 
-    assertNull(wc.amContainer.getRunningTaskAttempt());
-    assertEquals(0, wc.amContainer.getQueuedTaskAttempts().size());
-    assertEquals(2, wc.amContainer.getAllTaskAttempts().size());
+    assertFalse(wc.amContainer.isInErrorState());
+
+    assertNull(wc.amContainer.getCurrentTaskAttempt());
+    assertEquals(1, wc.amContainer.getAllTaskAttempts().size());
   }
 
   @SuppressWarnings("rawtypes")
   @Test (timeout=5000)
-  public void testContainerTimedOutAtRunning() {
+  public void testStopRequestedAtRunning() {
     WrappedContainer wc = new WrappedContainer();
     List<Event> outgoingEvents;
 
     wc.launchContainer();
     wc.containerLaunched();
     wc.assignTaskAttempt(wc.taskAttemptID);
-    wc.pullTaskToRun();
     wc.verifyState(AMContainerState.RUNNING);
 
-    wc.containerTimedOut();
+    wc.stopRequest();
     wc.verifyState(AMContainerState.STOP_REQUESTED);
     verify(wc.tal).unregisterRunningContainer(wc.containerID);
     verify(wc.chh).unregister(wc.containerID);
@@ -428,8 +470,7 @@ public class TestAMContainer {
 
     assertFalse(wc.amContainer.isInErrorState());
 
-    assertNull(wc.amContainer.getRunningTaskAttempt());
-    assertEquals(0, wc.amContainer.getQueuedTaskAttempts().size());
+    assertNull(wc.amContainer.getCurrentTaskAttempt());
     assertEquals(1, wc.amContainer.getAllTaskAttempts().size());
   }
 
@@ -442,6 +483,7 @@ public class TestAMContainer {
     wc.launchContainer();
     wc.assignTaskAttempt(wc.taskAttemptID);
     wc.verifyState(AMContainerState.LAUNCHING);
+    verify(wc.tal).registerRunningContainer(wc.containerID);
     wc.launchFailed();
     wc.verifyState(AMContainerState.STOPPING);
     verify(wc.tal).registerRunningContainer(wc.containerID);
@@ -454,7 +496,7 @@ public class TestAMContainer {
     for (Event e : outgoingEvents) {
       if (e.getType() == TaskAttemptEventType.TA_CONTAINER_TERMINATING) {
         Assert.assertEquals(TaskAttemptTerminationCause.CONTAINER_LAUNCH_FAILED,
-            ((TaskAttemptEventContainerTerminating)e).getTerminationCause());        
+            ((TaskAttemptEventContainerTerminating)e).getTerminationCause());
       }
     }
 
@@ -462,7 +504,7 @@ public class TestAMContainer {
     outgoingEvents = wc.verifyCountAndGetOutgoingEvents(1);
     verifyUnOrderedOutgoingEventTypes(outgoingEvents,
         TaskAttemptEventType.TA_CONTAINER_TERMINATED);
-    
+
     // Valid transition. Container complete, but not with an error.
     assertFalse(wc.amContainer.isInErrorState());
   }
@@ -511,7 +553,7 @@ public class TestAMContainer {
 
     assertFalse(wc.amContainer.isInErrorState());
   }
-  
+
   @SuppressWarnings("rawtypes")
   @Test (timeout=5000)
   public void testContainerCompletedAtLaunchingSpecificClusterError() {
@@ -520,7 +562,6 @@ public class TestAMContainer {
 
     wc.launchContainer();
 
-
     wc.assignTaskAttempt(wc.taskAttemptID);
 
     wc.containerCompleted(ContainerExitStatus.DISKS_FAILED, TaskAttemptTerminationCause.NODE_DISK_ERROR);
@@ -542,7 +583,7 @@ public class TestAMContainer {
 
     assertFalse(wc.amContainer.isInErrorState());
   }
-  
+
   @SuppressWarnings("rawtypes")
   @Test (timeout=5000)
   public void testContainerCompletedAtLaunchingSpecificError() {
@@ -582,7 +623,6 @@ public class TestAMContainer {
 
     wc.launchContainer();
 
-    wc.assignTaskAttempt(wc.taskAttemptID);
     wc.containerLaunched();
     wc.verifyState(AMContainerState.IDLE);
 
@@ -593,16 +633,10 @@ public class TestAMContainer {
     verify(wc.chh).register(wc.containerID);
     verify(wc.chh).unregister(wc.containerID);
 
-    outgoingEvents = wc.verifyCountAndGetOutgoingEvents(1);
-    verifyUnOrderedOutgoingEventTypes(outgoingEvents,
-        TaskAttemptEventType.TA_CONTAINER_TERMINATED);
+    wc.verifyCountAndGetOutgoingEvents(0);
 
     assertFalse(wc.amContainer.isInErrorState());
 
-    // Pending pull request. (Ideally, container should be dead at this point
-    // and this event should not be generated. Network timeout on NM-RM heartbeat
-    // can cause it to be genreated)
-    wc.pullTaskToRun();
     wc.verifyNoOutgoingEvents();
     wc.verifyHistoryStopEvent();
 
@@ -619,7 +653,6 @@ public class TestAMContainer {
 
     wc.assignTaskAttempt(wc.taskAttemptID);
     wc.containerLaunched();
-    wc.pullTaskToRun();
     wc.verifyState(AMContainerState.RUNNING);
 
     wc.containerCompleted();
@@ -655,7 +688,6 @@ public class TestAMContainer {
 
     wc.assignTaskAttempt(wc.taskAttemptID);
     wc.containerLaunched();
-    wc.pullTaskToRun();
     wc.verifyState(AMContainerState.RUNNING);
 
     wc.containerCompleted(ContainerExitStatus.PREEMPTED, TaskAttemptTerminationCause.EXTERNAL_PREEMPTION);
@@ -693,7 +725,6 @@ public class TestAMContainer {
 
     wc.assignTaskAttempt(wc.taskAttemptID);
     wc.containerLaunched();
-    wc.pullTaskToRun();
     wc.verifyState(AMContainerState.RUNNING);
 
     wc.containerCompleted(ContainerExitStatus.INVALID, TaskAttemptTerminationCause.INTERNAL_PREEMPTION);
@@ -720,7 +751,7 @@ public class TestAMContainer {
 
     assertFalse(wc.amContainer.isInErrorState());
   }
-  
+
   @SuppressWarnings("rawtypes")
   @Test (timeout=5000)
   public void testContainerDiskFailedAtRunning() {
@@ -731,7 +762,6 @@ public class TestAMContainer {
 
     wc.assignTaskAttempt(wc.taskAttemptID);
     wc.containerLaunched();
-    wc.pullTaskToRun();
     wc.verifyState(AMContainerState.RUNNING);
 
     wc.containerCompleted(ContainerExitStatus.DISKS_FAILED, TaskAttemptTerminationCause.NODE_DISK_ERROR);
@@ -758,7 +788,7 @@ public class TestAMContainer {
 
     assertFalse(wc.amContainer.isInErrorState());
   }
-  
+
   @SuppressWarnings("rawtypes")
   @Test (timeout=5000)
   public void testTaskAssignedToCompletedContainer() {
@@ -768,7 +798,6 @@ public class TestAMContainer {
     wc.launchContainer();
     wc.containerLaunched();
     wc.assignTaskAttempt(wc.taskAttemptID);
-    wc.pullTaskToRun();
     wc.taskAttemptSucceeded(wc.taskAttemptID);
 
     wc.containerCompleted();
@@ -791,28 +820,16 @@ public class TestAMContainer {
     assertTrue(wc.amContainer.isInErrorState());
   }
 
-  @Test (timeout=5000)
-  public void testTaskPullAtLaunching() {
-    WrappedContainer wc = new WrappedContainer();
-
-    wc.launchContainer();
-    AMContainerTask pulledTask = wc.pullTaskToRun();
-    wc.verifyState(AMContainerState.LAUNCHING);
-    wc.verifyNoOutgoingEvents();
-    assertFalse(pulledTask.shouldDie());
-    assertNull(pulledTask.getTask());
-  }
-
   @SuppressWarnings("rawtypes")
   @Test (timeout=5000)
-  public void testNodeFailedAtIdle() {
+  public void testNodeFailedAtRunning() {
     WrappedContainer wc = new WrappedContainer();
     List<Event> outgoingEvents;
 
     wc.launchContainer();
     wc.containerLaunched();
     wc.assignTaskAttempt(wc.taskAttemptID);
-    wc.verifyState(AMContainerState.IDLE);
+    wc.verifyState(AMContainerState.RUNNING);
 
     wc.nodeFailed();
     // Expecting a complete event from the RM
@@ -848,13 +865,11 @@ public class TestAMContainer {
     wc.launchContainer();
     wc.containerLaunched();
     wc.assignTaskAttempt(wc.taskAttemptID);
-    wc.pullTaskToRun();
     wc.taskAttemptSucceeded(wc.taskAttemptID);
     wc.verifyState(AMContainerState.IDLE);
 
     TezTaskAttemptID taID2 = TezTaskAttemptID.getInstance(wc.taskID, 2);
     wc.assignTaskAttempt(taID2);
-    wc.pullTaskToRun();
     wc.taskAttemptSucceeded(taID2);
     wc.verifyState(AMContainerState.IDLE);
 
@@ -880,8 +895,7 @@ public class TestAMContainer {
     wc.verifyNoOutgoingEvents();
     wc.verifyHistoryStopEvent();
 
-    assertNull(wc.amContainer.getRunningTaskAttempt());
-    assertEquals(0, wc.amContainer.getQueuedTaskAttempts().size());
+    assertNull(wc.amContainer.getCurrentTaskAttempt());
     assertEquals(2, wc.amContainer.getAllTaskAttempts().size());
   }
 
@@ -894,12 +908,10 @@ public class TestAMContainer {
     wc.launchContainer();
     wc.containerLaunched();
     wc.assignTaskAttempt(wc.taskAttemptID);
-    wc.pullTaskToRun();
     wc.taskAttemptSucceeded(wc.taskAttemptID);
 
     TezTaskAttemptID taID2 = TezTaskAttemptID.getInstance(wc.taskID, 2);
     wc.assignTaskAttempt(taID2);
-    wc.pullTaskToRun();
     wc.verifyState(AMContainerState.RUNNING);
 
     wc.nodeFailed();
@@ -926,8 +938,7 @@ public class TestAMContainer {
         TaskAttemptEventType.TA_CONTAINER_TERMINATED);
 
     assertFalse(wc.amContainer.isInErrorState());
-    assertNull(wc.amContainer.getRunningTaskAttempt());
-    assertEquals(0, wc.amContainer.getQueuedTaskAttempts().size());
+    assertNull(wc.amContainer.getCurrentTaskAttempt());
     assertEquals(2, wc.amContainer.getAllTaskAttempts().size());
   }
 
@@ -940,12 +951,10 @@ public class TestAMContainer {
     wc.launchContainer();
     wc.containerLaunched();
     wc.assignTaskAttempt(wc.taskAttemptID);
-    wc.pullTaskToRun();
     wc.taskAttemptSucceeded(wc.taskAttemptID);
 
     TezTaskAttemptID taID2 = TezTaskAttemptID.getInstance(wc.taskID, 2);
     wc.assignTaskAttempt(taID2);
-    wc.pullTaskToRun();
     wc.taskAttemptSucceeded(taID2);
     wc.stopRequest();
     wc.nmStopSent();
@@ -958,8 +967,7 @@ public class TestAMContainer {
         TaskAttemptEventType.TA_NODE_FAILED,
         TaskAttemptEventType.TA_NODE_FAILED);
 
-    assertNull(wc.amContainer.getRunningTaskAttempt());
-    assertEquals(0, wc.amContainer.getQueuedTaskAttempts().size());
+    assertNull(wc.amContainer.getCurrentTaskAttempt());
     assertEquals(2, wc.amContainer.getAllTaskAttempts().size());
   }
 
@@ -970,12 +978,10 @@ public class TestAMContainer {
     wc.launchContainer();
     wc.containerLaunched();
     wc.assignTaskAttempt(wc.taskAttemptID);
-    wc.pullTaskToRun();
     wc.taskAttemptSucceeded(wc.taskAttemptID);
 
     TezTaskAttemptID taID2 = TezTaskAttemptID.getInstance(wc.taskID, 2);
     wc.assignTaskAttempt(taID2);
-    wc.pullTaskToRun();
     wc.taskAttemptSucceeded(taID2);
     wc.stopRequest();
     wc.nmStopSent();
@@ -988,7 +994,7 @@ public class TestAMContainer {
     wc.verifyNoOutgoingEvents();
     wc.verifyHistoryStopEvent();
   }
-  
+
   @Test (timeout=5000)
   public void testLocalResourceAddition() {
     WrappedContainer wc = new WrappedContainer();
@@ -1003,7 +1009,9 @@ public class TestAMContainer {
     wc.launchContainer(initialResources, new Credentials());
     wc.containerLaunched();
     wc.assignTaskAttempt(wc.taskAttemptID);
-    AMContainerTask task1 = wc.pullTaskToRun();
+    ArgumentCaptor<AMContainerTask> argumentCaptor = ArgumentCaptor.forClass(AMContainerTask.class);
+    verify(wc.tal, times(1)).registerTaskAttempt(argumentCaptor.capture(), eq(wc.containerID));
+    AMContainerTask task1 = argumentCaptor.getAllValues().get(0);
     assertEquals(0, task1.getAdditionalResources().size());
     wc.taskAttemptSucceeded(wc.taskAttemptID);
 
@@ -1014,7 +1022,9 @@ public class TestAMContainer {
 
     TezTaskAttemptID taID2 = TezTaskAttemptID.getInstance(wc.taskID, 2);
     wc.assignTaskAttempt(taID2, additionalResources, new Credentials());
-    AMContainerTask task2 = wc.pullTaskToRun();
+    argumentCaptor = ArgumentCaptor.forClass(AMContainerTask.class);
+    verify(wc.tal, times(2)).registerTaskAttempt(argumentCaptor.capture(), eq(wc.containerID));
+    AMContainerTask task2 = argumentCaptor.getAllValues().get(1);
     Map<String, LocalResource> pullTaskAdditionalResources = task2.getAdditionalResources();
     assertEquals(2, pullTaskAdditionalResources.size());
     pullTaskAdditionalResources.remove(rsrc2);
@@ -1035,7 +1045,9 @@ public class TestAMContainer {
     // task is not asked to re-localize again.
     TezTaskAttemptID taID3 = TezTaskAttemptID.getInstance(wc.taskID, 3);
     wc.assignTaskAttempt(taID3, new HashMap<String, LocalResource>(), new Credentials());
-    AMContainerTask task3 = wc.pullTaskToRun();
+    argumentCaptor = ArgumentCaptor.forClass(AMContainerTask.class);
+    verify(wc.tal, times(3)).registerTaskAttempt(argumentCaptor.capture(), eq(wc.containerID));
+    AMContainerTask task3 = argumentCaptor.getAllValues().get(2);
     assertEquals(0, task3.getAdditionalResources().size());
     wc.taskAttemptSucceeded(taID3);
 
@@ -1063,66 +1075,79 @@ public class TestAMContainer {
     TezTaskAttemptID attempt22 = TezTaskAttemptID.getInstance(taskID2, 300);
     TezTaskAttemptID attempt31 = TezTaskAttemptID.getInstance(taskID3, 200);
     TezTaskAttemptID attempt32 = TezTaskAttemptID.getInstance(taskID3, 300);
-    
+
     Map<String, LocalResource> LRs = new HashMap<String, LocalResource>();
     AMContainerTask fetchedTask = null;
-    
+    ArgumentCaptor<AMContainerTask> argumentCaptor = null;
+
     Token<TokenIdentifier> amGenToken = mock(Token.class);
     Token<TokenIdentifier> token1 = mock(Token.class);
     Token<TokenIdentifier> token3 = mock(Token.class);
-    
+
     Credentials containerCredentials = new Credentials();
     TokenCache.setSessionToken(amGenToken, containerCredentials);
 
     Text token1Name = new Text("tokenDag1");
     Text token3Name = new Text("tokenDag3");
-    
+
     Credentials dag1Credentials = new Credentials();
     dag1Credentials.addToken(new Text(token1Name), token1);
     Credentials dag3Credentials = new Credentials();
     dag3Credentials.addToken(new Text(token3Name), token3);
-    
+
     wc.launchContainer(new HashMap<String, LocalResource>(), containerCredentials);
     wc.containerLaunched();
-    wc.assignTaskAttempt(attempt11, LRs , dag1Credentials);
-    fetchedTask = wc.pullTaskToRun();
+    wc.assignTaskAttempt(attempt11, LRs, dag1Credentials);
+    argumentCaptor = ArgumentCaptor.forClass(AMContainerTask.class);
+    verify(wc.tal, times(1)).registerTaskAttempt(argumentCaptor.capture(), eq(wc.containerID));
+    fetchedTask = argumentCaptor.getAllValues().get(0);
     assertTrue(fetchedTask.haveCredentialsChanged());
     assertNotNull(fetchedTask.getCredentials());
     assertNotNull(fetchedTask.getCredentials().getToken(token1Name));
     wc.taskAttemptSucceeded(attempt11);
-    
+
     wc.assignTaskAttempt(attempt12, LRs, dag1Credentials);
-    fetchedTask = wc.pullTaskToRun();
+    argumentCaptor = ArgumentCaptor.forClass(AMContainerTask.class);
+    verify(wc.tal, times(2)).registerTaskAttempt(argumentCaptor.capture(), eq(wc.containerID));
+    fetchedTask = argumentCaptor.getAllValues().get(1);
     assertFalse(fetchedTask.haveCredentialsChanged());
     assertNull(fetchedTask.getCredentials());
     wc.taskAttemptSucceeded(attempt12);
-    
+
     // Move to running a second DAG, with no credentials.
     wc.setNewDAGID(dagID2);
     wc.assignTaskAttempt(attempt21, LRs, null);
-    fetchedTask = wc.pullTaskToRun();
+    argumentCaptor = ArgumentCaptor.forClass(AMContainerTask.class);
+    verify(wc.tal, times(3)).registerTaskAttempt(argumentCaptor.capture(), eq(wc.containerID));
+    fetchedTask = argumentCaptor.getAllValues().get(2);
     assertTrue(fetchedTask.haveCredentialsChanged());
     assertNull(fetchedTask.getCredentials());
     wc.taskAttemptSucceeded(attempt21);
-    
+
     wc.assignTaskAttempt(attempt22, LRs, null);
-    fetchedTask = wc.pullTaskToRun();
+    argumentCaptor = ArgumentCaptor.forClass(AMContainerTask.class);
+    verify(wc.tal, times(4)).registerTaskAttempt(argumentCaptor.capture(), eq(wc.containerID));
+    fetchedTask = argumentCaptor.getAllValues().get(3);
     assertFalse(fetchedTask.haveCredentialsChanged());
     assertNull(fetchedTask.getCredentials());
     wc.taskAttemptSucceeded(attempt22);
-    
+
     // Move to running a third DAG, with Credentials this time
     wc.setNewDAGID(dagID3);
     wc.assignTaskAttempt(attempt31, LRs , dag3Credentials);
-    fetchedTask = wc.pullTaskToRun();
+    argumentCaptor = ArgumentCaptor.forClass(AMContainerTask.class);
+    verify(wc.tal, times(5)).registerTaskAttempt(argumentCaptor.capture(), eq(wc.containerID));
+    fetchedTask = argumentCaptor.getAllValues().get(4);
     assertTrue(fetchedTask.haveCredentialsChanged());
     assertNotNull(fetchedTask.getCredentials());
     assertNotNull(fetchedTask.getCredentials().getToken(token3Name));
     assertNull(fetchedTask.getCredentials().getToken(token1Name));
     wc.taskAttemptSucceeded(attempt31);
-    
+
     wc.assignTaskAttempt(attempt32, LRs, dag1Credentials);
-    fetchedTask = wc.pullTaskToRun();
+    argumentCaptor = ArgumentCaptor.forClass(AMContainerTask.class);
+    verify(wc.tal, times(6)).registerTaskAttempt(argumentCaptor.capture(), eq(wc.containerID));
+    fetchedTask = argumentCaptor.getAllValues().get(5);
     assertFalse(fetchedTask.haveCredentialsChanged());
     assertNull(fetchedTask.getCredentials());
     wc.taskAttemptSucceeded(attempt32);
@@ -1261,15 +1286,11 @@ public class TestAMContainer {
     public void assignTaskAttempt(TezTaskAttemptID taID,
         Map<String, LocalResource> additionalResources, Credentials credentials) {
       reset(eventHandler);
+      doReturn(taID).when(taskSpec).getTaskAttemptID();
       amContainer.handle(new AMContainerEventAssignTA(containerID, taID, taskSpec,
           additionalResources, credentials));
     }
 
-    public AMContainerTask pullTaskToRun() {
-      reset(eventHandler);
-      return amContainer.pullTaskContext();
-    }
-
     public void containerLaunched() {
       reset(eventHandler);
       amContainer.handle(new AMContainerEventLaunched(containerID));