You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tez.apache.org by zj...@apache.org on 2014/12/18 03:24:15 UTC

tez git commit: TEZ-1769. ContainerCompletedWhileRunningTransition should inherit from TerminatedWhileRunningTransition (zjffdu)

Repository: tez
Updated Branches:
  refs/heads/master 3e5046e0b -> 32b437655


TEZ-1769. ContainerCompletedWhileRunningTransition should inherit from TerminatedWhileRunningTransition (zjffdu)


Project: http://git-wip-us.apache.org/repos/asf/tez/repo
Commit: http://git-wip-us.apache.org/repos/asf/tez/commit/32b43765
Tree: http://git-wip-us.apache.org/repos/asf/tez/tree/32b43765
Diff: http://git-wip-us.apache.org/repos/asf/tez/diff/32b43765

Branch: refs/heads/master
Commit: 32b437655cc18f0ec34e45251c9b09cefb8feb00
Parents: 3e5046e
Author: Jeff Zhang <zj...@apache.org>
Authored: Thu Dec 18 10:23:25 2014 +0800
Committer: Jeff Zhang <zj...@apache.org>
Committed: Thu Dec 18 10:23:25 2014 +0800

----------------------------------------------------------------------
 CHANGES.txt                                     |  1 +
 .../tez/dag/app/dag/impl/TaskAttemptImpl.java   |  2 +-
 .../tez/dag/app/dag/impl/TestTaskAttempt.java   | 95 +++++++++++++-------
 3 files changed, 64 insertions(+), 34 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/tez/blob/32b43765/CHANGES.txt
----------------------------------------------------------------------
diff --git a/CHANGES.txt b/CHANGES.txt
index 4cb6054..ec3f8f5 100644
--- a/CHANGES.txt
+++ b/CHANGES.txt
@@ -6,6 +6,7 @@ Release 0.6.0: Unreleased
 INCOMPATIBLE CHANGES
 
 ALL CHANGES:
+  TEZ-1769. ContainerCompletedWhileRunningTransition should inherit from TerminatedWhileRunningTransition
   TEZ-1849. Fix tez-ui war file licensing.
   TEZ-1840. Document TezTaskOutput.
   TEZ-1576. Class level comment in {{MiniTezCluster}} ends abruptly.

http://git-wip-us.apache.org/repos/asf/tez/blob/32b43765/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/TaskAttemptImpl.java
----------------------------------------------------------------------
diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/TaskAttemptImpl.java b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/TaskAttemptImpl.java
index 1e6ed22..1c8fb8d 100644
--- a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/TaskAttemptImpl.java
+++ b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/TaskAttemptImpl.java
@@ -1325,7 +1325,7 @@ public class TaskAttemptImpl implements TaskAttempt,
   }
 
   protected static class ContainerCompletedWhileRunningTransition extends
-      TerminatedBeforeRunningTransition {
+      TerminatedWhileRunningTransition {
     public ContainerCompletedWhileRunningTransition() {
       super(FAILED_HELPER);
     }

http://git-wip-us.apache.org/repos/asf/tez/blob/32b43765/tez-dag/src/test/java/org/apache/tez/dag/app/dag/impl/TestTaskAttempt.java
----------------------------------------------------------------------
diff --git a/tez-dag/src/test/java/org/apache/tez/dag/app/dag/impl/TestTaskAttempt.java b/tez-dag/src/test/java/org/apache/tez/dag/app/dag/impl/TestTaskAttempt.java
index 07e54fe..4a0e7b9 100644
--- a/tez-dag/src/test/java/org/apache/tez/dag/app/dag/impl/TestTaskAttempt.java
+++ b/tez-dag/src/test/java/org/apache/tez/dag/app/dag/impl/TestTaskAttempt.java
@@ -312,7 +312,6 @@ public class TestTaskAttempt {
     TezDAGID dagID = TezDAGID.getInstance(appId, 1);
     TezVertexID vertexID = TezVertexID.getInstance(dagID, 1);
     TezTaskID taskID = TezTaskID.getInstance(vertexID, 1);
-    TezTaskAttemptID taskAttemptID = TezTaskAttemptID.getInstance(taskID, 0);
 
     MockEventHandler eventHandler = spy(new MockEventHandler());
     TaskAttemptListener taListener = mock(TaskAttemptListener.class);
@@ -343,11 +342,12 @@ public class TestTaskAttempt {
     doReturn(new ClusterInfo()).when(appCtx).getClusterInfo();
     doReturn(containers).when(appCtx).getAllContainers();
 
+    TaskHeartbeatHandler mockHeartbeatHandler = mock(TaskHeartbeatHandler.class);
     TaskAttemptImpl taImpl = new MockTaskAttemptImpl(taskID, 1, eventHandler,
         taListener, taskConf, new SystemClock(),
-        mock(TaskHeartbeatHandler.class), appCtx, locationHint, false,
+        mockHeartbeatHandler, appCtx, locationHint, false,
         resource, createFakeContainerContext(), false);
-
+    TezTaskAttemptID taskAttemptID = taImpl.getID();
     ArgumentCaptor<Event> arg = ArgumentCaptor.forClass(Event.class);
 
     taImpl.handle(new TaskAttemptEventSchedule(taskAttemptID, 0, 0));
@@ -356,6 +356,7 @@ public class TestTaskAttempt {
         null));
     assertEquals("Task attempt is not in the RUNNING state", taImpl.getState(),
         TaskAttemptState.RUNNING);
+    verify(mockHeartbeatHandler).register(taskAttemptID);
 
     int expectedEventsAtRunning = 3;
     verify(eventHandler, times(expectedEventsAtRunning)).handle(arg.capture());
@@ -365,7 +366,7 @@ public class TestTaskAttempt {
     assertFalse(
         "InternalError occurred trying to handle TA_CONTAINER_TERMINATING",
         eventHandler.internalError);
-
+    verify(mockHeartbeatHandler).unregister(taskAttemptID);
     assertEquals("Task attempt is not in the  FAILED state", taImpl.getState(),
         TaskAttemptState.FAILED);
 
@@ -389,6 +390,8 @@ public class TestTaskAttempt {
 
     taImpl.handle(new TaskAttemptEventContainerTerminated(taskAttemptID,
         "Terminated", TaskAttemptTerminationCause.CONTAINER_EXITED));
+    // verify unregister is not invoked again
+    verify(mockHeartbeatHandler, times(1)).unregister(taskAttemptID);
     int expectedEventAfterTerminated = expectedEvenstAfterTerminating + 0;
     arg = ArgumentCaptor.forClass(Event.class);
     verify(eventHandler, times(expectedEventAfterTerminated)).handle(arg.capture());
@@ -410,7 +413,6 @@ public class TestTaskAttempt {
     TezDAGID dagID = TezDAGID.getInstance(appId, 1);
     TezVertexID vertexID = TezVertexID.getInstance(dagID, 1);
     TezTaskID taskID = TezTaskID.getInstance(vertexID, 1);
-    TezTaskAttemptID taskAttemptID = TezTaskAttemptID.getInstance(taskID, 0);
 
     MockEventHandler eventHandler = new MockEventHandler();
     TaskAttemptListener taListener = mock(TaskAttemptListener.class);
@@ -441,23 +443,26 @@ public class TestTaskAttempt {
     doReturn(new ClusterInfo()).when(appCtx).getClusterInfo();
     doReturn(containers).when(appCtx).getAllContainers();
 
+    TaskHeartbeatHandler mockHeartbeatHandler = mock(TaskHeartbeatHandler.class);
     TaskAttemptImpl taImpl = new MockTaskAttemptImpl(taskID, 1, eventHandler,
         taListener, taskConf, new SystemClock(),
-        mock(TaskHeartbeatHandler.class), appCtx, locationHint, false,
+        mockHeartbeatHandler, appCtx, locationHint, false,
         resource, createFakeContainerContext(), false);
-
+    TezTaskAttemptID taskAttemptID = taImpl.getID();
     taImpl.handle(new TaskAttemptEventSchedule(taskAttemptID, 0, 0));
     // At state STARTING.
     taImpl.handle(new TaskAttemptEventStartedRemotely(taskAttemptID, contId,
         null));
     assertEquals("Task attempt is not in running state", taImpl.getState(),
         TaskAttemptState.RUNNING);
+    verify(mockHeartbeatHandler).register(taskAttemptID);
+
     taImpl.handle(new TaskAttemptEventContainerTerminated(taskAttemptID, "Terminated",
         TaskAttemptTerminationCause.CONTAINER_EXITED));
     assertFalse(
         "InternalError occurred trying to handle TA_CONTAINER_TERMINATED",
         eventHandler.internalError);
-
+    verify(mockHeartbeatHandler).unregister(taskAttemptID);
     assertEquals("Terminated", taImpl.getDiagnostics().get(0));
     assertEquals(TaskAttemptTerminationCause.CONTAINER_EXITED, taImpl.getTerminationCause());
     // TODO Ensure TA_TERMINATING after this is ingored.
@@ -473,7 +478,6 @@ public class TestTaskAttempt {
     TezDAGID dagID = TezDAGID.getInstance(appId, 1);
     TezVertexID vertexID = TezVertexID.getInstance(dagID, 1);
     TezTaskID taskID = TezTaskID.getInstance(vertexID, 1);
-    TezTaskAttemptID taskAttemptID = TezTaskAttemptID.getInstance(taskID, 0);
 
     MockEventHandler eventHandler = spy(new MockEventHandler());
     TaskAttemptListener taListener = mock(TaskAttemptListener.class);
@@ -504,11 +508,12 @@ public class TestTaskAttempt {
     doReturn(new ClusterInfo()).when(appCtx).getClusterInfo();
     doReturn(containers).when(appCtx).getAllContainers();
 
+    TaskHeartbeatHandler mockHeartbeatHandler = mock(TaskHeartbeatHandler.class);
     TaskAttemptImpl taImpl = new MockTaskAttemptImpl(taskID, 1, eventHandler,
         taListener, taskConf, new SystemClock(),
-        mock(TaskHeartbeatHandler.class), appCtx, locationHint, false,
+        mockHeartbeatHandler, appCtx, locationHint, false,
         resource, createFakeContainerContext(), false);
-
+    TezTaskAttemptID taskAttemptID = taImpl.getID();
     ArgumentCaptor<Event> arg = ArgumentCaptor.forClass(Event.class);
 
     taImpl.handle(new TaskAttemptEventSchedule(taskAttemptID, 0, 0));
@@ -517,6 +522,7 @@ public class TestTaskAttempt {
         null));
     assertEquals("Task attempt is not in the RUNNING state", taImpl.getState(),
         TaskAttemptState.RUNNING);
+    verify(mockHeartbeatHandler).register(taskAttemptID);
 
     int expectedEventsAtRunning = 3;
     verify(eventHandler, times(expectedEventsAtRunning)).handle(arg.capture());
@@ -525,7 +531,7 @@ public class TestTaskAttempt {
 
     assertEquals("Task attempt is not in the  SUCCEEDED state", taImpl.getState(),
         TaskAttemptState.SUCCEEDED);
-
+    verify(mockHeartbeatHandler).unregister(taskAttemptID);
     assertEquals(0, taImpl.getDiagnostics().size());
 
     int expectedEvenstAfterTerminating = expectedEventsAtRunning + 3;
@@ -544,6 +550,8 @@ public class TestTaskAttempt {
 
     taImpl.handle(new TaskAttemptEventContainerTerminated(taskAttemptID,
         "Terminated", TaskAttemptTerminationCause.CONTAINER_EXITED));
+    // verify unregister is not invoked again
+    verify(mockHeartbeatHandler, times(1)).unregister(taskAttemptID);
     int expectedEventAfterTerminated = expectedEvenstAfterTerminating + 0;
     arg = ArgumentCaptor.forClass(Event.class);
     verify(eventHandler, times(expectedEventAfterTerminated)).handle(arg.capture());
@@ -562,7 +570,6 @@ public class TestTaskAttempt {
     TezDAGID dagID = TezDAGID.getInstance(appId, 1);
     TezVertexID vertexID = TezVertexID.getInstance(dagID, 1);
     TezTaskID taskID = TezTaskID.getInstance(vertexID, 1);
-    TezTaskAttemptID taskAttemptID = TezTaskAttemptID.getInstance(taskID, 0);
 
     MockEventHandler eventHandler = spy(new MockEventHandler());
     TaskAttemptListener taListener = mock(TaskAttemptListener.class);
@@ -594,11 +601,12 @@ public class TestTaskAttempt {
     doReturn(new ClusterInfo()).when(appCtx).getClusterInfo();
     doReturn(containers).when(appCtx).getAllContainers();
 
+    TaskHeartbeatHandler mockHeartbeatHandler = mock(TaskHeartbeatHandler.class);
     TaskAttemptImpl taImpl = new MockTaskAttemptImpl(taskID, 1, eventHandler,
         taListener, taskConf, new SystemClock(),
-        mock(TaskHeartbeatHandler.class), appCtx, locationHint, false,
+        mockHeartbeatHandler, appCtx, locationHint, false,
         resource, createFakeContainerContext(), false);
-
+    TezTaskAttemptID taskAttemptID = taImpl.getID();
     ArgumentCaptor<Event> arg = ArgumentCaptor.forClass(Event.class);
 
     taImpl.handle(new TaskAttemptEventSchedule(taskAttemptID, 0, 0));
@@ -607,6 +615,7 @@ public class TestTaskAttempt {
         null));
     assertEquals("Task attempt is not in the RUNNING state", taImpl.getState(),
         TaskAttemptState.RUNNING);
+    verify(mockHeartbeatHandler).register(taskAttemptID);
 
     int expectedEventsAtRunning = 4;
     verify(eventHandler, times(expectedEventsAtRunning)).handle(arg.capture());
@@ -621,14 +630,15 @@ public class TestTaskAttempt {
 
     assertEquals("Task attempt is not in the  FAIL_IN_PROGRESS state", taImpl.getInternalState(),
         TaskAttemptStateInternal.FAIL_IN_PROGRESS);
-
+    verify(mockHeartbeatHandler).unregister(taskAttemptID);
     assertEquals(1, taImpl.getDiagnostics().size());
     assertEquals("0", taImpl.getDiagnostics().get(0));
     assertEquals(TaskAttemptTerminationCause.APPLICATION_ERROR, taImpl.getTerminationCause());
-    
+
     taImpl.handle(new TaskAttemptEventContainerTerminated(taskAttemptID, "1",
         TaskAttemptTerminationCause.CONTAINER_EXITED));
-
+    // verify unregister is not invoked again
+    verify(mockHeartbeatHandler, times(1)).unregister(taskAttemptID);
     assertEquals(2, taImpl.getDiagnostics().size());
     assertEquals("1", taImpl.getDiagnostics().get(1));
     // err cause does not change
@@ -662,7 +672,6 @@ public class TestTaskAttempt {
     TezDAGID dagID = TezDAGID.getInstance(appId, 1);
     TezVertexID vertexID = TezVertexID.getInstance(dagID, 1);
     TezTaskID taskID = TezTaskID.getInstance(vertexID, 1);
-    TezTaskAttemptID taskAttemptID = TezTaskAttemptID.getInstance(taskID, 0);
 
     MockEventHandler eventHandler = spy(new MockEventHandler());
     TaskAttemptListener taListener = mock(TaskAttemptListener.class);
@@ -694,11 +703,12 @@ public class TestTaskAttempt {
     doReturn(new ClusterInfo()).when(appCtx).getClusterInfo();
     doReturn(containers).when(appCtx).getAllContainers();
 
+    TaskHeartbeatHandler mockHeartbeatHandler = mock(TaskHeartbeatHandler.class);
     TaskAttemptImpl taImpl = new MockTaskAttemptImpl(taskID, 1, eventHandler,
         taListener, taskConf, new SystemClock(),
-        mock(TaskHeartbeatHandler.class), appCtx, locationHint, false,
+        mockHeartbeatHandler, appCtx, locationHint, false,
         resource, createFakeContainerContext(), false);
-
+    TezTaskAttemptID taskAttemptID = taImpl.getID();
     ArgumentCaptor<Event> arg = ArgumentCaptor.forClass(Event.class);
 
     taImpl.handle(new TaskAttemptEventSchedule(taskAttemptID, 0, 0));
@@ -707,6 +717,7 @@ public class TestTaskAttempt {
         null));
     assertEquals("Task attempt is not in the RUNNING state", taImpl.getState(),
         TaskAttemptState.RUNNING);
+    verify(mockHeartbeatHandler).register(taskAttemptID);
 
     int expectedEventsAtRunning = 4;
     verify(eventHandler, times(expectedEventsAtRunning)).handle(arg.capture());
@@ -720,7 +731,7 @@ public class TestTaskAttempt {
 
     assertEquals("Task attempt is not in the  SUCCEEDED state", taImpl.getState(),
         TaskAttemptState.SUCCEEDED);
-
+    verify(mockHeartbeatHandler).unregister(taskAttemptID);
     assertEquals(0, taImpl.getDiagnostics().size());
 
     int expectedEvenstAfterTerminating = expectedEventsAtRunning + 5;
@@ -753,7 +764,6 @@ public class TestTaskAttempt {
     TezDAGID dagID = TezDAGID.getInstance(appId, 1);
     TezVertexID vertexID = TezVertexID.getInstance(dagID, 1);
     TezTaskID taskID = TezTaskID.getInstance(vertexID, 1);
-    TezTaskAttemptID taskAttemptID = TezTaskAttemptID.getInstance(taskID, 0);
 
     MockEventHandler eventHandler = spy(new MockEventHandler());
     TaskAttemptListener taListener = mock(TaskAttemptListener.class);
@@ -784,10 +794,12 @@ public class TestTaskAttempt {
     doReturn(new ClusterInfo()).when(appCtx).getClusterInfo();
     doReturn(containers).when(appCtx).getAllContainers();
 
+    TaskHeartbeatHandler mockHeartbeatHandler = mock(TaskHeartbeatHandler.class);
     TaskAttemptImpl taImpl = new MockTaskAttemptImpl(taskID, 1, eventHandler,
         taListener, taskConf, new SystemClock(),
-        mock(TaskHeartbeatHandler.class), appCtx, locationHint, false,
+        mockHeartbeatHandler, appCtx, locationHint, false,
         resource, createFakeContainerContext(), false);
+    TezTaskAttemptID taskAttemptID = taImpl.getID();
 
     ArgumentCaptor<Event> arg = ArgumentCaptor.forClass(Event.class);
 
@@ -797,6 +809,7 @@ public class TestTaskAttempt {
         null));
     assertEquals("Task attempt is not in the RUNNING state", taImpl.getState(),
         TaskAttemptState.RUNNING);
+    verify(mockHeartbeatHandler).register(taskAttemptID);
 
     int expectedEventsAtRunning = 3;
     verify(eventHandler, times(expectedEventsAtRunning)).handle(arg.capture());
@@ -805,6 +818,7 @@ public class TestTaskAttempt {
 
     assertEquals("Task attempt is not in the  SUCCEEDED state", taImpl.getState(),
         TaskAttemptState.SUCCEEDED);
+    verify(mockHeartbeatHandler).unregister(taskAttemptID);
 
     assertEquals(0, taImpl.getDiagnostics().size());
 
@@ -824,6 +838,8 @@ public class TestTaskAttempt {
 
     taImpl.handle(new TaskAttemptEvent(taskAttemptID,
         TaskAttemptEventType.TA_CONTAINER_TERMINATED_BY_SYSTEM));
+    // verify unregister is not invoked again
+    verify(mockHeartbeatHandler, times(1)).unregister(taskAttemptID);
     int expectedEventAfterTerminated = expectedEventsAfterTerminating + 0;
     arg = ArgumentCaptor.forClass(Event.class);
     verify(eventHandler, times(expectedEventAfterTerminated)).handle(arg.capture());
@@ -843,7 +859,6 @@ public class TestTaskAttempt {
     TezDAGID dagID = TezDAGID.getInstance(appId, 1);
     TezVertexID vertexID = TezVertexID.getInstance(dagID, 1);
     TezTaskID taskID = TezTaskID.getInstance(vertexID, 1);
-    TezTaskAttemptID taskAttemptID = TezTaskAttemptID.getInstance(taskID, 0);
 
     MockEventHandler eventHandler = spy(new MockEventHandler());
     TaskAttemptListener taListener = mock(TaskAttemptListener.class);
@@ -874,10 +889,12 @@ public class TestTaskAttempt {
     doReturn(new ClusterInfo()).when(appCtx).getClusterInfo();
     doReturn(containers).when(appCtx).getAllContainers();
 
+    TaskHeartbeatHandler mockHeartbeatHandler = mock(TaskHeartbeatHandler.class);
     MockTaskAttemptImpl taImpl = new MockTaskAttemptImpl(taskID, 1, eventHandler,
         taListener, taskConf, new SystemClock(),
-        mock(TaskHeartbeatHandler.class), appCtx, locationHint, false,
+        mockHeartbeatHandler, appCtx, locationHint, false,
         resource, createFakeContainerContext(), false);
+    TezTaskAttemptID taskAttemptID = taImpl.getID();
 
     ArgumentCaptor<Event> arg = ArgumentCaptor.forClass(Event.class);
 
@@ -887,6 +904,7 @@ public class TestTaskAttempt {
         null));
     assertEquals("Task attempt is not in the RUNNING state", TaskAttemptState.RUNNING,
         taImpl.getState());
+    verify(mockHeartbeatHandler).register(taskAttemptID);
 
     int expectedEventsAtRunning = 3;
     verify(eventHandler, times(expectedEventsAtRunning)).handle(arg.capture());
@@ -895,7 +913,7 @@ public class TestTaskAttempt {
 
     assertEquals("Task attempt is not in the  SUCCEEDED state", TaskAttemptState.SUCCEEDED,
         taImpl.getState());
-
+    verify(mockHeartbeatHandler).unregister(taskAttemptID);
     assertEquals(0, taImpl.getDiagnostics().size());
 
     int expectedEvenstAfterTerminating = expectedEventsAtRunning + 3;
@@ -918,6 +936,8 @@ public class TestTaskAttempt {
     // Verify in KILLED state
     assertEquals("Task attempt is not in the  KILLED state", TaskAttemptState.KILLED,
         taImpl.getState());
+    // verify unregister is not invoked again
+    verify(mockHeartbeatHandler, times(1)).unregister(taskAttemptID);
     assertEquals(true, taImpl.inputFailedReported);
     // Verify one event to the Task informing it about FAILURE. No events to scheduler. Counter event.
     int expectedEventsNodeFailure = expectedEvenstAfterTerminating + 2;
@@ -942,7 +962,6 @@ public class TestTaskAttempt {
     TezDAGID dagID = TezDAGID.getInstance(appId, 1);
     TezVertexID vertexID = TezVertexID.getInstance(dagID, 1);
     TezTaskID taskID = TezTaskID.getInstance(vertexID, 1);
-    TezTaskAttemptID taskAttemptID = TezTaskAttemptID.getInstance(taskID, 0);
 
     MockEventHandler eventHandler = spy(new MockEventHandler());
     TaskAttemptListener taListener = mock(TaskAttemptListener.class);
@@ -973,10 +992,12 @@ public class TestTaskAttempt {
     doReturn(new ClusterInfo()).when(appCtx).getClusterInfo();
     doReturn(containers).when(appCtx).getAllContainers();
 
+    TaskHeartbeatHandler mockHeartbeatHandler = mock(TaskHeartbeatHandler.class);
     TaskAttemptImpl taImpl = new MockTaskAttemptImpl(taskID, 1, eventHandler,
         taListener, taskConf, new SystemClock(),
-        mock(TaskHeartbeatHandler.class), appCtx, locationHint, false,
+        mockHeartbeatHandler, appCtx, locationHint, false,
         resource, createFakeContainerContext(), true);
+    TezTaskAttemptID taskAttemptID = taImpl.getID();
 
     ArgumentCaptor<Event> arg = ArgumentCaptor.forClass(Event.class);
 
@@ -986,6 +1007,7 @@ public class TestTaskAttempt {
         null));
     assertEquals("Task attempt is not in the RUNNING state", TaskAttemptState.RUNNING,
         taImpl.getState());
+    verify(mockHeartbeatHandler).register(taskAttemptID);
 
     int expectedEventsAtRunning = 3;
     verify(eventHandler, times(expectedEventsAtRunning)).handle(arg.capture());
@@ -994,7 +1016,7 @@ public class TestTaskAttempt {
 
     assertEquals("Task attempt is not in the  SUCCEEDED state", TaskAttemptState.SUCCEEDED,
         taImpl.getState());
-
+    verify(mockHeartbeatHandler).unregister(taskAttemptID);
     assertEquals(0, taImpl.getDiagnostics().size());
 
     int expectedEvenstAfterTerminating = expectedEventsAtRunning + 3;
@@ -1023,6 +1045,8 @@ public class TestTaskAttempt {
     // Verify still in SUCCEEDED state
     assertEquals("Task attempt is not in the  SUCCEEDED state", TaskAttemptState.SUCCEEDED,
         taImpl.getState());
+    // verify unregister is not invoked again
+    verify(mockHeartbeatHandler, times(1)).unregister(taskAttemptID);
     // error cause remains as default value
     assertEquals(TaskAttemptTerminationCause.UNKNOWN_ERROR, taImpl.getTerminationCause());
   }
@@ -1037,7 +1061,6 @@ public class TestTaskAttempt {
     TezDAGID dagID = TezDAGID.getInstance(appId, 1);
     TezVertexID vertexID = TezVertexID.getInstance(dagID, 1);
     TezTaskID taskID = TezTaskID.getInstance(vertexID, 1);
-    TezTaskAttemptID taskAttemptID = TezTaskAttemptID.getInstance(taskID, 0);
 
     MockEventHandler mockEh = new MockEventHandler();
     MockEventHandler eventHandler = spy(mockEh);
@@ -1069,19 +1092,23 @@ public class TestTaskAttempt {
     doReturn(new ClusterInfo()).when(appCtx).getClusterInfo();
     doReturn(containers).when(appCtx).getAllContainers();
 
+    TaskHeartbeatHandler mockHeartbeatHandler = mock(TaskHeartbeatHandler.class);
     MockTaskAttemptImpl taImpl = new MockTaskAttemptImpl(taskID, 1, eventHandler,
         taListener, taskConf, new SystemClock(),
-        mock(TaskHeartbeatHandler.class), appCtx, locationHint, false,
+        mockHeartbeatHandler, appCtx, locationHint, false,
         resource, createFakeContainerContext(), false);
+    TezTaskAttemptID taskAttemptID = taImpl.getID();
 
     taImpl.handle(new TaskAttemptEventSchedule(taskAttemptID, 0, 0));
     // At state STARTING.
     taImpl.handle(new TaskAttemptEventStartedRemotely(taskAttemptID, contId,
         null));
+    verify(mockHeartbeatHandler).register(taskAttemptID);
     taImpl.handle(new TaskAttemptEvent(taskAttemptID,
         TaskAttemptEventType.TA_DONE));
     assertEquals("Task attempt is not in succeeded state", taImpl.getState(),
         TaskAttemptState.SUCCEEDED);
+    verify(mockHeartbeatHandler).unregister(taskAttemptID);
 
     int expectedEventsTillSucceeded = 6;
     ArgumentCaptor<Event> arg = ArgumentCaptor.forClass(Event.class);
@@ -1098,7 +1125,7 @@ public class TestTaskAttempt {
     // failure threshold not met. state is SUCCEEDED
     assertEquals("Task attempt is not in succeeded state", taImpl.getState(),
         TaskAttemptState.SUCCEEDED);
-    
+
     // sending same error again doesnt change anything
     taImpl.handle(new TaskAttemptEventOutputFailed(taskAttemptID, tzEvent, 4));
     assertEquals("Task attempt is not in succeeded state", taImpl.getState(),
@@ -1114,6 +1141,8 @@ public class TestTaskAttempt {
     assertEquals("Task attempt is not in FAILED state", taImpl.getState(),
         TaskAttemptState.FAILED);
     assertEquals(TaskAttemptTerminationCause.OUTPUT_LOST, taImpl.getTerminationCause());
+    // verify unregister is not invoked again
+    verify(mockHeartbeatHandler, times(1)).unregister(taskAttemptID);
 
     assertEquals(true, taImpl.inputFailedReported);
     int expectedEventsAfterFetchFailure = expectedEventsTillSucceeded + 2;