You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@samza.apache.org by ni...@apache.org on 2016/03/08 01:44:06 UTC

samza git commit: SAMZA-867: Fix job restart/shutdown in the event of node outage

Repository: samza
Updated Branches:
  refs/heads/master d6051086f -> bfba03b7b


SAMZA-867: Fix job restart/shutdown in the event of node outage


Project: http://git-wip-us.apache.org/repos/asf/samza/repo
Commit: http://git-wip-us.apache.org/repos/asf/samza/commit/bfba03b7
Tree: http://git-wip-us.apache.org/repos/asf/samza/tree/bfba03b7
Diff: http://git-wip-us.apache.org/repos/asf/samza/diff/bfba03b7

Branch: refs/heads/master
Commit: bfba03b7bea731ffe7db1ed7e2a6441369e2cd9a
Parents: d605108
Author: Jacob Maes <ja...@gmail.com>
Authored: Mon Mar 7 16:27:18 2016 -0800
Committer: Yi Pan (Data Infrastructure) <ni...@gmail.com>
Committed: Mon Mar 7 16:30:03 2016 -0800

----------------------------------------------------------------------
 .../org/apache/samza/config/TaskConfig.scala    |   2 +
 .../job/yarn/AbstractContainerAllocator.java    |  84 ++++++++++-
 .../samza/job/yarn/ContainerAllocator.java      |  21 +--
 .../samza/job/yarn/ContainerRequestState.java   |  13 ++
 .../apache/samza/job/yarn/ContainerUtil.java    | 144 ++++++++++--------
 .../job/yarn/HostAwareContainerAllocator.java   |  32 +---
 .../apache/samza/job/yarn/SamzaAppState.java    |  32 ++++
 .../job/yarn/SamzaContainerLaunchException.java |  45 ++++++
 .../apache/samza/job/yarn/SamzaTaskManager.java | 147 ++++++++++---------
 .../samza/job/yarn/TestContainerAllocator.java  |  63 +++++++-
 .../yarn/TestHostAwareContainerAllocator.java   |  63 +++++++-
 .../samza/job/yarn/TestSamzaTaskManager.java    |  30 ++--
 .../job/yarn/util/MockContainerAllocator.java   |  10 ++
 .../job/yarn/util/MockContainerListener.java    |  24 ++-
 .../yarn/util/MockContainerRequestState.java    |  26 ++++
 .../samza/job/yarn/util/MockContainerUtil.java  |  18 ++-
 .../samza/job/yarn/util/TestAMRMClientImpl.java |   6 +
 17 files changed, 554 insertions(+), 206 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/samza/blob/bfba03b7/samza-core/src/main/scala/org/apache/samza/config/TaskConfig.scala
----------------------------------------------------------------------
diff --git a/samza-core/src/main/scala/org/apache/samza/config/TaskConfig.scala b/samza-core/src/main/scala/org/apache/samza/config/TaskConfig.scala
index 51e9e99..6ff9aac 100644
--- a/samza-core/src/main/scala/org/apache/samza/config/TaskConfig.scala
+++ b/samza-core/src/main/scala/org/apache/samza/config/TaskConfig.scala
@@ -94,6 +94,8 @@ class TaskConfig(config: Config) extends ScalaMapConfig(config) with Logging {
 
   def getCommandClass = getOption(TaskConfig.COMMAND_BUILDER)
 
+  def getCommandClass(defaultValue: String) = getOrDefault(TaskConfig.COMMAND_BUILDER, defaultValue)
+
   def getCheckpointManagerFactory() = getOption(TaskConfig.CHECKPOINT_MANAGER_FACTORY)
 
   def getMessageChooserClass = getOption(TaskConfig.MESSAGE_CHOOSER_CLASS_NAME)

http://git-wip-us.apache.org/repos/asf/samza/blob/bfba03b7/samza-yarn/src/main/java/org/apache/samza/job/yarn/AbstractContainerAllocator.java
----------------------------------------------------------------------
diff --git a/samza-yarn/src/main/java/org/apache/samza/job/yarn/AbstractContainerAllocator.java b/samza-yarn/src/main/java/org/apache/samza/job/yarn/AbstractContainerAllocator.java
index 2e192ee..b4789e6 100644
--- a/samza-yarn/src/main/java/org/apache/samza/job/yarn/AbstractContainerAllocator.java
+++ b/samza-yarn/src/main/java/org/apache/samza/job/yarn/AbstractContainerAllocator.java
@@ -18,9 +18,11 @@
  */
 package org.apache.samza.job.yarn;
 
+import java.util.List;
 import org.apache.hadoop.yarn.api.records.Container;
 import org.apache.hadoop.yarn.client.api.AMRMClient;
 import org.apache.hadoop.yarn.client.api.async.AMRMClientAsync;
+import org.apache.samza.SamzaException;
 import org.apache.samza.config.YarnConfig;
 import java.util.Map;
 import java.util.concurrent.atomic.AtomicBoolean;
@@ -51,10 +53,10 @@ public abstract class AbstractContainerAllocator implements Runnable {
   protected final int containerMaxCpuCore;
 
   // containerRequestState indicate the state of all unfulfilled container requests and allocated containers
-  protected final ContainerRequestState containerRequestState;
+  private final ContainerRequestState containerRequestState;
 
   // state that controls the lifecycle of the allocator thread
-  protected AtomicBoolean isRunning = new AtomicBoolean(true);
+  private AtomicBoolean isRunning = new AtomicBoolean(true);
 
   public AbstractContainerAllocator(AMRMClientAsync<AMRMClient.ContainerRequest> amClient,
                             ContainerUtil containerUtil,
@@ -79,6 +81,10 @@ public abstract class AbstractContainerAllocator implements Runnable {
     while(isRunning.get()) {
       try {
         assignContainerRequests();
+
+        // Release extra containers and update the entire system's state
+        containerRequestState.releaseExtraContainers();
+
         Thread.sleep(ALLOCATOR_SLEEP_TIME);
       } catch (InterruptedException e) {
         log.info("Got InterruptedException in AllocatorThread.", e);
@@ -95,6 +101,41 @@ public abstract class AbstractContainerAllocator implements Runnable {
   protected abstract void assignContainerRequests();
 
   /**
+   * Updates the request state and runs the container on the specified host. Assumes a container
+   * is available on the preferred host, so the caller must verify that before invoking this method.
+   *
+   * @param request             the {@link SamzaContainerRequest} which is being handled.
+   * @param preferredHost       the preferred host on which the container should be run or
+   *                            {@link ContainerRequestState#ANY_HOST} if there is no host preference.
+   */
+  protected void runContainer(SamzaContainerRequest request, String preferredHost) {
+    // Get the available container
+    Container container = peekAllocatedContainer(preferredHost);
+    if (container == null)
+      throw new SamzaException("Expected container was unavailable on host " + preferredHost);
+
+    // Update state
+    containerRequestState.updateStateAfterAssignment(request, preferredHost, container);
+    int expectedContainerId = request.expectedContainerId;
+
+    // Cancel request and run container
+    log.info("Found available containers on {}. Assigning request for container_id {} with "
+            + "timestamp {} to container {}",
+        new Object[]{preferredHost, String.valueOf(expectedContainerId), request.getRequestTimestamp(), container.getId()});
+    try {
+      if (preferredHost.equals(ANY_HOST)) {
+        containerUtil.runContainer(expectedContainerId, container);
+      } else {
+        containerUtil.runMatchedContainer(expectedContainerId, container);
+      }
+    } catch (SamzaContainerLaunchException e) {
+      log.warn(String.format("Got exception while starting container %s. Requesting a new container on any host", container), e);
+      containerRequestState.releaseUnstartableContainer(container);
+      requestContainer(expectedContainerId, ContainerAllocator.ANY_HOST);
+    }
+  }
+
+  /**
    * Called during initial request for containers
    *
    * @param containerToHostMappings Map of containerId to its last seen host (locality).
@@ -131,6 +172,22 @@ public abstract class AbstractContainerAllocator implements Runnable {
   }
 
   /**
+   * @return {@code true} if there is a pending request, {@code false} otherwise.
+   */
+  protected boolean hasPendingRequest() {
+    return peekPendingRequest() != null;
+  }
+
+  /**
+   * Retrieves, but does not remove, the next pending request in the queue.
+   *
+   * @return  the pending request or {@code null} if there is no pending request.
+   */
+  protected SamzaContainerRequest peekPendingRequest() {
+    return containerRequestState.getRequestsQueue().peek();
+  }
+
+  /**
    * Method that adds allocated container to a synchronized buffer of allocated containers list
    * See allocatedContainers in {@link org.apache.samza.job.yarn.ContainerRequestState}
    *
@@ -140,6 +197,29 @@ public abstract class AbstractContainerAllocator implements Runnable {
     containerRequestState.addContainer(container);
   }
 
+  /**
+   * @param host  the host for which a container is needed.
+   * @return      {@code true} if there is a container allocated for the specified host, {@code false} otherwise.
+   */
+  protected boolean hasAllocatedContainer(String host) {
+    return peekAllocatedContainer(host) != null;
+  }
+
+  /**
+   * Retrieves, but does not remove, the first allocated container on the specified host.
+   *
+   * @param host  the host for which a container is needed.
+   * @return      the first {@link Container} allocated for the specified host or {@code null} if there isn't one.
+   */
+  protected Container peekAllocatedContainer(String host) {
+    List<Container> allocatedContainers = containerRequestState.getContainersOnAHost(host);
+    if (allocatedContainers == null || allocatedContainers.isEmpty()) {
+      return null;
+    }
+
+    return allocatedContainers.get(0);
+  }
+
   public final void setIsRunning(boolean state) {
     isRunning.set(state);
   }

http://git-wip-us.apache.org/repos/asf/samza/blob/bfba03b7/samza-yarn/src/main/java/org/apache/samza/job/yarn/ContainerAllocator.java
----------------------------------------------------------------------
diff --git a/samza-yarn/src/main/java/org/apache/samza/job/yarn/ContainerAllocator.java b/samza-yarn/src/main/java/org/apache/samza/job/yarn/ContainerAllocator.java
index 31fcc57..24ac410 100644
--- a/samza-yarn/src/main/java/org/apache/samza/job/yarn/ContainerAllocator.java
+++ b/samza-yarn/src/main/java/org/apache/samza/job/yarn/ContainerAllocator.java
@@ -23,8 +23,6 @@ import org.apache.hadoop.yarn.api.records.Container;
 import org.apache.hadoop.yarn.client.api.AMRMClient;
 import org.apache.hadoop.yarn.client.api.async.AMRMClientAsync;
 import org.apache.samza.config.YarnConfig;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
 
 /**
  * This is the default allocator thread that will be used by SamzaTaskManager.
@@ -33,8 +31,6 @@ import org.slf4j.LoggerFactory;
  * If there aren't enough containers, it waits by sleeping for {@code ALLOCATOR_SLEEP_TIME} milliseconds.
  */
 public class ContainerAllocator extends AbstractContainerAllocator {
-  private static final Logger log = LoggerFactory.getLogger(ContainerAllocator.class);
-
   public ContainerAllocator(AMRMClientAsync<AMRMClient.ContainerRequest> amClient,
                             ContainerUtil containerUtil,
                             YarnConfig yarnConfig) {
@@ -50,20 +46,9 @@ public class ContainerAllocator extends AbstractContainerAllocator {
    * */
   @Override
   public void assignContainerRequests() {
-    List<Container> allocatedContainers = containerRequestState.getContainersOnAHost(ANY_HOST);
-    while (!containerRequestState.getRequestsQueue().isEmpty() && allocatedContainers != null && allocatedContainers.size() > 0) {
-      SamzaContainerRequest request = containerRequestState.getRequestsQueue().peek();
-      Container container = allocatedContainers.get(0);
-
-      // Update state
-      containerRequestState.updateStateAfterAssignment(request, ANY_HOST, container);
-
-      // Cancel request and run container
-      log.info("Running {} on {}", request.expectedContainerId, container.getId());
-      containerUtil.runContainer(request.expectedContainerId, container);
+    while (hasPendingRequest() && hasAllocatedContainer(ANY_HOST)) {
+      SamzaContainerRequest request = peekPendingRequest();
+      runContainer(request, ANY_HOST);
     }
-
-    // If requestQueue is empty, all extra containers in the buffer should be released.
-    containerRequestState.releaseExtraContainers();
   }
 }

http://git-wip-us.apache.org/repos/asf/samza/blob/bfba03b7/samza-yarn/src/main/java/org/apache/samza/job/yarn/ContainerRequestState.java
----------------------------------------------------------------------
diff --git a/samza-yarn/src/main/java/org/apache/samza/job/yarn/ContainerRequestState.java b/samza-yarn/src/main/java/org/apache/samza/job/yarn/ContainerRequestState.java
index 54db5e5..3e3f48c 100644
--- a/samza-yarn/src/main/java/org/apache/samza/job/yarn/ContainerRequestState.java
+++ b/samza-yarn/src/main/java/org/apache/samza/job/yarn/ContainerRequestState.java
@@ -239,6 +239,19 @@ public class ContainerRequestState {
   }
 
   /**
+   * Releases a container that was allocated and assigned but could not be started.
+   * e.g. because of a ConnectException while trying to communicate with the NM.
+   * This method assumes the specified container and associated request have already
+   * been removed from their respective queues.
+   *
+   * @param container the {@link Container} to release.
+   */
+  public void releaseUnstartableContainer(Container container) {
+    log.info("Releasing unstartable container {}", container.getId());
+    amClient.releaseAssignedContainer(container.getId());
+  }
+
+  /**
    * Clears all the state variables
    * Performed when there are no more unfulfilled requests
    * This is not synchronized because it is private.

http://git-wip-us.apache.org/repos/asf/samza/blob/bfba03b7/samza-yarn/src/main/java/org/apache/samza/job/yarn/ContainerUtil.java
----------------------------------------------------------------------
diff --git a/samza-yarn/src/main/java/org/apache/samza/job/yarn/ContainerUtil.java b/samza-yarn/src/main/java/org/apache/samza/job/yarn/ContainerUtil.java
index 91fae98..6580b9a 100644
--- a/samza-yarn/src/main/java/org/apache/samza/job/yarn/ContainerUtil.java
+++ b/samza-yarn/src/main/java/org/apache/samza/job/yarn/ContainerUtil.java
@@ -18,6 +18,13 @@
  */
 package org.apache.samza.job.yarn;
 
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.Map;
 import org.apache.hadoop.fs.FileStatus;
 import org.apache.hadoop.fs.Path;
 import org.apache.hadoop.io.DataOutputBuffer;
@@ -27,7 +34,12 @@ import org.apache.hadoop.security.token.Token;
 import org.apache.hadoop.security.token.TokenIdentifier;
 import org.apache.hadoop.yarn.api.ApplicationConstants;
 import org.apache.hadoop.yarn.api.protocolrecords.StartContainerRequest;
-import org.apache.hadoop.yarn.api.records.*;
+import org.apache.hadoop.yarn.api.records.Container;
+import org.apache.hadoop.yarn.api.records.ContainerLaunchContext;
+import org.apache.hadoop.yarn.api.records.LocalResource;
+import org.apache.hadoop.yarn.api.records.LocalResourceType;
+import org.apache.hadoop.yarn.api.records.LocalResourceVisibility;
+import org.apache.hadoop.yarn.api.records.URL;
 import org.apache.hadoop.yarn.client.api.NMClient;
 import org.apache.hadoop.yarn.conf.YarnConfiguration;
 import org.apache.hadoop.yarn.exceptions.YarnException;
@@ -44,10 +56,6 @@ import org.apache.samza.util.Util;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-import java.io.IOException;
-import java.nio.ByteBuffer;
-import java.util.*;
-
 public class ContainerUtil {
   private static final Logger log = LoggerFactory.getLogger(ContainerUtil.class);
 
@@ -81,74 +89,47 @@ public class ContainerUtil {
     state.containerRequests.incrementAndGet();
   }
 
-  public void runMatchedContainer(int samzaContainerId, Container container) {
+  public void runMatchedContainer(int samzaContainerId, Container container) throws SamzaContainerLaunchException {
     state.matchedContainerRequests.incrementAndGet();
     runContainer(samzaContainerId, container);
   }
 
-  public void runContainer(int samzaContainerId, Container container) {
+  public void runContainer(int samzaContainerId, Container container) throws SamzaContainerLaunchException {
     String containerIdStr = ConverterUtils.toString(container.getId());
     log.info("Got available container ID ({}) for container: {}", samzaContainerId, container);
 
-    String cmdBuilderClassName;
-    if (taskConfig.getCommandClass().isDefined()) {
-      cmdBuilderClassName = taskConfig.getCommandClass().get();
-    } else {
-      cmdBuilderClassName = ShellCommandBuilder.class.getName();
+    CommandBuilder cmdBuilder = getCommandBuilder(samzaContainerId);
+    String command = cmdBuilder.buildCommand();
+    log.info("Container ID {} using command {}", samzaContainerId, command);
+
+    Map<String, String> env = getEscapedEnvironmentVariablesMap(cmdBuilder);
+    printContainerEnvironmentVariables(samzaContainerId, env);
+
+    Path path = new Path(yarnConfig.getPackagePath());
+    log.info("Starting container ID {} using package path {}", samzaContainerId, path);
+
+    startContainer(path, container, env,
+        getFormattedCommand(ApplicationConstants.LOG_DIR_EXPANSION_VAR, command, ApplicationConstants.STDOUT,
+            ApplicationConstants.STDERR));
+
+    if (state.neededContainers.decrementAndGet() == 0) {
+      state.jobHealthy.set(true);
     }
-      CommandBuilder cmdBuilder = (CommandBuilder) Util.getObj(cmdBuilderClassName);
-      cmdBuilder
-          .setConfig(config)
-          .setId(samzaContainerId)
-          .setUrl(state.coordinatorUrl);
-
-      String command = cmdBuilder.buildCommand();
-      log.info("Container ID {} using command {}", samzaContainerId, command);
-
-      log.info("Container ID {} using environment variables: ", samzaContainerId);
-      Map<String, String> env = new HashMap<String, String>();
-      for (Map.Entry<String, String> entry: cmdBuilder.buildEnvironment().entrySet()) {
-        String escapedValue = Util.envVarEscape(entry.getValue());
-        env.put(entry.getKey(), escapedValue);
-        log.info("{}={} ", entry.getKey(), escapedValue);
-      }
+    state.runningContainers.put(samzaContainerId, new YarnContainer(container));
 
-      Path path = new Path(yarnConfig.getPackagePath());
-      log.info("Starting container ID {} using package path {}", samzaContainerId, path);
-
-      startContainer(
-          path,
-          container,
-          env,
-          getFormattedCommand(
-              ApplicationConstants.LOG_DIR_EXPANSION_VAR,
-              command,
-              ApplicationConstants.STDOUT,
-              ApplicationConstants.STDERR)
-      );
-
-      if (state.neededContainers.decrementAndGet() == 0) {
-        state.jobHealthy.set(true);
-      }
-      state.runningContainers.put(samzaContainerId, new YarnContainer(container));
-
-      log.info("Claimed container ID {} for container {} on node {} (http://{}/node/containerlogs/{}).",
-          new Object[]{
-              samzaContainerId,
-              containerIdStr,
-              container.getNodeId().getHost(),
-              container.getNodeHttpAddress(),
-              containerIdStr}
-      );
-
-      log.info("Started container ID {}", samzaContainerId);
+    log.info("Claimed container ID {} for container {} on node {} (http://{}/node/containerlogs/{}).",
+        new Object[]{samzaContainerId, containerIdStr, container
+            .getNodeId().getHost(), container.getNodeHttpAddress(), containerIdStr});
+
+    log.info("Started container ID {}", samzaContainerId);
   }
 
   protected void startContainer(Path packagePath,
                                 Container container,
                                 Map<String, String> env,
-                                final String cmd) {
-    log.info("starting container {} {} {} {}",
+                                final String cmd)
+      throws SamzaContainerLaunchException {
+    log.info("Starting container {} {} {} {}",
         new Object[]{packagePath, container, env, cmd});
 
     // set the local package so that the containers and app master are provisioned with it
@@ -205,10 +186,10 @@ public class ContainerUtil {
       nmClient.startContainer(container, context);
     } catch (YarnException ye) {
       log.error("Received YarnException when starting container: " + container.getId(), ye);
-      throw new SamzaException("Received YarnException when starting container: " + container.getId());
+      throw new SamzaContainerLaunchException("Received YarnException when starting container: " + container.getId(), ye);
     } catch (IOException ioe) {
       log.error("Received IOException when starting container: " + container.getId(), ioe);
-      throw new SamzaException("Received IOException when starting container: " + container.getId());
+      throw new SamzaContainerLaunchException("Received IOException when starting container: " + container.getId(), ioe);
     }
   }
 
@@ -219,4 +200,45 @@ public class ContainerUtil {
     return "export SAMZA_LOG_DIR=" + logDirExpansionVar + " && ln -sfn " + logDirExpansionVar +
         " logs && exec ./__package/" + command + " 1>logs/" + stdOut + " 2>logs/" + stdErr;
   }
+
+  /**
+   * Instantiates and initializes the configured {@link CommandBuilder} class.
+   *
+   * @param samzaContainerId  the Samza container Id for which the container start command will be built.
+   * @return                  the command builder, which is initialized and ready to build the command.
+   */
+  private CommandBuilder getCommandBuilder(int samzaContainerId) {
+    String cmdBuilderClassName = taskConfig.getCommandClass(ShellCommandBuilder.class.getName());
+    CommandBuilder cmdBuilder = (CommandBuilder) Util.getObj(cmdBuilderClassName);
+    cmdBuilder.setConfig(config).setId(samzaContainerId).setUrl(state.coordinatorUrl);
+    return cmdBuilder;
+  }
+
+  /**
+   * Gets the environment variables from the specified {@link CommandBuilder} and escapes certain characters.
+   *
+   * @param cmdBuilder        the command builder containing the environment variables.
+   * @return                  the map containing the escaped environment variables.
+   */
+  private Map<String, String> getEscapedEnvironmentVariablesMap(CommandBuilder cmdBuilder) {
+    Map<String, String> env = new HashMap<String, String>();
+    for (Map.Entry<String, String> entry : cmdBuilder.buildEnvironment().entrySet()) {
+      String escapedValue = Util.envVarEscape(entry.getValue());
+      env.put(entry.getKey(), escapedValue);
+    }
+
+    return env;
+  }
+
+  /**
+   * @param samzaContainerId  the Samza container Id for logging purposes.
+   * @param env               the Map of environment variables to their respective values.
+   */
+  private void printContainerEnvironmentVariables(int samzaContainerId, Map<String, String> env) {
+    StringBuilder sb = new StringBuilder();
+    for (Map.Entry<String, String> entry : env.entrySet()) {
+      sb.append(String.format("\n%s=%s", entry.getKey(), entry.getValue()));
+    }
+    log.info("Container ID {} using environment variables: {}", samzaContainerId, sb.toString());
+  }
 }

http://git-wip-us.apache.org/repos/asf/samza/blob/bfba03b7/samza-yarn/src/main/java/org/apache/samza/job/yarn/HostAwareContainerAllocator.java
----------------------------------------------------------------------
diff --git a/samza-yarn/src/main/java/org/apache/samza/job/yarn/HostAwareContainerAllocator.java b/samza-yarn/src/main/java/org/apache/samza/job/yarn/HostAwareContainerAllocator.java
index 8e1db77..9797196 100644
--- a/samza-yarn/src/main/java/org/apache/samza/job/yarn/HostAwareContainerAllocator.java
+++ b/samza-yarn/src/main/java/org/apache/samza/job/yarn/HostAwareContainerAllocator.java
@@ -18,8 +18,6 @@
  */
 package org.apache.samza.job.yarn;
 
-import java.util.List;
-import org.apache.hadoop.yarn.api.records.Container;
 import org.apache.hadoop.yarn.client.api.AMRMClient;
 import org.apache.hadoop.yarn.client.api.async.AMRMClientAsync;
 import org.apache.samza.config.YarnConfig;
@@ -56,48 +54,32 @@ public class HostAwareContainerAllocator extends AbstractContainerAllocator {
    */
   @Override
   public void assignContainerRequests() {
-    while (!containerRequestState.getRequestsQueue().isEmpty()) {
-      SamzaContainerRequest request = containerRequestState.getRequestsQueue().peek();
+    while (hasPendingRequest()) {
+      SamzaContainerRequest request = peekPendingRequest();
       String preferredHost = request.getPreferredHost();
       int expectedContainerId = request.getExpectedContainerId();
 
       log.info("Handling request for container id {} on preferred host {}", expectedContainerId, preferredHost);
 
-      List<Container> allocatedContainers = containerRequestState.getContainersOnAHost(preferredHost);
-      if (allocatedContainers != null && allocatedContainers.size() > 0) {
+      if (hasAllocatedContainer(preferredHost)) {
         // Found allocated container at preferredHost
-        Container container = allocatedContainers.get(0);
-
-        containerRequestState.updateStateAfterAssignment(request, preferredHost, container);
-
-        log.info("Running {} on {}", expectedContainerId, container.getId());
-        containerUtil.runMatchedContainer(expectedContainerId, container);
+        runContainer(request, preferredHost);
       } else {
         // No allocated container on preferredHost
         log.info("Did not find any allocated containers on preferred host {} for running container id {}",
             preferredHost, expectedContainerId);
         boolean expired = requestExpired(request);
-        allocatedContainers = containerRequestState.getContainersOnAHost(ANY_HOST);
-        if (!expired || allocatedContainers == null || allocatedContainers.size() == 0) {
+        if (expired || !hasAllocatedContainer(ANY_HOST)) {
           log.info("Either the request timestamp {} is greater than container request timeout {}ms or we couldn't "
                   + "find any free allocated containers in the buffer. Breaking out of loop.",
               request.getRequestTimestamp(), CONTAINER_REQUEST_TIMEOUT);
           break;
         } else {
-          if (allocatedContainers.size() > 0) {
-            Container container = allocatedContainers.get(0);
-            log.info("Found available containers on ANY_HOST. Assigning request for container_id {} with "
-                    + "timestamp {} to container {}",
-                new Object[]{String.valueOf(expectedContainerId), request.getRequestTimestamp(), container.getId()});
-            containerRequestState.updateStateAfterAssignment(request, ANY_HOST, container);
-            log.info("Running {} on {}", expectedContainerId, container.getId());
-            containerUtil.runContainer(expectedContainerId, container);
-          }
+          runContainer(request, ANY_HOST);
         }
       }
     }
-    // Release extra containers and update the entire system's state
-    containerRequestState.releaseExtraContainers();
+
   }
 
   private boolean requestExpired(SamzaContainerRequest request) {

http://git-wip-us.apache.org/repos/asf/samza/blob/bfba03b7/samza-yarn/src/main/java/org/apache/samza/job/yarn/SamzaAppState.java
----------------------------------------------------------------------
diff --git a/samza-yarn/src/main/java/org/apache/samza/job/yarn/SamzaAppState.java b/samza-yarn/src/main/java/org/apache/samza/job/yarn/SamzaAppState.java
index bc5b606..77280ba 100644
--- a/samza-yarn/src/main/java/org/apache/samza/job/yarn/SamzaAppState.java
+++ b/samza-yarn/src/main/java/org/apache/samza/job/yarn/SamzaAppState.java
@@ -19,6 +19,7 @@
 
 package org.apache.samza.job.yarn;
 
+import java.util.Map;
 import org.apache.hadoop.yarn.api.records.ApplicationAttemptId;
 import org.apache.hadoop.yarn.api.records.ContainerId;
 import org.apache.hadoop.yarn.api.records.FinalApplicationStatus;
@@ -35,6 +36,11 @@ import java.util.concurrent.atomic.AtomicInteger;
 
 public class SamzaAppState {
   /**
+   * Represents an invalid or unknown Samza container ID.
+   */
+  private static final int UNUSED_CONTAINER_ID = -1;
+
+  /**
    * Job Coordinator is started in the AM and follows the {@link org.apache.samza.job.yarn.SamzaAppMasterService}
    * lifecycle. It helps querying JobModel related info in {@link org.apache.samza.webapp.ApplicationMasterRestServlet}
    * and locality information when host-affinity is enabled in {@link org.apache.samza.job.yarn.SamzaTaskManager}
@@ -177,4 +183,30 @@ public class SamzaAppState {
     this.appAttemptId = amContainerId.getApplicationAttemptId();
 
   }
+
+  /**
+   * Returns the Samza container ID if the specified YARN container ID corresponds to a running container.
+   *
+   * @param yarnContainerId the YARN container ID.
+   * @return                the Samza container ID if it is running,
+   *                        otherwise {@link SamzaAppState#UNUSED_CONTAINER_ID}.
+   */
+  public int getRunningSamzaContainerId(ContainerId yarnContainerId) {
+    int containerId = UNUSED_CONTAINER_ID;
+    for(Map.Entry<Integer, YarnContainer> entry: runningContainers.entrySet()) {
+      if(entry.getValue().id().equals(yarnContainerId)) {
+        containerId = entry.getKey();
+        break;
+      }
+    }
+    return containerId;
+  }
+
+  /**
+   * @param samzaContainerId  the Samza container ID to validate.
+   * @return                  {@code true} if the ID is valid, {@code false} otherwise
+   */
+  public static boolean isValidContainerId(int samzaContainerId) {
+    return samzaContainerId != UNUSED_CONTAINER_ID;
+  }
 }

http://git-wip-us.apache.org/repos/asf/samza/blob/bfba03b7/samza-yarn/src/main/java/org/apache/samza/job/yarn/SamzaContainerLaunchException.java
----------------------------------------------------------------------
diff --git a/samza-yarn/src/main/java/org/apache/samza/job/yarn/SamzaContainerLaunchException.java b/samza-yarn/src/main/java/org/apache/samza/job/yarn/SamzaContainerLaunchException.java
new file mode 100644
index 0000000..4ba936c
--- /dev/null
+++ b/samza-yarn/src/main/java/org/apache/samza/job/yarn/SamzaContainerLaunchException.java
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.samza.job.yarn;
+
+/**
+ * {@code SamzaContainerLaunchException} indicates an {@link Exception} during container launch.
+ * It can wrap another type of {@link Throwable} or {@link Exception}. Ultimately, any exception thrown
+ * during container launch should be of this type so it can be handled explicitly.
+ */
+public class SamzaContainerLaunchException extends Exception {
+
+  private static final long serialVersionUID = -3957939806997013992L;
+
+  public SamzaContainerLaunchException() {
+    super();
+  }
+
+  public SamzaContainerLaunchException(String s, Throwable t) {
+    super(s, t);
+  }
+
+  public SamzaContainerLaunchException(String s) {
+    super(s);
+  }
+
+  public SamzaContainerLaunchException(Throwable t) {
+    super(t);
+  }
+}

http://git-wip-us.apache.org/repos/asf/samza/blob/bfba03b7/samza-yarn/src/main/java/org/apache/samza/job/yarn/SamzaTaskManager.java
----------------------------------------------------------------------
diff --git a/samza-yarn/src/main/java/org/apache/samza/job/yarn/SamzaTaskManager.java b/samza-yarn/src/main/java/org/apache/samza/job/yarn/SamzaTaskManager.java
index a3562a1..caee6e6 100644
--- a/samza-yarn/src/main/java/org/apache/samza/job/yarn/SamzaTaskManager.java
+++ b/samza-yarn/src/main/java/org/apache/samza/job/yarn/SamzaTaskManager.java
@@ -18,6 +18,8 @@
  */
 package org.apache.samza.job.yarn;
 
+import java.util.HashMap;
+import java.util.Map;
 import org.apache.hadoop.yarn.api.records.Container;
 import org.apache.hadoop.yarn.api.records.ContainerExitStatus;
 import org.apache.hadoop.yarn.api.records.ContainerStatus;
@@ -32,8 +34,6 @@ import org.apache.samza.config.YarnConfig;
 import org.apache.samza.coordinator.stream.messages.SetContainerHostMapping;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
-import java.util.HashMap;
-import java.util.Map;
 
 /**
  * Samza's application master is mostly interested in requesting containers to
@@ -145,13 +145,7 @@ class SamzaTaskManager implements YarnAppMasterListener {
   @Override
   public void onContainerCompleted(ContainerStatus containerStatus) {
     String containerIdStr = ConverterUtils.toString(containerStatus.getContainerId());
-    int containerId = -1;
-    for(Map.Entry<Integer, YarnContainer> entry: state.runningContainers.entrySet()) {
-      if(entry.getValue().id().equals(containerStatus.getContainerId())) {
-        containerId = entry.getKey();
-        break;
-      }
-    }
+    int containerId = state.getRunningSamzaContainerId(containerStatus.getContainerId());
     state.runningContainers.remove(containerId);
 
     int exitStatus = containerStatus.getExitStatus();
@@ -161,7 +155,7 @@ class SamzaTaskManager implements YarnAppMasterListener {
 
         state.completedContainers.incrementAndGet();
 
-        if (containerId != -1) {
+        if (SamzaAppState.isValidContainerId(containerId)) {
           state.finishedContainers.add(containerId);
           containerFailures.remove(containerId);
         }
@@ -182,12 +176,11 @@ class SamzaTaskManager implements YarnAppMasterListener {
             containerIdStr);
 
         state.releasedContainers.incrementAndGet();
-
         // If this container was assigned some partitions (a containerId), then
         // clean up, and request a new container for the tasks. This only
         // should happen if the container was 'lost' due to node failure, not
         // if the AM released the container.
-        if (containerId != -1) {
+        if (SamzaAppState.isValidContainerId(containerId)) {
           log.info("Released container {} was assigned task group ID {}. Requesting a new container for the task group.", containerIdStr, containerId);
 
           state.neededContainers.incrementAndGet();
@@ -196,6 +189,7 @@ class SamzaTaskManager implements YarnAppMasterListener {
           // request a container on new host
           containerAllocator.requestContainer(containerId, ContainerAllocator.ANY_HOST);
         }
+
         break;
 
       default:
@@ -207,67 +201,16 @@ class SamzaTaskManager implements YarnAppMasterListener {
         state.failedContainersStatus.put(containerIdStr, containerStatus);
         state.jobHealthy.set(false);
 
-        if(containerId != -1) {
+        if(SamzaAppState.isValidContainerId(containerId)) {
           state.neededContainers.incrementAndGet();
-          // Find out previously running container location
-          String lastSeenOn = state.jobCoordinator.jobModel().getContainerToHostValue(containerId, SetContainerHostMapping.HOST_KEY);
-          if (!hostAffinityEnabled || lastSeenOn == null) {
-            lastSeenOn = ContainerAllocator.ANY_HOST;
-          }
-          // A container failed for an unknown reason. Let's check to see if
-          // we need to shutdown the whole app master if too many container
-          // failures have happened. The rules for failing are that the
-          // failure count for a task group id must be > the configured retry
-          // count, and the last failure (the one prior to this one) must have
-          // happened less than retry window ms ago. If retry count is set to
-          // 0, the app master will fail on any container failure. If the
-          // retry count is set to a number < 0, a container failure will
-          // never trigger an app master failure.
-          int retryCount = yarnConfig.getContainerRetryCount();
-          int retryWindowMs = yarnConfig.getContainerRetryWindowMs();
-
-          if (retryCount == 0) {
-            log.error("Container ID {} ({}) failed, and retry count is set to 0, so shutting down the application master, and marking the job as failed.", containerId, containerIdStr);
-
-            tooManyFailedContainers = true;
-          } else if (retryCount > 0) {
-            int currentFailCount;
-            long lastFailureTime;
-            if(containerFailures.containsKey(containerId)) {
-              ContainerFailure failure = containerFailures.get(containerId);
-              currentFailCount = failure.getCount() + 1;
-              lastFailureTime = failure.getLastFailure();
-              } else {
-              currentFailCount = 1;
-              lastFailureTime = 0L;
-            }
-            if (currentFailCount >= retryCount) {
-              long lastFailureMsDiff = System.currentTimeMillis() - lastFailureTime;
-
-              if (lastFailureMsDiff < retryWindowMs) {
-                log.error("Container ID " + containerId + "(" + containerIdStr + ") has failed " + currentFailCount +
-                    " times, with last failure " + lastFailureMsDiff + "ms ago. This is greater than retry count of " +
-                    retryCount + " and window of " + retryWindowMs + "ms , so shutting down the application master, and marking the job as failed.");
-
-                // We have too many failures, and we're within the window
-                // boundary, so reset shut down the app master.
-                tooManyFailedContainers = true;
-                state.status = FinalApplicationStatus.FAILED;
-              } else {
-                log.info("Resetting fail count for container ID {} back to 1, since last container failure ({}) for " +
-                    "this container ID was outside the bounds of the retry window.", containerId, containerIdStr);
-
-                // Reset counter back to 1, since the last failure for this
-                // container happened outside the window boundary.
-                containerFailures.put(containerId, new ContainerFailure(1, System.currentTimeMillis()));
-              }
-            } else {
-              log.info("Current fail count for container ID {} is {}.", containerId, currentFailCount);
-              containerFailures.put(containerId, new ContainerFailure(currentFailCount, System.currentTimeMillis()));
-            }
-          }
+          recordContainerFailCount(containerIdStr, containerId);
 
           if (!tooManyFailedContainers) {
+            // Find out previously running container location
+            String lastSeenOn = state.jobCoordinator.jobModel().getContainerToHostValue(containerId, SetContainerHostMapping.HOST_KEY);
+            if (!hostAffinityEnabled || lastSeenOn == null) {
+              lastSeenOn = ContainerAllocator.ANY_HOST;
+            }
             // Request a new container
             containerAllocator.requestContainer(containerId, lastSeenOn);
           }
@@ -275,4 +218,68 @@ class SamzaTaskManager implements YarnAppMasterListener {
 
     }
   }
+
+  /**
+   * Increments the failure count, logs the failure, and records the  last failure time for the specified container.
+   * Also, updates the global flag indicating whether too many failures have occurred and returns that flag.
+   *
+   * @param containerIdStr  the YARN container Id for logging purposes.
+   * @param containerId     the Samza container/group Id that failed.
+   * @return                true if any container has failed more than the max number of times.
+   */
+  private boolean recordContainerFailCount(String containerIdStr, int containerId) {
+    // A container failed for an unknown reason. Let's check to see if
+    // we need to shutdown the whole app master if too many container
+    // failures have happened. The rules for failing are that the
+    // failure count for a task group id must be > the configured retry
+    // count, and the last failure (the one prior to this one) must have
+    // happened less than retry window ms ago. If retry count is set to
+    // 0, the app master will fail on any container failure. If the
+    // retry count is set to a number < 0, a container failure will
+    // never trigger an app master failure.
+    int retryCount = yarnConfig.getContainerRetryCount();
+    int retryWindowMs = yarnConfig.getContainerRetryWindowMs();
+
+    if (retryCount == 0) {
+      log.error("Container ID {} ({}) failed, and retry count is set to 0, so shutting down the application master, and marking the job as failed.", containerId, containerIdStr);
+
+      tooManyFailedContainers = true;
+    } else if (retryCount > 0) {
+      int currentFailCount;
+      long lastFailureTime;
+      if(containerFailures.containsKey(containerId)) {
+        ContainerFailure failure = containerFailures.get(containerId);
+        currentFailCount = failure.getCount() + 1;
+        lastFailureTime = failure.getLastFailure();
+        } else {
+        currentFailCount = 1;
+        lastFailureTime = 0L;
+      }
+      if (currentFailCount >= retryCount) {
+        long lastFailureMsDiff = System.currentTimeMillis() - lastFailureTime;
+
+        if (lastFailureMsDiff < retryWindowMs) {
+          log.error("Container ID " + containerId + "(" + containerIdStr + ") has failed " + currentFailCount +
+              " times, with last failure " + lastFailureMsDiff + "ms ago. This is greater than retry count of " +
+              retryCount + " and window of " + retryWindowMs + "ms , so shutting down the application master, and marking the job as failed.");
+
+          // We have too many failures, and we're within the window
+          // boundary, so reset shut down the app master.
+          tooManyFailedContainers = true;
+          state.status = FinalApplicationStatus.FAILED;
+        } else {
+          log.info("Resetting fail count for container ID {} back to 1, since last container failure ({}) for " +
+              "this container ID was outside the bounds of the retry window.", containerId, containerIdStr);
+
+          // Reset counter back to 1, since the last failure for this
+          // container happened outside the window boundary.
+          containerFailures.put(containerId, new ContainerFailure(1, System.currentTimeMillis()));
+        }
+      } else {
+        log.info("Current fail count for container ID {} is {}.", containerId, currentFailCount);
+        containerFailures.put(containerId, new ContainerFailure(currentFailCount, System.currentTimeMillis()));
+      }
+    }
+    return tooManyFailedContainers;
+  }
 }

http://git-wip-us.apache.org/repos/asf/samza/blob/bfba03b7/samza-yarn/src/test/java/org/apache/samza/job/yarn/TestContainerAllocator.java
----------------------------------------------------------------------
diff --git a/samza-yarn/src/test/java/org/apache/samza/job/yarn/TestContainerAllocator.java b/samza-yarn/src/test/java/org/apache/samza/job/yarn/TestContainerAllocator.java
index e2b45d7..2b1bdab 100644
--- a/samza-yarn/src/test/java/org/apache/samza/job/yarn/TestContainerAllocator.java
+++ b/samza-yarn/src/test/java/org/apache/samza/job/yarn/TestContainerAllocator.java
@@ -19,6 +19,7 @@
 
 package org.apache.samza.job.yarn;
 
+import java.io.IOException;
 import java.lang.reflect.Field;
 import java.net.URL;
 import java.util.ArrayList;
@@ -39,6 +40,7 @@ import org.apache.samza.job.model.JobModel;
 import org.apache.samza.job.model.TaskModel;
 import org.apache.samza.job.yarn.util.MockContainerListener;
 import org.apache.samza.job.yarn.util.MockContainerRequestState;
+import org.apache.samza.job.yarn.util.MockContainerUtil;
 import org.apache.samza.job.yarn.util.MockHttpServer;
 import org.apache.samza.job.yarn.util.TestAMRMClientImpl;
 import org.apache.samza.job.yarn.util.TestUtil;
@@ -61,6 +63,7 @@ public class TestContainerAllocator {
   private TestAMRMClientImpl testAMRMClient;
   private MockContainerRequestState requestState;
   private ContainerAllocator containerAllocator;
+  private ContainerUtil containerUtil;
   private Thread allocatorThread;
 
   private Config config = new MapConfig(new HashMap<String, String>() {
@@ -105,9 +108,10 @@ public class TestContainerAllocator {
     state.coordinatorUrl = new URL("http://localhost:7778/");
 
     requestState = new MockContainerRequestState(amRmClientAsync, false);
+    containerUtil = TestUtil.getContainerUtil(config, state);
     containerAllocator = new ContainerAllocator(
         amRmClientAsync,
-        TestUtil.getContainerUtil(config, state),
+        containerUtil,
         new YarnConfig(config)
     );
     Field requestStateField = containerAllocator.getClass().getSuperclass().getDeclaredField("containerRequestState");
@@ -195,6 +199,59 @@ public class TestContainerAllocator {
   }
 
   /**
+   * If the container fails to start e.g because it fails to connect to a NM on a host that
+   * is down, the allocator should request a new container on a different host.
+   */
+  @Test
+  public void testRerequestOnAnyHostIfContainerStartFails() throws Exception {
+    final Container container = TestUtil.getContainer(ConverterUtils.toContainerId("container_1350670447861_0003_01_000001"), "2", 123);
+    final Container container1 = TestUtil.getContainer(ConverterUtils.toContainerId("container_1350670447861_0003_01_000002"), "1", 123);
+
+    ((MockContainerUtil) containerUtil).containerStartException = new IOException("You shall not... connect to the NM!");
+
+    // Set up our final asserts before starting the allocator thread
+    MockContainerListener listener = new MockContainerListener(2, 1, 2, null, new Runnable() {
+      @Override
+      public void run() {
+        // The failed container should be released. The successful one should not.
+        assertNotNull(testAMRMClient.getRelease());
+        assertEquals(1, testAMRMClient.getRelease().size());
+        assertTrue(testAMRMClient.getRelease().contains(container.getId()));
+      }
+    },
+        new Runnable() {
+          @Override
+          public void run() {
+            // Test that the first request assignment had a preferred host and the retry didn't
+            assertEquals(2, requestState.assignedRequests.size());
+
+            SamzaContainerRequest request = requestState.assignedRequests.remove();
+            assertEquals(0, request.expectedContainerId);
+            assertEquals("2", request.getPreferredHost());
+
+            request = requestState.assignedRequests.remove();
+            assertEquals(0, request.expectedContainerId);
+            assertEquals("ANY_HOST", request.getPreferredHost());
+
+            // This routine should be called after the retry is assigned, but before it's started.
+            // So there should still be 1 container needed.
+            assertEquals(1, state.neededContainers.get());
+          }
+        }
+    );
+    requestState.registerContainerListener(listener);
+
+    allocatorThread.start();
+
+    // Only request 1 container and we should see 2 assignments in the assertions above (because of the retry)
+    containerAllocator.requestContainer(0, "2");
+    containerAllocator.addContainer(container);
+    containerAllocator.addContainer(container1);
+
+    listener.verify();
+  }
+
+  /**
    * Extra allocated containers that are returned by the RM and unused by the AM should be released.
    * Containers are considered "extra" only when there are no more pending requests to fulfill
    * @throws Exception
@@ -206,7 +263,7 @@ public class TestContainerAllocator {
     final Container container2 = TestUtil.getContainer(ConverterUtils.toContainerId("container_1350670447861_0003_01_000003"), "def", 123);
 
     // Set up our final asserts before starting the allocator thread
-    MockContainerListener listener = new MockContainerListener(3, 2, null, new Runnable() {
+    MockContainerListener listener = new MockContainerListener(3, 2, 0, null, new Runnable() {
       @Override
       public void run() {
         assertNotNull(testAMRMClient.getRelease());
@@ -220,7 +277,7 @@ public class TestContainerAllocator {
         assertNull(requestState.getContainersOnAHost("abc"));
         assertNull(requestState.getContainersOnAHost("def"));
       }
-    });
+    }, null);
     requestState.registerContainerListener(listener);
 
     allocatorThread.start();

http://git-wip-us.apache.org/repos/asf/samza/blob/bfba03b7/samza-yarn/src/test/java/org/apache/samza/job/yarn/TestHostAwareContainerAllocator.java
----------------------------------------------------------------------
diff --git a/samza-yarn/src/test/java/org/apache/samza/job/yarn/TestHostAwareContainerAllocator.java b/samza-yarn/src/test/java/org/apache/samza/job/yarn/TestHostAwareContainerAllocator.java
index 269d824..0c7a09f 100644
--- a/samza-yarn/src/test/java/org/apache/samza/job/yarn/TestHostAwareContainerAllocator.java
+++ b/samza-yarn/src/test/java/org/apache/samza/job/yarn/TestHostAwareContainerAllocator.java
@@ -18,6 +18,7 @@
  */
 package org.apache.samza.job.yarn;
 
+import java.io.IOException;
 import java.lang.reflect.Field;
 import java.net.URL;
 import java.util.ArrayList;
@@ -137,6 +138,59 @@ public class TestHostAwareContainerAllocator {
     allocatorThread.join();
   }
 
+  /**
+   * If the container fails to start e.g because it fails to connect to a NM on a host that
+   * is down, the allocator should request a new container on a different host.
+   */
+  @Test
+  public void testRerequestOnAnyHostIfContainerStartFails() throws Exception {
+    final Container container = TestUtil.getContainer(ConverterUtils.toContainerId("container_1350670447861_0003_01_000001"), "2", 123);
+    final Container container1 = TestUtil.getContainer(ConverterUtils.toContainerId("container_1350670447861_0003_01_000002"), "1", 123);
+
+    ((MockContainerUtil) containerUtil).containerStartException = new IOException("You shall not... connect to the NM!");
+
+    // Set up our final asserts before starting the allocator thread
+    MockContainerListener listener = new MockContainerListener(2, 1, 2, null, new Runnable() {
+      @Override
+      public void run() {
+        // The failed container should be released. The successful one should not.
+        assertNotNull(testAMRMClient.getRelease());
+        assertEquals(1, testAMRMClient.getRelease().size());
+        assertTrue(testAMRMClient.getRelease().contains(container.getId()));
+      }
+    },
+        new Runnable() {
+          @Override
+          public void run() {
+            // Test that the first request assignment had a preferred host and the retry didn't
+            assertEquals(2, requestState.assignedRequests.size());
+
+            SamzaContainerRequest request = requestState.assignedRequests.remove();
+            assertEquals(0, request.expectedContainerId);
+            assertEquals("2", request.getPreferredHost());
+
+            request = requestState.assignedRequests.remove();
+            assertEquals(0, request.expectedContainerId);
+            assertEquals("ANY_HOST", request.getPreferredHost());
+
+            // This routine should be called after the retry is assigned, but before it's started.
+            // So there should still be 1 container needed.
+            assertEquals(1, state.neededContainers.get());
+          }
+        }
+    );
+    requestState.registerContainerListener(listener);
+
+    // Only request 1 container and we should see 2 assignments in the assertions above (because of the retry)
+    containerAllocator.requestContainer(0, "2");
+    containerAllocator.addContainer(container1);
+    containerAllocator.addContainer(container);
+
+    allocatorThread.start();
+
+    listener.verify();
+  }
+
   @Test
   public void testAllocatorReleasesExtraContainers() throws Exception {
     final Container container = TestUtil.getContainer(ConverterUtils.toContainerId("container_1350670447861_0003_01_000001"), "abc", 123);
@@ -144,7 +198,7 @@ public class TestHostAwareContainerAllocator {
     final Container container2 = TestUtil.getContainer(ConverterUtils.toContainerId("container_1350670447861_0003_01_000003"), "def", 123);
 
     // Set up our final asserts before starting the allocator thread
-    MockContainerListener listener = new MockContainerListener(3, 2, null, new Runnable() {
+    MockContainerListener listener = new MockContainerListener(3, 2, 0, null, new Runnable() {
       @Override
       public void run() {
         assertNotNull(testAMRMClient.getRelease());
@@ -158,7 +212,8 @@ public class TestHostAwareContainerAllocator {
         assertNull(requestState.getContainersOnAHost("abc"));
         assertNull(requestState.getContainersOnAHost("def"));
       }
-    });
+    },
+    null);
     requestState.registerContainerListener(listener);
 
     allocatorThread.start();
@@ -293,7 +348,7 @@ public class TestHostAwareContainerAllocator {
     assertTrue(requestState.getRequestsToCountMap().get("def").get() == 1);
 
     // Set up our final asserts before starting the allocator thread
-    MockContainerListener listener = new MockContainerListener(2, 0, new Runnable() {
+    MockContainerListener listener = new MockContainerListener(2, 0, 0, new Runnable() {
       @Override
       public void run() {
         assertNull(requestState.getContainersOnAHost("xyz"));
@@ -301,7 +356,7 @@ public class TestHostAwareContainerAllocator {
         assertNotNull(requestState.getContainersOnAHost(ANY_HOST));
         assertTrue(requestState.getContainersOnAHost(ANY_HOST).size() == 2);
       }
-    }, null);
+    }, null, null);
     requestState.registerContainerListener(listener);
 
     allocatorThread.start();

http://git-wip-us.apache.org/repos/asf/samza/blob/bfba03b7/samza-yarn/src/test/java/org/apache/samza/job/yarn/TestSamzaTaskManager.java
----------------------------------------------------------------------
diff --git a/samza-yarn/src/test/java/org/apache/samza/job/yarn/TestSamzaTaskManager.java b/samza-yarn/src/test/java/org/apache/samza/job/yarn/TestSamzaTaskManager.java
index 88d9f24..9da1edf 100644
--- a/samza-yarn/src/test/java/org/apache/samza/job/yarn/TestSamzaTaskManager.java
+++ b/samza-yarn/src/test/java/org/apache/samza/job/yarn/TestSamzaTaskManager.java
@@ -281,7 +281,7 @@ public class TestSamzaTaskManager {
     taskManager.onInit();
 
     assertFalse(taskManager.shouldShutdown());
-    assertEquals(1, allocator.containerRequestState.getRequestsQueue().size());
+    assertEquals(1, allocator.getContainerRequestState().getRequestsQueue().size());
 
     Container container = TestUtil.getContainer(ConverterUtils.toContainerId("container_1350670447861_0003_01_000002"), "abc", 123);
     taskManager.onContainerAllocated(container);
@@ -293,8 +293,8 @@ public class TestSamzaTaskManager {
     taskManager.onContainerCompleted(TestUtil.getContainerStatus(container.getId(), 1, "Expecting a failure here"));
 
     // The above failure should trigger a container request
-    assertEquals(1, allocator.containerRequestState.getRequestsQueue().size());
-    assertEquals(ContainerRequestState.ANY_HOST, allocator.containerRequestState.getRequestsQueue().peek().getPreferredHost());
+    assertEquals(1, allocator.getContainerRequestState().getRequestsQueue().size());
+    assertEquals(ContainerRequestState.ANY_HOST, allocator.getContainerRequestState().getRequestsQueue().peek().getPreferredHost());
     assertFalse(taskManager.shouldShutdown());
     assertFalse(state.jobHealthy.get());
     assertEquals(2, testAMRMClient.requests.size());
@@ -311,7 +311,7 @@ public class TestSamzaTaskManager {
     taskManager.onContainerCompleted(TestUtil.getContainerStatus(container.getId(), 1, "Expecting a failure here"));
 
     // The above failure should trigger a job shutdown because our retry count is set to 1
-    assertEquals(0, allocator.containerRequestState.getRequestsQueue().size());
+    assertEquals(0, allocator.getContainerRequestState().getRequestsQueue().size());
     assertEquals(2, testAMRMClient.requests.size());
     assertEquals(0, testAMRMClient.getRelease().size());
     assertFalse(state.jobHealthy.get());
@@ -347,7 +347,7 @@ public class TestSamzaTaskManager {
     taskManager.onInit();
 
     assertFalse(taskManager.shouldShutdown());
-    assertEquals(1, allocator.containerRequestState.getRequestsQueue().size());
+    assertEquals(1, allocator.getContainerRequestState().getRequestsQueue().size());
 
     Container container = TestUtil.getContainer(ConverterUtils.toContainerId("container_1350670447861_0003_01_000002"), "abc", 123);
     taskManager.onContainerAllocated(container);
@@ -359,8 +359,8 @@ public class TestSamzaTaskManager {
     taskManager.onContainerCompleted(TestUtil.getContainerStatus(container.getId(), 1, "Expecting a failure here"));
 
     // The above failure should trigger a container request
-    assertEquals(1, allocator.containerRequestState.getRequestsQueue().size());
-    assertEquals("abc", allocator.containerRequestState.getRequestsQueue().peek().getPreferredHost());
+    assertEquals(1, allocator.getContainerRequestState().getRequestsQueue().size());
+    assertEquals("abc", allocator.getContainerRequestState().getRequestsQueue().peek().getPreferredHost());
     assertFalse(taskManager.shouldShutdown());
     assertFalse(state.jobHealthy.get());
     assertEquals(2, testAMRMClient.requests.size());
@@ -377,7 +377,7 @@ public class TestSamzaTaskManager {
     taskManager.onContainerCompleted(TestUtil.getContainerStatus(container.getId(), 1, "Expecting a failure here"));
 
     // The above failure should trigger a job shutdown because our retry count is set to 1
-    assertEquals(0, allocator.containerRequestState.getRequestsQueue().size());
+    assertEquals(0, allocator.getContainerRequestState().getRequestsQueue().size());
     assertEquals(2, testAMRMClient.requests.size());
     assertEquals(0, testAMRMClient.getRelease().size());
     assertFalse(state.jobHealthy.get());
@@ -415,7 +415,7 @@ public class TestSamzaTaskManager {
     // Start the task manager
     taskManager.onInit();
     assertFalse(taskManager.shouldShutdown());
-    assertEquals(1, allocator.containerRequestState.getRequestsQueue().size());
+    assertEquals(1, allocator.getContainerRequestState().getRequestsQueue().size());
 
     Container container = TestUtil.getContainer(ConverterUtils.toContainerId("container_1350670447861_0003_01_000002"), "abc", 123);
     taskManager.onContainerAllocated(container);
@@ -427,32 +427,32 @@ public class TestSamzaTaskManager {
     taskManager.onContainerCompleted(TestUtil.getContainerStatus(container.getId(), ContainerExitStatus.DISKS_FAILED, "Disk failure"));
 
     // The above failure should trigger a container request
-    assertEquals(1, allocator.containerRequestState.getRequestsQueue().size());
+    assertEquals(1, allocator.getContainerRequestState().getRequestsQueue().size());
     assertFalse(taskManager.shouldShutdown());
     assertFalse(state.jobHealthy.get());
     assertEquals(2, testAMRMClient.requests.size());
     assertEquals(0, testAMRMClient.getRelease().size());
-    assertEquals(ContainerRequestState.ANY_HOST, allocator.containerRequestState.getRequestsQueue().peek().getPreferredHost());
+    assertEquals(ContainerRequestState.ANY_HOST, allocator.getContainerRequestState().getRequestsQueue().peek().getPreferredHost());
 
     // Create container failure - with ContainerExitStatus.PREEMPTED
     taskManager.onContainerCompleted(TestUtil.getContainerStatus(container.getId(), ContainerExitStatus.PREEMPTED, "Task Preempted by RM"));
 
     // The above failure should trigger a container request
-    assertEquals(1, allocator.containerRequestState.getRequestsQueue().size());
+    assertEquals(1, allocator.getContainerRequestState().getRequestsQueue().size());
     assertFalse(taskManager.shouldShutdown());
     assertFalse(state.jobHealthy.get());
-    assertEquals(ContainerRequestState.ANY_HOST, allocator.containerRequestState.getRequestsQueue().peek().getPreferredHost());
+    assertEquals(ContainerRequestState.ANY_HOST, allocator.getContainerRequestState().getRequestsQueue().peek().getPreferredHost());
 
     // Create container failure - with ContainerExitStatus.ABORTED
     taskManager.onContainerCompleted(TestUtil.getContainerStatus(container.getId(), ContainerExitStatus.ABORTED, "Task Aborted by the NM"));
 
     // The above failure should trigger a container request
-    assertEquals(1, allocator.containerRequestState.getRequestsQueue().size());
+    assertEquals(1, allocator.getContainerRequestState().getRequestsQueue().size());
     assertEquals(2, testAMRMClient.requests.size());
     assertEquals(0, testAMRMClient.getRelease().size());
     assertFalse(taskManager.shouldShutdown());
     assertFalse(state.jobHealthy.get());
-    assertEquals(ContainerRequestState.ANY_HOST, allocator.containerRequestState.getRequestsQueue().peek().getPreferredHost());
+    assertEquals(ContainerRequestState.ANY_HOST, allocator.getContainerRequestState().getRequestsQueue().peek().getPreferredHost());
 
     taskManager.onShutdown();
   }

http://git-wip-us.apache.org/repos/asf/samza/blob/bfba03b7/samza-yarn/src/test/java/org/apache/samza/job/yarn/util/MockContainerAllocator.java
----------------------------------------------------------------------
diff --git a/samza-yarn/src/test/java/org/apache/samza/job/yarn/util/MockContainerAllocator.java b/samza-yarn/src/test/java/org/apache/samza/job/yarn/util/MockContainerAllocator.java
index 5fcad82..3290247 100644
--- a/samza-yarn/src/test/java/org/apache/samza/job/yarn/util/MockContainerAllocator.java
+++ b/samza-yarn/src/test/java/org/apache/samza/job/yarn/util/MockContainerAllocator.java
@@ -18,10 +18,13 @@
  */
 package org.apache.samza.job.yarn.util;
 
+import java.lang.reflect.Field;
 import org.apache.hadoop.yarn.client.api.AMRMClient;
 import org.apache.hadoop.yarn.client.api.async.AMRMClientAsync;
 import org.apache.samza.config.YarnConfig;
+import org.apache.samza.job.yarn.AbstractContainerAllocator;
 import org.apache.samza.job.yarn.ContainerAllocator;
+import org.apache.samza.job.yarn.ContainerRequestState;
 import org.apache.samza.job.yarn.ContainerUtil;
 
 import java.util.Map;
@@ -40,4 +43,11 @@ public class MockContainerAllocator extends ContainerAllocator {
     requestedContainers += containerToHostMappings.size();
     super.requestContainers(containerToHostMappings);
   }
+
+  public ContainerRequestState getContainerRequestState() throws Exception {
+    Field field = AbstractContainerAllocator.class.getDeclaredField("containerRequestState");
+    field.setAccessible(true);
+
+    return (ContainerRequestState) field.get(this);
+  }
 }

http://git-wip-us.apache.org/repos/asf/samza/blob/bfba03b7/samza-yarn/src/test/java/org/apache/samza/job/yarn/util/MockContainerListener.java
----------------------------------------------------------------------
diff --git a/samza-yarn/src/test/java/org/apache/samza/job/yarn/util/MockContainerListener.java b/samza-yarn/src/test/java/org/apache/samza/job/yarn/util/MockContainerListener.java
index 8fc0b98..cb82ccc 100644
--- a/samza-yarn/src/test/java/org/apache/samza/job/yarn/util/MockContainerListener.java
+++ b/samza-yarn/src/test/java/org/apache/samza/job/yarn/util/MockContainerListener.java
@@ -19,27 +19,37 @@
 
 package org.apache.samza.job.yarn.util;
 
+import java.util.HashMap;
+import java.util.Map;
 import org.apache.hadoop.yarn.api.records.Container;
+import org.apache.samza.job.yarn.SamzaContainerRequest;
+import scala.tools.nsc.Global;
 
 import static org.junit.Assert.assertTrue;
 
 public class MockContainerListener {
-  private static final int NUM_CONDITIONS = 2;
+  private static final int NUM_CONDITIONS = 3;
   private boolean allContainersAdded = false;
   private boolean allContainersReleased = false;
   private final int numExpectedContainersAdded;
   private final int numExpectedContainersReleased;
+  private final int numExpectedContainersAssigned;
   private final Runnable addContainerAssertions;
   private final Runnable releaseContainerAssertions;
+  private final Runnable assignContainerAssertions;
 
   public MockContainerListener(int numExpectedContainersAdded,
       int numExpectedContainersReleased,
+      int numExpectedContainersAssigned,
       Runnable addContainerAssertions,
-      Runnable releaseContainerAssertions) {
+      Runnable releaseContainerAssertions,
+      Runnable assignContainerAssertions) {
     this.numExpectedContainersAdded = numExpectedContainersAdded;
     this.numExpectedContainersReleased = numExpectedContainersReleased;
+    this.numExpectedContainersAssigned = numExpectedContainersAssigned;
     this.addContainerAssertions = addContainerAssertions;
     this.releaseContainerAssertions = releaseContainerAssertions;
+    this.assignContainerAssertions = assignContainerAssertions;
   }
 
   public synchronized void postAddContainer(Container container, int totalAddedContainers) {
@@ -77,4 +87,14 @@ public class MockContainerListener {
     assertTrue("Not all containers were added.", allContainersAdded);
     assertTrue("Not all containers were released.", allContainersReleased);
   }
+
+  public void postUpdateRequestStateAfterAssignment(int totalAssignedContainers) {
+    if (totalAssignedContainers == numExpectedContainersAssigned) {
+      if (assignContainerAssertions != null) {
+        assignContainerAssertions.run();
+      }
+
+      this.notifyAll();
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/samza/blob/bfba03b7/samza-yarn/src/test/java/org/apache/samza/job/yarn/util/MockContainerRequestState.java
----------------------------------------------------------------------
diff --git a/samza-yarn/src/test/java/org/apache/samza/job/yarn/util/MockContainerRequestState.java b/samza-yarn/src/test/java/org/apache/samza/job/yarn/util/MockContainerRequestState.java
index e7441e5..879a7d0 100644
--- a/samza-yarn/src/test/java/org/apache/samza/job/yarn/util/MockContainerRequestState.java
+++ b/samza-yarn/src/test/java/org/apache/samza/job/yarn/util/MockContainerRequestState.java
@@ -19,23 +19,39 @@
 package org.apache.samza.job.yarn.util;
 
 import java.util.ArrayList;
+import java.util.LinkedList;
 import java.util.List;
+import java.util.Queue;
 import org.apache.hadoop.yarn.api.records.Container;
 import org.apache.hadoop.yarn.client.api.AMRMClient;
 import org.apache.hadoop.yarn.client.api.async.AMRMClientAsync;
 import org.apache.samza.job.yarn.ContainerRequestState;
+import org.apache.samza.job.yarn.SamzaContainerRequest;
 
 
 public class MockContainerRequestState extends ContainerRequestState {
   private final List<MockContainerListener> _mockContainerListeners = new ArrayList<MockContainerListener>();
   private int numAddedContainers = 0;
   private int numReleasedContainers = 0;
+  private int numAssignedContainers = 0;
+  public Queue<SamzaContainerRequest> assignedRequests = new LinkedList<>();
 
   public MockContainerRequestState(AMRMClientAsync<AMRMClient.ContainerRequest> amClient,
       boolean hostAffinityEnabled) {
     super(amClient, hostAffinityEnabled);
   }
 
+  @Override
+  public synchronized void updateStateAfterAssignment(SamzaContainerRequest request, String assignedHost, Container container) {
+    super.updateStateAfterAssignment(request, assignedHost, container);
+
+    numAssignedContainers++;
+    assignedRequests.add(request);
+
+    for (MockContainerListener listener : _mockContainerListeners) {
+      listener.postUpdateRequestStateAfterAssignment(numAssignedContainers);
+    }
+  }
 
   @Override
   public synchronized void addContainer(Container container) {
@@ -58,6 +74,16 @@ public class MockContainerRequestState extends ContainerRequestState {
     return numAddedContainers;
   }
 
+  @Override
+  public void releaseUnstartableContainer(Container container) {
+    super.releaseUnstartableContainer(container);
+
+    numReleasedContainers += 1;
+    for (MockContainerListener listener : _mockContainerListeners) {
+      listener.postReleaseContainers(numReleasedContainers);
+    }
+  }
+
   public void registerContainerListener(MockContainerListener listener) {
     _mockContainerListeners.add(listener);
   }

http://git-wip-us.apache.org/repos/asf/samza/blob/bfba03b7/samza-yarn/src/test/java/org/apache/samza/job/yarn/util/MockContainerUtil.java
----------------------------------------------------------------------
diff --git a/samza-yarn/src/test/java/org/apache/samza/job/yarn/util/MockContainerUtil.java b/samza-yarn/src/test/java/org/apache/samza/job/yarn/util/MockContainerUtil.java
index 4426ce6..2f9669f 100644
--- a/samza-yarn/src/test/java/org/apache/samza/job/yarn/util/MockContainerUtil.java
+++ b/samza-yarn/src/test/java/org/apache/samza/job/yarn/util/MockContainerUtil.java
@@ -18,6 +18,10 @@
  */
 package org.apache.samza.job.yarn.util;
 
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
 import org.apache.hadoop.fs.Path;
 import org.apache.hadoop.yarn.api.records.Container;
 import org.apache.hadoop.yarn.client.api.NMClient;
@@ -25,14 +29,12 @@ import org.apache.hadoop.yarn.conf.YarnConfiguration;
 import org.apache.samza.config.Config;
 import org.apache.samza.job.yarn.ContainerUtil;
 import org.apache.samza.job.yarn.SamzaAppState;
+import org.apache.samza.job.yarn.SamzaContainerLaunchException;
 
-import java.util.ArrayList;
-import java.util.List;
-import java.util.Map;
-import java.util.HashMap;
 
 public class MockContainerUtil extends ContainerUtil {
   public final Map<String, List<Container>> runningContainerList = new HashMap<>();
+  public Exception containerStartException = null;
 
   public MockContainerUtil(Config config, SamzaAppState state, YarnConfiguration conf, NMClient nmClient) {
     super(config, state, conf);
@@ -40,7 +42,7 @@ public class MockContainerUtil extends ContainerUtil {
   }
 
   @Override
-  public void runContainer(int samzaContainerId, Container container) {
+  public void runContainer(int samzaContainerId, Container container) throws SamzaContainerLaunchException {
     String hostname = container.getNodeHttpAddress().split(":")[0];
     List<Container> list = runningContainerList.get(hostname);
     if (list == null) {
@@ -55,7 +57,11 @@ public class MockContainerUtil extends ContainerUtil {
   }
 
   @Override
-  public void startContainer(Path packagePath, Container container, Map<String, String> env, String cmd) {
+  public void startContainer(Path packagePath, Container container, Map<String, String> env, String cmd) throws
+                                                                                                         SamzaContainerLaunchException {
+    if (containerStartException != null) {
+      throw new SamzaContainerLaunchException(containerStartException);
+    }
   }
 
 }

http://git-wip-us.apache.org/repos/asf/samza/blob/bfba03b7/samza-yarn/src/test/java/org/apache/samza/job/yarn/util/TestAMRMClientImpl.java
----------------------------------------------------------------------
diff --git a/samza-yarn/src/test/java/org/apache/samza/job/yarn/util/TestAMRMClientImpl.java b/samza-yarn/src/test/java/org/apache/samza/job/yarn/util/TestAMRMClientImpl.java
index 951e0f9..59226ca 100644
--- a/samza-yarn/src/test/java/org/apache/samza/job/yarn/util/TestAMRMClientImpl.java
+++ b/samza-yarn/src/test/java/org/apache/samza/job/yarn/util/TestAMRMClientImpl.java
@@ -63,6 +63,12 @@ public class TestAMRMClientImpl extends AMRMClientImpl<ContainerRequest> {
   }
 
   @Override
+  public synchronized void releaseAssignedContainer(ContainerId containerId) {
+    pendingRelease.add(containerId);
+    release.add(containerId);
+  }
+
+  @Override
   public void unregisterApplicationMaster(FinalApplicationStatus appStatus, String appMessage, String appTrackingUrl)
       throws YarnException, IOException { }