You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@samza.apache.org by jm...@apache.org on 2018/05/18 21:23:18 UTC

samza git commit: SAMZA-1508: JobRunner should not return success until the job is healthy

Repository: samza
Updated Branches:
  refs/heads/master 171793b69 -> 7a2b4650c


SAMZA-1508: JobRunner should not return success until the job is healthy

Author: Jacob Maes <jm...@apache.org>
Author: Jacob Maes <jm...@linkedin.com>

Reviewers: Prateek Maheshwari <pm...@linkedin.com>

Closes #367 from jmakes/samza-1508


Project: http://git-wip-us.apache.org/repos/asf/samza/repo
Commit: http://git-wip-us.apache.org/repos/asf/samza/commit/7a2b4650
Tree: http://git-wip-us.apache.org/repos/asf/samza/tree/7a2b4650
Diff: http://git-wip-us.apache.org/repos/asf/samza/diff/7a2b4650

Branch: refs/heads/master
Commit: 7a2b4650cd7c87f0475bc06d306cd4ef834377b0
Parents: 171793b
Author: Jacob Maes <jm...@apache.org>
Authored: Fri May 18 14:23:07 2018 -0700
Committer: Jacob Maes <--global>
Committed: Fri May 18 14:23:07 2018 -0700

----------------------------------------------------------------------
 build.gradle                                    |   1 +
 .../scala/org/apache/samza/job/JobRunner.scala  |  38 +--
 .../webapp/ApplicationMasterRestClient.java     | 111 +++++++
 .../apache/samza/job/yarn/ClientHelper.scala    |  54 ++-
 .../org/apache/samza/job/yarn/YarnJob.scala     |  10 +-
 .../webapp/ApplicationMasterRestServlet.scala   |  76 +++--
 .../webapp/TestApplicationMasterRestClient.java | 330 +++++++++++++++++++
 .../samza/job/yarn/TestClientHelper.scala       |  36 +-
 8 files changed, 583 insertions(+), 73 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/samza/blob/7a2b4650/build.gradle
----------------------------------------------------------------------
diff --git a/build.gradle b/build.gradle
index 2f27a03..0b4dae5 100644
--- a/build.gradle
+++ b/build.gradle
@@ -448,6 +448,7 @@ project(":samza-yarn_$scalaVersion") {
     compile "org.scala-lang:scala-compiler:$scalaLibVersion"
     compile "org.codehaus.jackson:jackson-mapper-asl:$jacksonVersion"
     compile "commons-httpclient:commons-httpclient:$commonsHttpClientVersion"
+    compile "org.apache.httpcomponents:httpclient:$httpClientVersion"
     compile("org.apache.hadoop:hadoop-yarn-api:$yarnVersion") {
       exclude module: 'slf4j-log4j12'
     }

http://git-wip-us.apache.org/repos/asf/samza/blob/7a2b4650/samza-core/src/main/scala/org/apache/samza/job/JobRunner.scala
----------------------------------------------------------------------
diff --git a/samza-core/src/main/scala/org/apache/samza/job/JobRunner.scala b/samza-core/src/main/scala/org/apache/samza/job/JobRunner.scala
index 917c018..c6e14f2 100644
--- a/samza-core/src/main/scala/org/apache/samza/job/JobRunner.scala
+++ b/samza-core/src/main/scala/org/apache/samza/job/JobRunner.scala
@@ -20,6 +20,8 @@
 package org.apache.samza.job
 
 
+import java.util.concurrent.TimeUnit
+
 import org.apache.samza.SamzaException
 import org.apache.samza.config.Config
 import org.apache.samza.config.JobConfig.Config2Job
@@ -117,23 +119,11 @@ class JobRunner(config: Config) extends Logging {
     coordinatorSystemProducer.stop()
 
     // Create the actual job, and submit it.
-    val job = jobFactory.getJob(config).submit
-
-    info("waiting for job to start")
-
-    // Wait until the job has started, then exit.
-    Option(job.waitForStatus(Running, 500)) match {
-      case Some(appStatus) => {
-        if (Running.equals(appStatus)) {
-          info("job started successfully - " + appStatus)
-        } else {
-          warn("unable to start job successfully. job has status %s" format (appStatus))
-        }
-      }
-      case _ => warn("unable to start job successfully.")
-    }
+    val job = jobFactory.getJob(config)
+
+    job.submit()
 
-    info("exiting")
+    info("Job submitted. Check status to determine when it is running.")
     job
   }
 
@@ -143,21 +133,7 @@ class JobRunner(config: Config) extends Logging {
     // Create the actual job, and kill it.
     val job = jobFactory.getJob(config).kill()
 
-    info("waiting for job to terminate")
-
-    // Wait until the job has terminated, then exit.
-    Option(job.waitForFinish(5000)) match {
-      case Some(appStatus) => {
-        if (SuccessfulFinish.equals(appStatus)) {
-          info("job terminated successfully - " + appStatus)
-        } else {
-          warn("unable to terminate job successfully. job has status %s" format (appStatus))
-        }
-      }
-      case _ => warn("unable to terminate job successfully.")
-    }
-
-    info("exiting")
+    info("Kill command executed. Check status to determine when it is terminated.")
   }
 
   def status(): ApplicationStatus = {

http://git-wip-us.apache.org/repos/asf/samza/blob/7a2b4650/samza-yarn/src/main/java/org/apache/samza/webapp/ApplicationMasterRestClient.java
----------------------------------------------------------------------
diff --git a/samza-yarn/src/main/java/org/apache/samza/webapp/ApplicationMasterRestClient.java b/samza-yarn/src/main/java/org/apache/samza/webapp/ApplicationMasterRestClient.java
new file mode 100644
index 0000000..eed16db
--- /dev/null
+++ b/samza-yarn/src/main/java/org/apache/samza/webapp/ApplicationMasterRestClient.java
@@ -0,0 +1,111 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.samza.webapp;
+
+import java.io.Closeable;
+import java.io.IOException;
+import java.util.Map;
+import org.apache.http.HttpHost;
+import org.apache.http.HttpResponse;
+import org.apache.http.HttpStatus;
+import org.apache.http.StatusLine;
+import org.apache.http.client.methods.HttpGet;
+import org.apache.http.impl.client.CloseableHttpClient;
+import org.apache.http.util.EntityUtils;
+import org.apache.samza.SamzaException;
+import org.apache.samza.serializers.model.SamzaObjectMapper;
+import org.codehaus.jackson.map.ObjectMapper;
+import org.codehaus.jackson.type.TypeReference;
+
+
+/**
+ * Client for the {@link ApplicationMasterRestServlet}.
+ */
+public class ApplicationMasterRestClient implements Closeable {
+  private final CloseableHttpClient httpClient;
+  private final HttpHost appMasterHost;
+  private final ObjectMapper jsonMapper = SamzaObjectMapper.getObjectMapper();
+
+  public ApplicationMasterRestClient(CloseableHttpClient client, String amHostName, int amRpcPort) {
+    httpClient = client;
+    appMasterHost = new HttpHost(amHostName, amRpcPort);
+  }
+
+  /**
+   * @return  the metrics as a map of groupName to metricName to metricValue.
+   * @throws IOException if there was an error fetching the metrics from the servlet.
+   */
+  public Map<String, Map<String, Object>> getMetrics() throws IOException {
+    String jsonString = getEntityAsJson("/metrics", "metrics");
+    return jsonMapper.readValue(jsonString, new TypeReference<Map<String, Map<String, Object>>>() {});
+  }
+
+  /**
+   * @return  the task context as a map of key to value
+   * @throws IOException if there was an error fetching the task context from the servlet.
+   */
+  public Map<String, Object> getTaskContext() throws IOException {
+    String jsonString = getEntityAsJson("/task-context", "task context");
+    return jsonMapper.readValue(jsonString, new TypeReference<Map<String, Object>>() {});
+  }
+
+  /**
+   * @return  the AM state as a map of key to value
+   * @throws IOException if there was an error fetching the AM state from the servlet.
+   */
+  public Map<String, Object> getAmState() throws IOException {
+    String jsonString = getEntityAsJson("/am", "AM state");
+    return jsonMapper.readValue(jsonString, new TypeReference<Map<String, Object>>() {});
+  }
+
+  /**
+   * @return  the config as a map of key to value
+   * @throws IOException if there was an error fetching the config from the servlet.
+   */
+  public Map<String, Object> getConfig() throws IOException {
+    String jsonString = getEntityAsJson("/config", "config");
+    return jsonMapper.readValue(jsonString, new TypeReference<Map<String, Object>>() {});
+  }
+
+  @Override
+  public void close() throws IOException {
+    httpClient.close();
+  }
+
+  private String getEntityAsJson(String path, String entityName) throws IOException {
+    HttpGet getRequest = new HttpGet(path);
+    HttpResponse httpResponse = httpClient.execute(appMasterHost, getRequest);
+
+    StatusLine status = httpResponse.getStatusLine();
+    if (status.getStatusCode() != HttpStatus.SC_OK) {
+      throw new SamzaException(String.format(
+          "Error retrieving %s from host %s. Response: %s",
+          entityName,
+          appMasterHost.toURI(),
+          status.getReasonPhrase()));
+    }
+
+    return EntityUtils.toString(httpResponse.getEntity());
+  }
+
+  @Override
+  public String toString() {
+    return "AppMasterClient for uri: " + appMasterHost.toURI().toString();
+  }
+}

http://git-wip-us.apache.org/repos/asf/samza/blob/7a2b4650/samza-yarn/src/main/scala/org/apache/samza/job/yarn/ClientHelper.scala
----------------------------------------------------------------------
diff --git a/samza-yarn/src/main/scala/org/apache/samza/job/yarn/ClientHelper.scala b/samza-yarn/src/main/scala/org/apache/samza/job/yarn/ClientHelper.scala
index e6c0896..d193ddb 100644
--- a/samza-yarn/src/main/scala/org/apache/samza/job/yarn/ClientHelper.scala
+++ b/samza-yarn/src/main/scala/org/apache/samza/job/yarn/ClientHelper.scala
@@ -19,7 +19,6 @@
 
 package org.apache.samza.job.yarn
 
-
 import org.apache.commons.lang.StringUtils
 import org.apache.hadoop.fs.permission.FsPermission
 import org.apache.samza.config.{Config, JobConfig, YarnConfig}
@@ -57,6 +56,9 @@ import org.apache.samza.util.Logging
 import java.io.IOException
 import java.nio.ByteBuffer
 
+import org.apache.http.impl.client.HttpClientBuilder
+import org.apache.samza.webapp.ApplicationMasterRestClient
+
 object ClientHelper {
   val applicationType = "Samza"
 
@@ -81,6 +83,13 @@ class ClientHelper(conf: Configuration) extends Logging {
     yarnClient
   }
 
+  private[yarn] def createAmClient(applicationReport: ApplicationReport) = {
+    val amHostName = applicationReport.getHost
+    val rpcPort = applicationReport.getRpcPort
+
+    new ApplicationMasterRestClient(HttpClientBuilder.create.build, amHostName, rpcPort)
+  }
+
   var jobContext: JobContext = null
 
   /**
@@ -242,8 +251,8 @@ class ClientHelper(conf: Configuration) extends Logging {
 
     applicationReports
       .asScala
-        .filter(applicationReport => isActiveApplication(applicationReport)
-          && appName.equals(applicationReport.getName))
+        .filter(applicationReport => appName.equals(applicationReport.getName)
+          && isActiveApplication(applicationReport))
         .map(applicationReport => applicationReport.getApplicationId)
         .toList
   }
@@ -253,8 +262,8 @@ class ClientHelper(conf: Configuration) extends Logging {
 
     applicationReports
       .asScala
-      .filter(applicationReport => (!(isActiveApplication(applicationReport))
-        && appName.equals(applicationReport.getName)))
+      .filter(applicationReport => appName.equals(applicationReport.getName)
+        && (!isActiveApplication(applicationReport)))
       .map(applicationReport => applicationReport.getApplicationId)
       .toList
   }
@@ -305,8 +314,39 @@ class ClientHelper(conf: Configuration) extends Logging {
         } else {
           Some(ApplicationStatus.unsuccessfulFinish(new SamzaException(diagnostics)))
         }
-      case (YarnApplicationState.NEW, _) | (YarnApplicationState.SUBMITTED, _) => Some(New)
-      case _ => Some(Running)
+      case (YarnApplicationState.RUNNING, _) =>
+        if (allContainersRunning(applicationReport)) {
+          Some(Running)
+        } else {
+          Some(New)
+        }
+      case _ =>
+        Some(New)
+    }
+  }
+
+  def allContainersRunning(applicationReport: ApplicationReport): Boolean = {
+    val amClient: ApplicationMasterRestClient = createAmClient(applicationReport)
+
+    debug("Created client: " + amClient.toString)
+
+    try {
+      val metrics = amClient.getMetrics
+      debug("Got metrics: " + metrics.toString)
+
+      val neededContainers = Integer.parseInt(
+        metrics.get(classOf[SamzaAppMasterMetrics].getCanonicalName)
+          .get("needed-containers")
+          .toString)
+
+      info("Needed containers: " + neededContainers)
+      if (neededContainers == 0) {
+        true
+      } else {
+        false
+      }
+    } finally {
+      amClient.close()
     }
   }
 

http://git-wip-us.apache.org/repos/asf/samza/blob/7a2b4650/samza-yarn/src/main/scala/org/apache/samza/job/yarn/YarnJob.scala
----------------------------------------------------------------------
diff --git a/samza-yarn/src/main/scala/org/apache/samza/job/yarn/YarnJob.scala b/samza-yarn/src/main/scala/org/apache/samza/job/yarn/YarnJob.scala
index 3e9d0fe..d335448 100644
--- a/samza-yarn/src/main/scala/org/apache/samza/job/yarn/YarnJob.scala
+++ b/samza-yarn/src/main/scala/org/apache/samza/job/yarn/YarnJob.scala
@@ -21,9 +21,10 @@ package org.apache.samza.job.yarn
 import org.apache.hadoop.conf.Configuration
 import org.apache.hadoop.yarn.api.ApplicationConstants
 import org.apache.hadoop.yarn.api.records.ApplicationId
+import org.apache.samza.SamzaException
 import org.apache.samza.config.JobConfig.Config2Job
 import org.apache.samza.config.{Config, JobConfig, ShellCommandConfig, YarnConfig}
-import org.apache.samza.job.ApplicationStatus.{Running, SuccessfulFinish, UnsuccessfulFinish}
+import org.apache.samza.job.ApplicationStatus.{SuccessfulFinish, UnsuccessfulFinish}
 import org.apache.samza.job.{ApplicationStatus, StreamJob}
 import org.apache.samza.serializers.model.SamzaObjectMapper
 import org.apache.samza.util.{CoordinatorStreamUtil, Util}
@@ -115,7 +116,7 @@ class YarnJob(config: Config, hadoopConfig: Configuration) extends StreamJob {
       Thread.sleep(1000)
     }
 
-    Running
+    getStatus
   }
 
   def waitForStatus(status: ApplicationStatus, timeoutMs: Long): ApplicationStatus = {
@@ -130,14 +131,15 @@ class YarnJob(config: Config, hadoopConfig: Configuration) extends StreamJob {
       Thread.sleep(1000)
     }
 
-    Running
+    getStatus
   }
 
   def getStatus: ApplicationStatus = {
     getAppId match {
       case Some(appId) =>
         logger.info("Getting status for applicationId %s" format appId)
-        client.status(appId).getOrElse(null)
+        client.status(appId).getOrElse(
+          throw new SamzaException("No status was determined for applicationId %s" format appId))
       case None =>
         logger.info("Unable to report status because no applicationId could be found.")
         ApplicationStatus.SuccessfulFinish

http://git-wip-us.apache.org/repos/asf/samza/blob/7a2b4650/samza-yarn/src/main/scala/org/apache/samza/webapp/ApplicationMasterRestServlet.scala
----------------------------------------------------------------------
diff --git a/samza-yarn/src/main/scala/org/apache/samza/webapp/ApplicationMasterRestServlet.scala b/samza-yarn/src/main/scala/org/apache/samza/webapp/ApplicationMasterRestServlet.scala
index 122a1df..5c10987 100644
--- a/samza-yarn/src/main/scala/org/apache/samza/webapp/ApplicationMasterRestServlet.scala
+++ b/samza-yarn/src/main/scala/org/apache/samza/webapp/ApplicationMasterRestServlet.scala
@@ -19,44 +19,41 @@
 
 package org.apache.samza.webapp
 
+import java.{lang, util}
+
 import org.apache.samza.clustermanager.SamzaApplicationState
 import org.scalatra._
 import scalate.ScalateSupport
 import org.apache.samza.config.Config
-import org.apache.samza.job.yarn.{YarnAppState, ClientHelper}
+import org.apache.samza.job.yarn.{ClientHelper, YarnAppState}
 import org.apache.samza.metrics._
+
 import scala.collection.JavaConverters._
 import org.apache.hadoop.yarn.conf.YarnConfiguration
 import java.util.HashMap
-import org.apache.samza.serializers.model.SamzaObjectMapper
 
-class ApplicationMasterRestServlet(samzaConfig: Config, samzaAppState: SamzaApplicationState, state: YarnAppState, registry: ReadableMetricsRegistry) extends ScalatraServlet with ScalateSupport {
-  val yarnConfig = new YarnConfiguration
-  val client = new ClientHelper(yarnConfig)
-  val jsonMapper = SamzaObjectMapper.getObjectMapper
-
-  before() {
-    contentType = "application/json"
-  }
+import org.apache.samza.serializers.model.SamzaObjectMapper
+import org.codehaus.jackson.map.ObjectMapper
 
-  get("/metrics") {
-    val metricMap = new HashMap[String, java.util.Map[String, Object]]
+object ApplicationMasterRestServlet {
+  def getMetrics(jsonMapper: ObjectMapper, metricsRegistry: ReadableMetricsRegistry) = {
+    val metricMap = new HashMap[String, util.Map[String, Object]]
 
     // build metric map
-    registry.getGroups.asScala.foreach(group => {
+    metricsRegistry.getGroups.asScala.foreach(group => {
       val groupMap = new HashMap[String, Object]
 
-      registry.getGroup(group).asScala.foreach {
+      metricsRegistry.getGroup(group).asScala.foreach {
         case (name, metric) =>
           metric.visit(new MetricsVisitor() {
             def counter(counter: Counter) =
-              groupMap.put(counter.getName, counter.getCount: java.lang.Long)
+              groupMap.put(counter.getName, counter.getCount: lang.Long)
 
             def gauge[T](gauge: Gauge[T]) =
-              groupMap.put(gauge.getName, gauge.getValue.asInstanceOf[java.lang.Object])
+              groupMap.put(gauge.getName, gauge.getValue.asInstanceOf[Object])
 
             def timer(timer: Timer) =
-              groupMap.put(timer.getName, timer.getSnapshot().getAverage: java.lang.Double)
+              groupMap.put(timer.getName, timer.getSnapshot().getAverage: lang.Double)
           })
       }
 
@@ -66,18 +63,18 @@ class ApplicationMasterRestServlet(samzaConfig: Config, samzaAppState: SamzaAppl
     jsonMapper.writeValueAsString(metricMap)
   }
 
-  get("/task-context") {
+  def getTaskContext(jsonMapper: ObjectMapper, state: YarnAppState) = {
     // sick of fighting with scala.. just using java map for now
     val contextMap = new HashMap[String, Object]
 
-    contextMap.put("task-id", state.taskId: java.lang.Integer)
+    contextMap.put("task-id", state.taskId: Integer)
     contextMap.put("name", state.amContainerId.toString)
 
     jsonMapper.writeValueAsString(contextMap)
   }
 
-  get("/am") {
-    val containers = new HashMap[String, HashMap[String, Object]]
+  def getAmState(jsonMapper: ObjectMapper, samzaAppState: SamzaApplicationState, state: YarnAppState) = {
+    val containers = new HashMap[String, util.HashMap[String, Object]]
 
     state.runningYarnContainers.asScala.foreach {
       case (containerId, container) =>
@@ -101,7 +98,42 @@ class ApplicationMasterRestServlet(samzaConfig: Config, samzaAppState: SamzaAppl
     jsonMapper.writeValueAsString(new HashMap[String, Object](status.asJava))
   }
 
-  get("/config") {
+  def getConfig(jsonMapper: ObjectMapper, samzaConfig: Config) = {
     jsonMapper.writeValueAsString(new HashMap[String, Object](samzaConfig.sanitize))
   }
 }
+
+/**
+  * Defines the Scalatra routes for the servlet.
+  */
+class ApplicationMasterRestServlet(samzaConfig: Config, samzaAppState: SamzaApplicationState, state: YarnAppState, registry: ReadableMetricsRegistry) extends ScalatraServlet with ScalateSupport {
+  val yarnConfig = new YarnConfiguration
+  val client = new ClientHelper(yarnConfig)
+  val jsonMapper = SamzaObjectMapper.getObjectMapper
+
+  before() {
+    contentType = "application/json"
+  }
+
+  get("/metrics") {
+    ApplicationMasterRestServlet.getMetrics(jsonMapper, registry)
+  }
+
+
+
+  get("/task-context") {
+    ApplicationMasterRestServlet.getTaskContext(jsonMapper, state)
+  }
+
+
+
+  get("/am") {
+    ApplicationMasterRestServlet.getAmState(jsonMapper, samzaAppState, state)
+  }
+
+
+
+  get("/config") {
+    ApplicationMasterRestServlet.getConfig(jsonMapper, samzaConfig)
+  }
+}

http://git-wip-us.apache.org/repos/asf/samza/blob/7a2b4650/samza-yarn/src/test/java/org/apache/samza/webapp/TestApplicationMasterRestClient.java
----------------------------------------------------------------------
diff --git a/samza-yarn/src/test/java/org/apache/samza/webapp/TestApplicationMasterRestClient.java b/samza-yarn/src/test/java/org/apache/samza/webapp/TestApplicationMasterRestClient.java
new file mode 100644
index 0000000..dbe534f
--- /dev/null
+++ b/samza-yarn/src/test/java/org/apache/samza/webapp/TestApplicationMasterRestClient.java
@@ -0,0 +1,330 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.samza.webapp;
+
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.ImmutableSet;
+import java.io.IOException;
+import java.io.StringReader;
+import java.net.MalformedURLException;
+import java.net.URL;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Set;
+import org.apache.commons.io.input.ReaderInputStream;
+import org.apache.hadoop.yarn.api.records.ApplicationAttemptId;
+import org.apache.hadoop.yarn.api.records.Container;
+import org.apache.hadoop.yarn.api.records.ContainerId;
+import org.apache.hadoop.yarn.util.ConverterUtils;
+import org.apache.http.HttpEntity;
+import org.apache.http.HttpHost;
+import org.apache.http.HttpStatus;
+import org.apache.http.StatusLine;
+import org.apache.http.client.methods.CloseableHttpResponse;
+import org.apache.http.client.methods.HttpGet;
+import org.apache.http.impl.client.CloseableHttpClient;
+import org.apache.samza.Partition;
+import org.apache.samza.SamzaException;
+import org.apache.samza.clustermanager.SamzaApplicationState;
+import org.apache.samza.clustermanager.SamzaResource;
+import org.apache.samza.config.Config;
+import org.apache.samza.config.MapConfig;
+import org.apache.samza.container.TaskName;
+import org.apache.samza.container.grouper.task.GroupByContainerCount;
+import org.apache.samza.coordinator.JobModelManager;
+import org.apache.samza.job.model.ContainerModel;
+import org.apache.samza.job.model.JobModel;
+import org.apache.samza.job.model.TaskModel;
+import org.apache.samza.job.yarn.SamzaAppMasterMetrics;
+import org.apache.samza.job.yarn.YarnAppState;
+import org.apache.samza.job.yarn.YarnContainer;
+import org.apache.samza.metrics.MetricsRegistryMap;
+import org.apache.samza.serializers.model.SamzaObjectMapper;
+import org.apache.samza.system.SystemStream;
+import org.apache.samza.system.SystemStreamPartition;
+import org.codehaus.jackson.map.ObjectMapper;
+import org.codehaus.jackson.type.TypeReference;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.ExpectedException;
+
+import static org.junit.Assert.*;
+import static org.mockito.Mockito.*;
+
+
+public class TestApplicationMasterRestClient {
+  private static final String AM_HOST_NAME = "dummyHost";
+  private static final int AM_RPC_PORT = 1337;
+  private static final int AM_HTTP_PORT = 7001;
+  private static final String YARN_CONTAINER_ID_1 = "container_e38_1510966221296_0007_01_000001";
+  private static final String YARN_CONTAINER_ID_2 = "container_e38_1510966221296_0007_01_000002";
+  private static final String YARN_CONTAINER_ID_3 = "container_e38_1510966221296_0007_01_000003";
+  private static final String APP_ATTEMPT_ID = "appattempt_1510966221296_0007_000001";
+
+  private final ObjectMapper jsonMapper = SamzaObjectMapper.getObjectMapper();
+
+  private CloseableHttpClient mockClient;
+
+  @Before
+  public void setup() {
+    mockClient = mock(CloseableHttpClient.class);
+  }
+
+  @Rule
+  public ExpectedException expectedException = ExpectedException.none(); // Enables us to verify the exception message
+
+  @Test
+  public void testGetMetricsSuccess() throws IOException {
+    SamzaApplicationState samzaAppState = createSamzaApplicationState();
+
+    MetricsRegistryMap registry = new MetricsRegistryMap();
+    assignMetricValues(samzaAppState, registry);
+
+    String response = ApplicationMasterRestServlet.getMetrics(jsonMapper, registry);
+    setupMockClientResponse(HttpStatus.SC_OK, "Success", response);
+
+    ApplicationMasterRestClient client = new ApplicationMasterRestClient(mockClient, AM_HOST_NAME, AM_RPC_PORT);
+    Map<String, Map<String, Object>> metricsResult = client.getMetrics();
+
+    String group = SamzaAppMasterMetrics.class.getCanonicalName();
+    assertEquals(1, metricsResult.size());
+    assertTrue(metricsResult.containsKey(group));
+
+    Map<String, Object> amMetricsGroup = metricsResult.get(group);
+    assertEquals(7, amMetricsGroup.size());
+    assertEquals(samzaAppState.runningContainers.size(),  amMetricsGroup.get("running-containers"));
+    assertEquals(samzaAppState.neededContainers.get(),    amMetricsGroup.get("needed-containers"));
+    assertEquals(samzaAppState.completedContainers.get(), amMetricsGroup.get("completed-containers"));
+    assertEquals(samzaAppState.failedContainers.get(),    amMetricsGroup.get("failed-containers"));
+    assertEquals(samzaAppState.releasedContainers.get(),  amMetricsGroup.get("released-containers"));
+    assertEquals(samzaAppState.containerCount.get(),      amMetricsGroup.get("container-count"));
+    assertEquals(samzaAppState.jobHealthy.get() ? 1 : 0,  amMetricsGroup.get("job-healthy"));
+  }
+
+  @Test
+  public void testGetMetricsError() throws IOException {
+    setupErrorTest("metrics");
+
+    ApplicationMasterRestClient client = new ApplicationMasterRestClient(mockClient, AM_HOST_NAME, AM_RPC_PORT);
+    client.getMetrics();
+  }
+
+  @Test
+  public void testGetTaskContextSuccess() throws IOException {
+    ContainerId containerId = ConverterUtils.toContainerId(YARN_CONTAINER_ID_1);
+    YarnAppState yarnAppState = createYarnAppState(containerId);
+
+    String response = ApplicationMasterRestServlet.getTaskContext(jsonMapper, yarnAppState);
+    setupMockClientResponse(HttpStatus.SC_OK, "Success", response);
+
+    ApplicationMasterRestClient client = new ApplicationMasterRestClient(mockClient, AM_HOST_NAME, AM_RPC_PORT);
+    Map<String, Object> taskContextResult = client.getTaskContext();
+
+    assertEquals(2, taskContextResult.size());
+    assertEquals(2, taskContextResult.get("task-id"));
+    assertEquals(containerId.toString(), taskContextResult.get("name"));
+  }
+
+  @Test
+  public void testTaskContextError() throws IOException {
+    setupErrorTest("task context");
+
+    ApplicationMasterRestClient client = new ApplicationMasterRestClient(mockClient, AM_HOST_NAME, AM_RPC_PORT);
+    client.getTaskContext();
+  }
+
+  @Test
+  public void testGetAmStateSuccess() throws IOException {
+    SamzaApplicationState samzaAppState = createSamzaApplicationState();
+
+    ApplicationAttemptId attemptId = ConverterUtils.toApplicationAttemptId(APP_ATTEMPT_ID);
+    ContainerId containerId = ConverterUtils.toContainerId(YARN_CONTAINER_ID_1);
+    YarnAppState yarnAppState = createYarnAppState(containerId);
+
+    String response = ApplicationMasterRestServlet.getAmState(jsonMapper, samzaAppState, yarnAppState);
+    setupMockClientResponse(HttpStatus.SC_OK, "Success", response);
+
+    ApplicationMasterRestClient client = new ApplicationMasterRestClient(mockClient, AM_HOST_NAME, AM_RPC_PORT);
+    Map<String, Object> amStateResult = client.getAmState();
+
+    assertEquals(4, amStateResult.size());
+    assertEquals(String.format("%s:%s", yarnAppState.nodeHost, yarnAppState.rpcUrl.getPort()), amStateResult.get("host"));
+    assertEquals(containerId.toString(), amStateResult.get("container-id"));
+    // Can only validate the keys because up-time changes everytime it's requested
+    assertEquals(buildExpectedContainerResponse(yarnAppState.runningYarnContainers, samzaAppState).keySet(),
+        ((Map<String, Object>) amStateResult.get("containers")).keySet());
+    assertEquals(attemptId.toString(), amStateResult.get("app-attempt-id"));
+  }
+
+  @Test
+  public void testGetAmStateError() throws IOException {
+    setupErrorTest("AM state");
+
+    ApplicationMasterRestClient client = new ApplicationMasterRestClient(mockClient, AM_HOST_NAME, AM_RPC_PORT);
+    client.getAmState();
+  }
+
+  @Test
+  public void testGetConfigSuccess() throws IOException {
+    SamzaApplicationState samzaAppState = createSamzaApplicationState();
+
+    Map<String, String> configMap = ImmutableMap.of("key1", "value1", "key2", "value2");
+    Config config = new MapConfig(configMap);
+
+    String response = ApplicationMasterRestServlet.getConfig(jsonMapper, config);
+    setupMockClientResponse(HttpStatus.SC_OK, "Success", response);
+
+    ApplicationMasterRestClient client = new ApplicationMasterRestClient(mockClient, AM_HOST_NAME, AM_RPC_PORT);
+    Map<String, Object> configResult = client.getConfig();
+
+    assertEquals(configMap, configResult);
+  }
+
+  @Test
+  public void testGetConfigError() throws IOException {
+    setupErrorTest("config");
+
+    ApplicationMasterRestClient client = new ApplicationMasterRestClient(mockClient, AM_HOST_NAME, AM_RPC_PORT);
+    client.getConfig();
+  }
+
+  @Test
+  public void testCloseMethodClosesHttpClient() throws IOException {
+    ApplicationMasterRestClient client = new ApplicationMasterRestClient(mockClient, AM_HOST_NAME, AM_RPC_PORT);
+    client.close();
+
+    verify(mockClient).close();
+  }
+
+  private void setupMockClientResponse(int statusCode, String statusReason, String responseBody) throws IOException {
+    StatusLine statusLine = mock(StatusLine.class);
+    when(statusLine.getStatusCode()).thenReturn(statusCode);
+    when(statusLine.getReasonPhrase()).thenReturn(statusReason);
+
+    HttpEntity entity = mock(HttpEntity.class);
+    when(entity.getContent()).thenReturn(new ReaderInputStream(new StringReader(responseBody)));
+
+    CloseableHttpResponse response = mock(CloseableHttpResponse.class);
+    when(response.getStatusLine()).thenReturn(statusLine);
+    when(response.getEntity()).thenReturn(entity);
+
+    when(mockClient.execute(any(HttpHost.class), any(HttpGet.class))).thenReturn(response);
+  }
+
+  private SamzaApplicationState createSamzaApplicationState() {
+    HashMap<String, ContainerModel> containers = generateContainers();
+
+    JobModel mockJobModel = mock(JobModel.class);
+    when(mockJobModel.getContainers()).thenReturn(containers);
+    JobModelManager mockJobModelManager = mock(JobModelManager.class);
+    when(mockJobModelManager.jobModel()).thenReturn(mockJobModel);
+
+    SamzaApplicationState samzaApplicationState = new SamzaApplicationState(mockJobModelManager);
+
+    samzaApplicationState.runningContainers.put(YARN_CONTAINER_ID_3,
+        new SamzaResource(1, 2, "dummyNodeHost1", "dummyResourceId1"));
+    samzaApplicationState.runningContainers.put(YARN_CONTAINER_ID_2,
+        new SamzaResource(2, 4, "dummyNodeHost2", "dummyResourceId2"));
+    return samzaApplicationState;
+  }
+
+  private YarnAppState createYarnAppState(ContainerId containerId) throws MalformedURLException {
+    YarnAppState yarnAppState = new YarnAppState(2, containerId, AM_HOST_NAME, AM_RPC_PORT, AM_HTTP_PORT);
+    yarnAppState.rpcUrl = new URL(new HttpHost(AM_HOST_NAME, AM_RPC_PORT).toURI());
+    yarnAppState.runningYarnContainers.put("0", new YarnContainer(Container.newInstance(
+        ConverterUtils.toContainerId(YARN_CONTAINER_ID_2),
+        ConverterUtils.toNodeIdWithDefaultPort("dummyNodeHost1"),
+        "dummyNodeHttpHost1",
+        null,
+        null,
+        null
+    )));
+    yarnAppState.runningYarnContainers.put("1", new YarnContainer(Container.newInstance(
+        ConverterUtils.toContainerId(YARN_CONTAINER_ID_3),
+        ConverterUtils.toNodeIdWithDefaultPort("dummyNodeHost2"),
+        "dummyNodeHttpHost2",
+        null,
+        null,
+        null
+    )));
+    return yarnAppState;
+  }
+
+  private HashMap<String, ContainerModel> generateContainers() {
+    Set<TaskModel> taskModels = ImmutableSet.of(
+        new TaskModel(new TaskName("task1"),
+                      ImmutableSet.of(new SystemStreamPartition(new SystemStream("system1", "stream1"), new Partition(0))),
+                      new Partition(0)),
+        new TaskModel(new TaskName("task2"),
+            ImmutableSet.of(new SystemStreamPartition(new SystemStream("system1", "stream1"), new Partition(1))),
+            new Partition(1)));
+    GroupByContainerCount grouper = new GroupByContainerCount(2);
+    Set<ContainerModel> containerModels = grouper.group(taskModels);
+    HashMap<String, ContainerModel> containers = new HashMap<>();
+    for (ContainerModel containerModel : containerModels) {
+      containers.put(containerModel.getProcessorId(), containerModel);
+    }
+    return containers;
+  }
+
+  private Map<String, Map<String, Object>> buildExpectedContainerResponse(Map<String, YarnContainer> runningYarnContainers,
+      SamzaApplicationState samzaAppState) throws IOException {
+    Map<String, Map<String, Object>> containers = new HashMap<>();
+
+    runningYarnContainers.forEach((containerId, container) -> {
+      String yarnContainerId = container.id().toString();
+      Map<String, Object> containerMap = new HashMap();
+      Map<TaskName, TaskModel> taskModels = samzaAppState.jobModelManager.jobModel().getContainers().get(containerId).getTasks();
+      containerMap.put("yarn-address", container.nodeHttpAddress());
+      containerMap.put("start-time", String.valueOf(container.startTime()));
+      containerMap.put("up-time", String.valueOf(container.upTime()));
+      containerMap.put("task-models", taskModels);
+      containerMap.put("container-id", containerId);
+      containers.put(yarnContainerId, containerMap);
+    });
+
+    return jsonMapper.readValue(jsonMapper.writeValueAsString(containers), new TypeReference<Map<String, Map<String, Object>>>() {});
+  }
+
+  private void assignMetricValues(SamzaApplicationState samzaAppState, MetricsRegistryMap registry) {
+    SamzaAppMasterMetrics metrics = new SamzaAppMasterMetrics(new MapConfig(), samzaAppState, registry);
+    metrics.start();
+    samzaAppState.runningContainers.put("dummyContainer",
+        new SamzaResource(1, 2, AM_HOST_NAME, "dummyResourceId")); // 1 container
+    samzaAppState.neededContainers.set(2);
+    samzaAppState.completedContainers.set(3);
+    samzaAppState.failedContainers.set(4);
+    samzaAppState.releasedContainers.set(5);
+    samzaAppState.containerCount.set(6);
+    samzaAppState.jobHealthy.set(true);
+  }
+
+  private void setupErrorTest(String entityToFetch) throws IOException {
+    String statusReason = "Dummy status reason";
+    expectedException.expect(SamzaException.class);
+    expectedException.expectMessage(String.format(
+        "Error retrieving %s from host %s. Response: %s",
+        entityToFetch,
+        new HttpHost(AM_HOST_NAME, AM_RPC_PORT).toURI(),
+        statusReason));
+
+    setupMockClientResponse(HttpStatus.SC_INTERNAL_SERVER_ERROR, statusReason, "");
+  }
+}

http://git-wip-us.apache.org/repos/asf/samza/blob/7a2b4650/samza-yarn/src/test/scala/org/apache/samza/job/yarn/TestClientHelper.scala
----------------------------------------------------------------------
diff --git a/samza-yarn/src/test/scala/org/apache/samza/job/yarn/TestClientHelper.scala b/samza-yarn/src/test/scala/org/apache/samza/job/yarn/TestClientHelper.scala
index ee947ae..6df6589 100644
--- a/samza-yarn/src/test/scala/org/apache/samza/job/yarn/TestClientHelper.scala
+++ b/samza-yarn/src/test/scala/org/apache/samza/job/yarn/TestClientHelper.scala
@@ -20,19 +20,16 @@ package org.apache.samza.job.yarn
 
 import org.apache.hadoop.conf.Configuration
 import org.apache.hadoop.fs.permission.FsPermission
-import org.apache.hadoop.fs.{FileStatus, Path, FileSystem}
-import org.apache.hadoop.yarn.api.records.ApplicationId
-import org.apache.hadoop.yarn.api.records.ApplicationReport
-import org.apache.hadoop.yarn.api.records.FinalApplicationStatus
-import org.apache.hadoop.yarn.api.records.YarnApplicationState
+import org.apache.hadoop.fs.{FileStatus, FileSystem, Path}
+import org.apache.hadoop.yarn.api.records.{ApplicationId, ApplicationReport, FinalApplicationStatus, YarnApplicationState}
 import org.apache.hadoop.yarn.client.api.YarnClient
 import org.apache.samza.SamzaException
-import org.apache.samza.config.{MapConfig, JobConfig, YarnConfig}
+import org.apache.samza.config.{JobConfig, MapConfig, YarnConfig}
 import org.apache.samza.job.ApplicationStatus
-import org.junit.Assert.assertEquals
-import org.junit.Assert.assertNotNull
-import org.mockito.Mockito._
+import org.apache.samza.webapp.ApplicationMasterRestClient
+import org.junit.Assert.{assertEquals, assertNotNull}
 import org.mockito.Matchers.any
+import org.mockito.Mockito._
 import org.scalatest.FunSuite
 import org.scalatest.mockito.MockitoSugar
 
@@ -40,11 +37,15 @@ import org.scalatest.mockito.MockitoSugar
 class TestClientHelper extends FunSuite {
   import MockitoSugar._
   val hadoopConfig = mock[Configuration]
+  val mockAmClient = mock[ApplicationMasterRestClient]
 
   val clientHelper = new ClientHelper(hadoopConfig) {
     override def createYarnClient() = {
       mock[YarnClient]
     }
+    override def createAmClient(applicationReport: ApplicationReport) = {
+      mockAmClient
+    }
   }
 
   test("test validateJobConfig") {
@@ -120,5 +121,22 @@ class TestClientHelper extends FunSuite {
     when(appReport.getFinalApplicationStatus).thenReturn(FinalApplicationStatus.SUCCEEDED)
     appStatus = clientHelper.toAppStatus(appReport).get
     assertEquals(appStatus, ApplicationStatus.SuccessfulFinish)
+
+    val appMasterMetrics =  new java.util.HashMap[String, Object]()
+    val metrics = new java.util.HashMap[String, java.util.Map[String, Object]]()
+    metrics.put(classOf[SamzaAppMasterMetrics].getCanonicalName(), appMasterMetrics)
+    appMasterMetrics.put("needed-containers", "1")
+    when(mockAmClient.getMetrics).thenReturn(metrics)
+
+    when(appReport.getYarnApplicationState).thenReturn(YarnApplicationState.RUNNING)
+    appStatus = clientHelper.toAppStatus(appReport).get
+    assertEquals(appStatus, ApplicationStatus.New) // Should not be RUNNING if there are still needed containers
+
+    appMasterMetrics.put("needed-containers", "0")
+    when(mockAmClient.getMetrics).thenReturn(metrics)
+
+    when(appReport.getYarnApplicationState).thenReturn(YarnApplicationState.RUNNING)
+    appStatus = clientHelper.toAppStatus(appReport).get
+    assertEquals(appStatus, ApplicationStatus.Running) // Should not be RUNNING if there are still needed containers
   }
 }