You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@samza.apache.org by bo...@apache.org on 2018/09/26 00:23:01 UTC

[02/29] samza git commit: SAMZA-1874: Refactor SamzaContainer and TaskInstance unit tests to make shared context changes easier

SAMZA-1874: Refactor SamzaContainer and TaskInstance unit tests to make shared context changes easier

This replaces https://github.com/apache/samza/pull/638, I accidentally messed up that branch.
The difference between this PR and the last review by prateekm is https://github.com/apache/samza/pull/646/commits/5d552996ac50d2a0b1dd5034a624d9417e74dc57

Author: Cameron Lee <ca...@linkedin.com>

Reviewers: Prateek Maheshwari <pm...@apache.org>

Closes #646 from cameronlee314/refactor_unit_tests_for_shared_context_new


Project: http://git-wip-us.apache.org/repos/asf/samza/repo
Commit: http://git-wip-us.apache.org/repos/asf/samza/commit/19c6f4f6
Tree: http://git-wip-us.apache.org/repos/asf/samza/tree/19c6f4f6
Diff: http://git-wip-us.apache.org/repos/asf/samza/diff/19c6f4f6

Branch: refs/heads/NewKafkaSystemConsumer
Commit: 19c6f4f6131ce86aa492073720f7168197a1f103
Parents: b7219e9
Author: Cameron Lee <ca...@linkedin.com>
Authored: Wed Sep 19 10:21:57 2018 -0700
Committer: Prateek Maheshwari <pm...@apache.org>
Committed: Wed Sep 19 10:21:57 2018 -0700

----------------------------------------------------------------------
 .../apache/samza/job/model/TestJobModel.java    |  50 ++
 .../samza/container/TestSamzaContainer.scala    | 729 ++++---------------
 .../samza/container/TestTaskInstance.scala      | 526 ++++---------
 .../TestTaskInstanceExceptionHandler.scala      | 144 ++++
 .../samza/system/chooser/MockSystemAdmin.scala  |  30 +
 .../chooser/TestBootstrappingChooser.scala      |   3 +-
 .../system/chooser/TestDefaultChooser.scala     |   1 -
 7 files changed, 528 insertions(+), 955 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/samza/blob/19c6f4f6/samza-core/src/test/java/org/apache/samza/job/model/TestJobModel.java
----------------------------------------------------------------------
diff --git a/samza-core/src/test/java/org/apache/samza/job/model/TestJobModel.java b/samza-core/src/test/java/org/apache/samza/job/model/TestJobModel.java
new file mode 100644
index 0000000..6c7c282
--- /dev/null
+++ b/samza-core/src/test/java/org/apache/samza/job/model/TestJobModel.java
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.samza.job.model;
+
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.ImmutableSet;
+import java.util.Map;
+import org.apache.samza.Partition;
+import org.apache.samza.config.Config;
+import org.apache.samza.config.MapConfig;
+import org.apache.samza.container.TaskName;
+import org.junit.Test;
+
+import static org.junit.Assert.*;
+
+
+public class TestJobModel {
+  @Test
+  public void testMaxChangeLogStreamPartitions() {
+    Config config = new MapConfig(ImmutableMap.of("a", "b"));
+    Map<TaskName, TaskModel> tasksForContainer1 = ImmutableMap.of(
+        new TaskName("t1"), new TaskModel(new TaskName("t1"), ImmutableSet.of(), new Partition(0)),
+        new TaskName("t2"), new TaskModel(new TaskName("t2"), ImmutableSet.of(), new Partition(1)));
+    Map<TaskName, TaskModel> tasksForContainer2 = ImmutableMap.of(
+        new TaskName("t3"), new TaskModel(new TaskName("t3"), ImmutableSet.of(), new Partition(2)),
+        new TaskName("t4"), new TaskModel(new TaskName("t4"), ImmutableSet.of(), new Partition(3)),
+        new TaskName("t5"), new TaskModel(new TaskName("t5"), ImmutableSet.of(), new Partition(4)));
+    ContainerModel containerModel1 = new ContainerModel("0", 0, tasksForContainer1);
+    ContainerModel containerModel2 = new ContainerModel("1", 1, tasksForContainer2);
+    Map<String, ContainerModel> containers = ImmutableMap.of("0", containerModel1, "1", containerModel2);
+    JobModel jobModel = new JobModel(config, containers);
+    assertEquals(jobModel.maxChangeLogStreamPartitions, 5);
+  }
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/samza/blob/19c6f4f6/samza-core/src/test/scala/org/apache/samza/container/TestSamzaContainer.scala
----------------------------------------------------------------------
diff --git a/samza-core/src/test/scala/org/apache/samza/container/TestSamzaContainer.scala b/samza-core/src/test/scala/org/apache/samza/container/TestSamzaContainer.scala
index ff57047..30ca8c1 100644
--- a/samza-core/src/test/scala/org/apache/samza/container/TestSamzaContainer.scala
+++ b/samza-core/src/test/scala/org/apache/samza/container/TestSamzaContainer.scala
@@ -20,27 +20,22 @@
 package org.apache.samza.container
 
 import java.util
-import java.util.concurrent.TimeUnit
 import java.util.concurrent.atomic.AtomicReference
 
-import org.apache.samza.checkpoint.{Checkpoint, CheckpointManager}
-import org.apache.samza.config.{Config, MapConfig}
+import org.apache.samza.config.MapConfig
 import org.apache.samza.coordinator.JobModelManager
 import org.apache.samza.coordinator.server.{HttpServer, JobServlet}
 import org.apache.samza.job.model.{ContainerModel, JobModel, TaskModel}
-import org.apache.samza.metrics.MetricsRegistryMap
-import org.apache.samza.serializers.SerdeManager
-import org.apache.samza.storage.TaskStorageManager
+import org.apache.samza.metrics.{Gauge, Timer}
 import org.apache.samza.system._
-import org.apache.samza.system.chooser.RoundRobinChooser
-import org.apache.samza.task._
-import org.apache.samza.util.SinglePartitionWithoutOffsetsSystemAdmin
 import org.apache.samza.{Partition, SamzaContainerStatus}
 import org.junit.Assert._
-import org.junit.Test
-import org.mockito.Mockito.when
+import org.junit.{Before, Test}
+import org.mockito.Matchers.{any, notNull}
+import org.mockito.Mockito._
 import org.mockito.invocation.InvocationOnMock
 import org.mockito.stubbing.Answer
+import org.mockito.{ArgumentCaptor, Mock, MockitoAnnotations}
 import org.scalatest.junit.AssertionsForJUnit
 import org.scalatest.mockito.MockitoSugar
 
@@ -48,8 +43,137 @@ import scala.collection.JavaConversions._
 import scala.collection.JavaConverters._
 
 class TestSamzaContainer extends AssertionsForJUnit with MockitoSugar {
+  private val TASK_NAME = new TaskName("taskName")
+
+  @Mock
+  private var containerContext: SamzaContainerContext = null
+  @Mock
+  private var taskInstance: TaskInstance = null
+  @Mock
+  private var runLoop: Runnable = null
+  @Mock
+  private var systemAdmins: SystemAdmins = null
+  @Mock
+  private var consumerMultiplexer: SystemConsumers = null
+  @Mock
+  private var producerMultiplexer: SystemProducers = null
+  @Mock
+  private var metrics: SamzaContainerMetrics = null
+  @Mock
+  private var samzaContainerListener: SamzaContainerListener = null
+
+  private var samzaContainer: SamzaContainer = null
+
+  @Before
+  def setup(): Unit = {
+    MockitoAnnotations.initMocks(this)
+    this.samzaContainer = new SamzaContainer(
+      this.containerContext,
+      Map(TASK_NAME -> this.taskInstance),
+      this.runLoop,
+      this.systemAdmins,
+      this.consumerMultiplexer,
+      this.producerMultiplexer,
+      metrics)
+    this.samzaContainer.setContainerListener(this.samzaContainerListener)
+    when(this.metrics.containerStartupTime).thenReturn(mock[Timer])
+  }
+
+  @Test
+  def testExceptionInTaskInitShutsDownTask() {
+    when(this.taskInstance.initTask).thenThrow(new RuntimeException("Trigger a shutdown, please."))
+
+    this.samzaContainer.run
+
+    verify(this.taskInstance).shutdownTask
+    assertEquals(SamzaContainerStatus.FAILED, this.samzaContainer.getStatus())
+    verify(this.samzaContainerListener).beforeStart()
+    verify(this.samzaContainerListener, never()).afterStart()
+    verify(this.samzaContainerListener, never()).afterStop()
+    verify(this.samzaContainerListener).afterFailure(notNull(classOf[Exception]))
+    verifyZeroInteractions(this.runLoop)
+  }
+
+  @Test
+  def testErrorInTaskInitShutsDownTask(): Unit = {
+    when(this.taskInstance.initTask).thenThrow(new NoSuchMethodError("Trigger a shutdown, please."))
+
+    this.samzaContainer.run
+
+    verify(this.taskInstance).shutdownTask
+    assertEquals(SamzaContainerStatus.FAILED, this.samzaContainer.getStatus())
+    verify(this.samzaContainerListener).beforeStart()
+    verify(this.samzaContainerListener, never()).afterStart()
+    verify(this.samzaContainerListener, never()).afterStop()
+    verify(this.samzaContainerListener).afterFailure(notNull(classOf[Exception]))
+    verifyZeroInteractions(this.runLoop)
+  }
+
+  @Test
+  def testExceptionInTaskProcessRunLoop() {
+    when(this.runLoop.run()).thenThrow(new RuntimeException("Trigger a shutdown, please."))
+
+    this.samzaContainer.run
+
+    verify(this.taskInstance).shutdownTask
+    assertEquals(SamzaContainerStatus.FAILED, this.samzaContainer.getStatus())
+    verify(this.samzaContainerListener).beforeStart()
+    verify(this.samzaContainerListener).afterStart()
+    verify(this.samzaContainerListener, never()).afterStop()
+    verify(this.samzaContainerListener).afterFailure(notNull(classOf[Exception]))
+    verify(this.runLoop).run()
+  }
+
+  @Test
+  def testCleanRun(): Unit = {
+    doNothing().when(this.runLoop).run() // run loop completes successfully
+
+    this.samzaContainer.run
+
+    verify(this.taskInstance).shutdownTask
+    assertEquals(SamzaContainerStatus.STOPPED, this.samzaContainer.getStatus())
+    verify(this.samzaContainerListener).beforeStart()
+    verify(this.samzaContainerListener).afterStart()
+    verify(this.samzaContainerListener).afterStop()
+    verify(this.samzaContainerListener, never()).afterFailure(any())
+    verify(this.runLoop).run()
+  }
+
   @Test
-  def testReadJobModel {
+  def testFailureDuringShutdown(): Unit = {
+    doNothing().when(this.runLoop).run() // run loop completes successfully
+    when(this.taskInstance.shutdownTask).thenThrow(new RuntimeException("Trigger a shutdown, please."))
+
+    this.samzaContainer.run
+
+    verify(this.taskInstance).shutdownTask
+    assertEquals(SamzaContainerStatus.FAILED, this.samzaContainer.getStatus())
+    verify(this.samzaContainerListener).beforeStart()
+    verify(this.samzaContainerListener).afterStart()
+    verify(this.samzaContainerListener, never()).afterStop()
+    verify(this.samzaContainerListener).afterFailure(notNull(classOf[Exception]))
+    verify(this.runLoop).run()
+  }
+
+  @Test
+  def testStartStoresIncrementsCounter() {
+    when(this.taskInstance.taskName).thenReturn(TASK_NAME)
+    val restoreGauge = mock[Gauge[Long]]
+    when(this.metrics.taskStoreRestorationMetrics).thenReturn(Map(TASK_NAME -> restoreGauge))
+    when(this.taskInstance.startStores).thenAnswer(new Answer[Void] {
+      override def answer(invocation: InvocationOnMock): Void = {
+        Thread.sleep(1)
+        null
+      }
+    })
+    this.samzaContainer.startStores
+    val restoreGaugeValueCaptor = ArgumentCaptor.forClass(classOf[Long])
+    verify(restoreGauge).set(restoreGaugeValueCaptor.capture())
+    assertTrue(restoreGaugeValueCaptor.getValue >= 1)
+  }
+
+  @Test
+  def testReadJobModel() {
     val config = new MapConfig(Map("a" -> "b").asJava)
     val offsets = new util.HashMap[SystemStreamPartition, String]()
     offsets.put(new SystemStreamPartition("system","stream", new Partition(0)), "1")
@@ -74,7 +198,7 @@ class TestSamzaContainer extends AssertionsForJUnit with MockitoSugar {
   }
 
   @Test
-  def testReadJobModelWithTimeouts {
+  def testReadJobModelWithTimeouts() {
     val config = new MapConfig(Map("a" -> "b").asJava)
     val offsets = new util.HashMap[SystemStreamPartition, String]()
     offsets.put(new SystemStreamPartition("system","stream", new Partition(0)), "1")
@@ -101,551 +225,7 @@ class TestSamzaContainer extends AssertionsForJUnit with MockitoSugar {
   }
 
   @Test
-  def testChangelogPartitions {
-    val config = new MapConfig(Map("a" -> "b").asJava)
-    val offsets = new util.HashMap[SystemStreamPartition, String]()
-    offsets.put(new SystemStreamPartition("system", "stream", new Partition(0)), "1")
-    val tasksForContainer1 = Map(
-      new TaskName("t1") -> new TaskModel(new TaskName("t1"), offsets.keySet(), new Partition(0)),
-      new TaskName("t2") -> new TaskModel(new TaskName("t2"), offsets.keySet(), new Partition(1)))
-    val tasksForContainer2 = Map(
-      new TaskName("t3") -> new TaskModel(new TaskName("t3"), offsets.keySet(), new Partition(2)),
-      new TaskName("t4") -> new TaskModel(new TaskName("t4"), offsets.keySet(), new Partition(3)),
-      new TaskName("t5") -> new TaskModel(new TaskName("t6"), offsets.keySet(), new Partition(4)))
-    val containerModel1 = new ContainerModel("0", 0, tasksForContainer1)
-    val containerModel2 = new ContainerModel("1", 1, tasksForContainer2)
-    val containers = Map(
-      "0" -> containerModel1,
-      "1" -> containerModel2)
-    val jobModel = new JobModel(config, containers)
-    assertEquals(jobModel.maxChangeLogStreamPartitions, 5)
-  }
-
-  @Test
-  def testGetInputStreamMetadata {
-    val inputStreams = Set(
-      new SystemStreamPartition("test", "stream1", new Partition(0)),
-      new SystemStreamPartition("test", "stream1", new Partition(1)),
-      new SystemStreamPartition("test", "stream2", new Partition(0)),
-      new SystemStreamPartition("test", "stream2", new Partition(1)))
-    val systemAdmins = mock[SystemAdmins]
-    when(systemAdmins.getSystemAdmin("test")).thenReturn(new SinglePartitionWithoutOffsetsSystemAdmin)
-    val metadata = new StreamMetadataCache(systemAdmins).getStreamMetadata(inputStreams.map(_.getSystemStream))
-    assertNotNull(metadata)
-    assertEquals(2, metadata.size)
-    val stream1Metadata = metadata(new SystemStream("test", "stream1"))
-    val stream2Metadata = metadata(new SystemStream("test", "stream2"))
-    assertNotNull(stream1Metadata)
-    assertNotNull(stream2Metadata)
-    assertEquals("stream1", stream1Metadata.getStreamName)
-    assertEquals("stream2", stream2Metadata.getStreamName)
-  }
-
-  @Test
-  def testExceptionInTaskInitShutsDownTask {
-    val task = new StreamTask with InitableTask with ClosableTask {
-      var wasShutdown = false
-
-      def init(config: Config, context: TaskContext) {
-        throw new Exception("Trigger a shutdown, please.")
-      }
-
-      def process(envelope: IncomingMessageEnvelope, collector: MessageCollector, coordinator: TaskCoordinator) {
-      }
-
-      def close {
-        wasShutdown = true
-      }
-    }
-    val config = new MapConfig
-    val taskName = new TaskName("taskName")
-    val systemAdmins = new SystemAdmins(config)
-    val consumerMultiplexer = new SystemConsumers(
-      new RoundRobinChooser,
-      Map[String, SystemConsumer]())
-    val producerMultiplexer = new SystemProducers(
-      Map[String, SystemProducer](),
-      new SerdeManager)
-    val collector = new TaskInstanceCollector(producerMultiplexer)
-    val containerContext = new SamzaContainerContext("0", config, Set[TaskName](taskName), new MetricsRegistryMap)
-    val taskInstance: TaskInstance = new TaskInstance(
-      task,
-      taskName,
-      config,
-      new TaskInstanceMetrics,
-      null,
-      consumerMultiplexer,
-      collector,
-      containerContext
-    )
-    val runLoop = new RunLoop(
-      taskInstances = Map(taskName -> taskInstance),
-      consumerMultiplexer = consumerMultiplexer,
-      metrics = new SamzaContainerMetrics,
-      maxThrottlingDelayMs = TimeUnit.SECONDS.toMillis(1))
-    @volatile var onContainerFailedCalled = false
-    @volatile var onContainerStopCalled = false
-    @volatile var onContainerStartCalled = false
-    @volatile var onContainerFailedThrowable: Throwable = null
-    @volatile var onContainerBeforeStartCalled = false
-
-    val container = new SamzaContainer(
-      containerContext = containerContext,
-      taskInstances = Map(taskName -> taskInstance),
-      runLoop = runLoop,
-      systemAdmins = systemAdmins,
-      consumerMultiplexer = consumerMultiplexer,
-      producerMultiplexer = producerMultiplexer,
-      metrics = new SamzaContainerMetrics)
-
-    val containerListener = new SamzaContainerListener {
-      override def afterFailure(t: Throwable): Unit = {
-        onContainerFailedCalled = true
-        onContainerFailedThrowable = t
-      }
-
-      override def afterStop(): Unit = {
-        onContainerStopCalled = true
-      }
-
-      override def afterStart(): Unit = {
-        onContainerStartCalled = true
-      }
-
-      override def beforeStart(): Unit = {
-        onContainerBeforeStartCalled = true
-      }
-
-    }
-    container.setContainerListener(containerListener)
-
-    container.run
-    assertTrue(task.wasShutdown)
-    assertTrue(onContainerBeforeStartCalled)
-    assertFalse(onContainerStartCalled)
-    assertFalse(onContainerStopCalled)
-
-    assertTrue(onContainerFailedCalled)
-    assertNotNull(onContainerFailedThrowable)
-  }
-
-  // Exception in Runloop should cause SamzaContainer to transition to FAILED status, shutdown the components and then,
-  // invoke the callback
-  @Test
-  def testExceptionInTaskProcessRunLoop() {
-    val task = new StreamTask with InitableTask with ClosableTask {
-      var wasShutdown = false
-
-      def init(config: Config, context: TaskContext) {
-      }
-
-      def process(envelope: IncomingMessageEnvelope, collector: MessageCollector, coordinator: TaskCoordinator) {
-        throw new Exception("Trigger a shutdown, please.")
-      }
-
-      def close {
-        wasShutdown = true
-      }
-    }
-    val config = new MapConfig
-    val taskName = new TaskName("taskName")
-    val systemAdmins = new SystemAdmins(config)
-    val consumerMultiplexer = new SystemConsumers(
-      new RoundRobinChooser,
-      Map[String, SystemConsumer]())
-    val producerMultiplexer = new SystemProducers(
-      Map[String, SystemProducer](),
-      new SerdeManager)
-    val collector = new TaskInstanceCollector(producerMultiplexer)
-    val containerContext = new SamzaContainerContext("0", config, Set[TaskName](taskName), new MetricsRegistryMap)
-    val taskInstance: TaskInstance = new TaskInstance(
-      task,
-      taskName,
-      config,
-      new TaskInstanceMetrics,
-      null,
-      consumerMultiplexer,
-      collector,
-      containerContext
-    )
-
-    @volatile var onContainerFailedCalled = false
-    @volatile var onContainerStopCalled = false
-    @volatile var onContainerStartCalled = false
-    @volatile var onContainerFailedThrowable: Throwable = null
-    @volatile var onContainerBeforeStartCalled = false
-
-    val mockRunLoop = mock[RunLoop]
-    when(mockRunLoop.run).thenThrow(new RuntimeException("Trigger a shutdown, please."))
-
-    val container = new SamzaContainer(
-      containerContext = containerContext,
-      taskInstances = Map(taskName -> taskInstance),
-      runLoop = mockRunLoop,
-      systemAdmins = systemAdmins,
-      consumerMultiplexer = consumerMultiplexer,
-      producerMultiplexer = producerMultiplexer,
-      metrics = new SamzaContainerMetrics)
-    val containerListener = new SamzaContainerListener {
-      override def afterFailure(t: Throwable): Unit = {
-        onContainerFailedCalled = true
-        onContainerFailedThrowable = t
-      }
-
-      override def afterStop(): Unit = {
-        onContainerStopCalled = true
-      }
-
-      override def afterStart(): Unit = {
-        onContainerStartCalled = true
-      }
-
-      /**
-        * Method invoked before the {@link org.apache.samza.container.SamzaContainer} is started
-        */
-      override def beforeStart(): Unit = {
-        onContainerBeforeStartCalled = true
-      }
-    }
-    container.setContainerListener(containerListener)
-
-    container.run
-    assertTrue(task.wasShutdown)
-    assertTrue(onContainerBeforeStartCalled)
-    assertTrue(onContainerStartCalled)
-
-    assertFalse(onContainerStopCalled)
-
-    assertTrue(onContainerFailedCalled)
-    assertNotNull(onContainerFailedThrowable)
-
-    assertEquals(SamzaContainerStatus.FAILED, container.getStatus())
-  }
-
-  @Test
-  def testErrorInTaskInitShutsDownTask() {
-    val task = new StreamTask with InitableTask with ClosableTask {
-      var wasShutdown = false
-
-      def init(config: Config, context: TaskContext) {
-        throw new NoSuchMethodError("Trigger a shutdown, please.")
-      }
-
-      def process(envelope: IncomingMessageEnvelope, collector: MessageCollector, coordinator: TaskCoordinator) {
-      }
-
-      def close {
-        wasShutdown = true
-      }
-    }
-    val config = new MapConfig
-    val taskName = new TaskName("taskName")
-    val systemAdmins = new SystemAdmins(config)
-    val consumerMultiplexer = new SystemConsumers(
-      new RoundRobinChooser,
-      Map[String, SystemConsumer]())
-    val producerMultiplexer = new SystemProducers(
-      Map[String, SystemProducer](),
-      new SerdeManager)
-    val collector = new TaskInstanceCollector(producerMultiplexer)
-    val containerContext = new SamzaContainerContext("0", config, Set[TaskName](taskName), new MetricsRegistryMap)
-    val taskInstance: TaskInstance = new TaskInstance(
-      task,
-      taskName,
-      config,
-      new TaskInstanceMetrics,
-      null,
-      consumerMultiplexer,
-      collector,
-      containerContext
-    )
-    val runLoop = new RunLoop(
-      taskInstances = Map(taskName -> taskInstance),
-      consumerMultiplexer = consumerMultiplexer,
-      metrics = new SamzaContainerMetrics,
-      maxThrottlingDelayMs = TimeUnit.SECONDS.toMillis(1))
-    @volatile var onContainerFailedCalled = false
-    @volatile var onContainerStopCalled = false
-    @volatile var onContainerStartCalled = false
-    @volatile var onContainerFailedThrowable: Throwable = null
-    @volatile var onContainerBeforeStartCalled = false
-
-    val container = new SamzaContainer(
-      containerContext = containerContext,
-      taskInstances = Map(taskName -> taskInstance),
-      runLoop = runLoop,
-      systemAdmins = systemAdmins,
-      consumerMultiplexer = consumerMultiplexer,
-      producerMultiplexer = producerMultiplexer,
-      metrics = new SamzaContainerMetrics)
-    val containerListener = new SamzaContainerListener {
-      override def afterFailure(t: Throwable): Unit = {
-        onContainerFailedCalled = true
-        onContainerFailedThrowable = t
-      }
-
-      override def afterStop(): Unit = {
-        onContainerStopCalled = true
-      }
-
-      override def afterStart(): Unit = {
-        onContainerStartCalled = true
-      }
-
-      /**
-        * Method invoked before the {@link org.apache.samza.container.SamzaContainer} is started
-        */
-      override def beforeStart(): Unit = {
-        onContainerBeforeStartCalled = true
-      }
-    }
-    container.setContainerListener(containerListener)
-
-    container.run
-
-    assertTrue(task.wasShutdown)
-    assertTrue(onContainerBeforeStartCalled)
-    assertFalse(onContainerStopCalled)
-    assertFalse(onContainerStartCalled)
-
-    assertTrue(onContainerFailedCalled)
-    assertNotNull(onContainerFailedThrowable)
-  }
-
-  @Test
-  def testRunloopShutdownIsClean(): Unit = {
-    val task = new StreamTask with InitableTask with ClosableTask {
-      var wasShutdown = false
-
-      def init(config: Config, context: TaskContext) {
-      }
-
-      def process(envelope: IncomingMessageEnvelope, collector: MessageCollector, coordinator: TaskCoordinator) {
-      }
-
-      def close {
-        wasShutdown = true
-      }
-    }
-    val config = new MapConfig
-    val taskName = new TaskName("taskName")
-    val systemAdmins = new SystemAdmins(config)
-    val consumerMultiplexer = new SystemConsumers(
-      new RoundRobinChooser,
-      Map[String, SystemConsumer]())
-    val producerMultiplexer = new SystemProducers(
-      Map[String, SystemProducer](),
-      new SerdeManager)
-    val collector = new TaskInstanceCollector(producerMultiplexer)
-    val containerContext = new SamzaContainerContext("0", config, Set[TaskName](taskName), new MetricsRegistryMap)
-    val taskInstance: TaskInstance = new TaskInstance(
-      task,
-      taskName,
-      config,
-      new TaskInstanceMetrics,
-      null,
-      consumerMultiplexer,
-      collector,
-      containerContext
-    )
-
-    @volatile var onContainerFailedCalled = false
-    @volatile var onContainerStopCalled = false
-    @volatile var onContainerStartCalled = false
-    @volatile var onContainerFailedThrowable: Throwable = null
-    @volatile var onContainerBeforeStartCalled = false
-
-    val mockRunLoop = mock[RunLoop]
-    when(mockRunLoop.run).thenAnswer(new Answer[Unit] {
-      override def answer(invocation: InvocationOnMock): Unit = {
-        Thread.sleep(100)
-      }
-    })
-
-    val container = new SamzaContainer(
-      containerContext = containerContext,
-      taskInstances = Map(taskName -> taskInstance),
-      runLoop = mockRunLoop,
-      systemAdmins = systemAdmins,
-      consumerMultiplexer = consumerMultiplexer,
-      producerMultiplexer = producerMultiplexer,
-      metrics = new SamzaContainerMetrics)
-      val containerListener = new SamzaContainerListener {
-        override def afterFailure(t: Throwable): Unit = {
-          onContainerFailedCalled = true
-          onContainerFailedThrowable = t
-        }
-
-        override def afterStop(): Unit = {
-          onContainerStopCalled = true
-        }
-
-        override def afterStart(): Unit = {
-          onContainerStartCalled = true
-        }
-
-        /**
-          * Method invoked before the {@link org.apache.samza.container.SamzaContainer} is started
-          */
-        override def beforeStart(): Unit = {
-          onContainerBeforeStartCalled = true
-        }
-      }
-    container.setContainerListener(containerListener)
-
-    container.run
-    assertTrue(onContainerBeforeStartCalled)
-    assertFalse(onContainerFailedCalled)
-    assertTrue(onContainerStartCalled)
-    assertTrue(onContainerStopCalled)
-  }
-
-  @Test
-  def testFailureDuringShutdown: Unit = {
-    val task = new StreamTask with InitableTask with ClosableTask {
-      def init(config: Config, context: TaskContext) {
-      }
-
-      def process(envelope: IncomingMessageEnvelope, collector: MessageCollector, coordinator: TaskCoordinator) {
-
-      }
-
-      def close {
-        throw new Exception("Exception during shutdown, please.")
-      }
-    }
-    val config = new MapConfig
-    val taskName = new TaskName("taskName")
-    val systemAdmins = new SystemAdmins(config)
-    val consumerMultiplexer = new SystemConsumers(
-      new RoundRobinChooser,
-      Map[String, SystemConsumer]())
-    val producerMultiplexer = new SystemProducers(
-      Map[String, SystemProducer](),
-      new SerdeManager)
-    val collector = new TaskInstanceCollector(producerMultiplexer)
-    val containerContext = new SamzaContainerContext("0", config, Set[TaskName](taskName), new MetricsRegistryMap)
-    val taskInstance: TaskInstance = new TaskInstance(
-      task,
-      taskName,
-      config,
-      new TaskInstanceMetrics,
-      null,
-      consumerMultiplexer,
-      collector,
-      containerContext
-    )
-
-    @volatile var onContainerFailedCalled = false
-    @volatile var onContainerStopCalled = false
-    @volatile var onContainerStartCalled = false
-    @volatile var onContainerFailedThrowable: Throwable = null
-    @volatile var onContainerBeforeStartCalled = false
-
-    val mockRunLoop = mock[RunLoop]
-    when(mockRunLoop.run).thenAnswer(new Answer[Unit] {
-      override def answer(invocation: InvocationOnMock): Unit = {
-        Thread.sleep(100)
-      }
-    })
-
-    val container = new SamzaContainer(
-      containerContext = containerContext,
-      taskInstances = Map(taskName -> taskInstance),
-      runLoop = mockRunLoop,
-      systemAdmins = systemAdmins,
-      consumerMultiplexer = consumerMultiplexer,
-      producerMultiplexer = producerMultiplexer,
-      metrics = new SamzaContainerMetrics)
-
-    val containerListener = new SamzaContainerListener {
-        override def afterFailure(t: Throwable): Unit = {
-          onContainerFailedCalled = true
-          onContainerFailedThrowable = t
-        }
-
-        override def afterStop(): Unit = {
-          onContainerStopCalled = true
-        }
-
-        override def afterStart(): Unit = {
-          onContainerStartCalled = true
-        }
-
-      /**
-        * Method invoked before the {@link org.apache.samza.container.SamzaContainer} is started
-        */
-      override def beforeStart(): Unit = {
-        onContainerBeforeStartCalled = true
-      }
-    }
-    container.setContainerListener(containerListener)
-
-    container.run
-
-    assertTrue(onContainerBeforeStartCalled)
-    assertTrue(onContainerStartCalled)
-    assertTrue(onContainerFailedCalled)
-    assertFalse(onContainerStopCalled)
-  }
-
-  @Test
-  def testStartStoresIncrementsCounter {
-    val task = new StreamTask {
-      def process(envelope: IncomingMessageEnvelope, collector: MessageCollector, coordinator: TaskCoordinator) {
-      }
-    }
-    val config = new MapConfig
-    val taskName = new TaskName("taskName")
-    val systemAdmins = new SystemAdmins(config)
-    val consumerMultiplexer = new SystemConsumers(
-      new RoundRobinChooser,
-      Map[String, SystemConsumer]())
-    val producerMultiplexer = new SystemProducers(
-      Map[String, SystemProducer](),
-      new SerdeManager)
-    val collector = new TaskInstanceCollector(producerMultiplexer)
-    val containerContext = new SamzaContainerContext("0", config, Set[TaskName](taskName), new MetricsRegistryMap)
-    val mockTaskStorageManager = mock[TaskStorageManager]
-
-    when(mockTaskStorageManager.init).thenAnswer(new Answer[String] {
-      override def answer(invocation: InvocationOnMock): String = {
-        Thread.sleep(1)
-        ""
-      }
-    })
-
-    val taskInstance: TaskInstance = new TaskInstance(
-      task,
-      taskName,
-      config,
-      new TaskInstanceMetrics,
-      null,
-      consumerMultiplexer,
-      collector,
-      containerContext,
-      storageManager = mockTaskStorageManager
-    )
-    val containerMetrics = new SamzaContainerMetrics()
-    containerMetrics.addStoreRestorationGauge(taskName, "store")
-    val container = new SamzaContainer(
-      containerContext = containerContext,
-      taskInstances = Map(taskName -> taskInstance),
-      runLoop = null,
-      systemAdmins = systemAdmins,
-      consumerMultiplexer = consumerMultiplexer,
-      producerMultiplexer = producerMultiplexer,
-      metrics = containerMetrics)
-
-    container.startStores
-    assertNotNull(containerMetrics.taskStoreRestorationMetrics)
-    assertNotNull(containerMetrics.taskStoreRestorationMetrics.get(taskName))
-    assertTrue(containerMetrics.taskStoreRestorationMetrics.get(taskName).getValue >= 1)
-
-  }
-
-  @Test
-  def testGetChangelogSSPsForContainer() = {
+  def testGetChangelogSSPsForContainer() {
     val taskName0 = new TaskName("task0")
     val taskName1 = new TaskName("task1")
     val taskModel0 = new TaskModel(taskName0,
@@ -665,7 +245,7 @@ class TestSamzaContainer extends AssertionsForJUnit with MockitoSugar {
   }
 
   @Test
-  def testGetChangelogSSPsForContainerNoChangelogs() = {
+  def testGetChangelogSSPsForContainerNoChangelogs() {
     val taskName0 = new TaskName("task0")
     val taskName1 = new TaskName("task1")
     val taskModel0 = new TaskModel(taskName0,
@@ -677,29 +257,18 @@ class TestSamzaContainer extends AssertionsForJUnit with MockitoSugar {
     val containerModel = new ContainerModel("processorId", 0, Map(taskName0 -> taskModel0, taskName1 -> taskModel1))
     assertEquals(Set(), SamzaContainer.getChangelogSSPsForContainer(containerModel, Map()))
   }
-}
-
-class MockCheckpointManager extends CheckpointManager {
-  override def start() = {}
-  override def stop() = {}
 
-  override def register(taskName: TaskName): Unit = {}
+  class MockJobServlet(exceptionLimit: Int, jobModelRef: AtomicReference[JobModel]) extends JobServlet(jobModelRef) {
+    var exceptionCount = 0
 
-  override def readLastCheckpoint(taskName: TaskName): Checkpoint = { new Checkpoint(Map[SystemStreamPartition, String]().asJava) }
-
-  override def writeCheckpoint(taskName: TaskName, checkpoint: Checkpoint): Unit = { }
-}
-
-class MockJobServlet(exceptionLimit: Int, jobModelRef: AtomicReference[JobModel]) extends JobServlet(jobModelRef) {
-  var exceptionCount = 0
-
-  override protected def getObjectToWrite() = {
-    if (exceptionCount < exceptionLimit) {
-      exceptionCount += 1
-      throw new java.io.IOException("Throwing exception")
-    } else {
-      val jobModel = jobModelRef.get()
-      jobModel
+    override protected def getObjectToWrite(): JobModel = {
+      if (exceptionCount < exceptionLimit) {
+        exceptionCount += 1
+        throw new java.io.IOException("Throwing exception")
+      } else {
+        val jobModel = jobModelRef.get()
+        jobModel
+      }
     }
   }
 }

http://git-wip-us.apache.org/repos/asf/samza/blob/19c6f4f6/samza-core/src/test/scala/org/apache/samza/container/TestTaskInstance.scala
----------------------------------------------------------------------
diff --git a/samza-core/src/test/scala/org/apache/samza/container/TestTaskInstance.scala b/samza-core/src/test/scala/org/apache/samza/container/TestTaskInstance.scala
index 1672191..b196131 100644
--- a/samza-core/src/test/scala/org/apache/samza/container/TestTaskInstance.scala
+++ b/samza-core/src/test/scala/org/apache/samza/container/TestTaskInstance.scala
@@ -20,429 +20,211 @@
 package org.apache.samza.container
 
 
-import java.util.concurrent.ConcurrentHashMap
-
 import org.apache.samza.Partition
 import org.apache.samza.checkpoint.{Checkpoint, OffsetManager}
-import org.apache.samza.config.{Config, MapConfig}
-import org.apache.samza.metrics.{Counter, Metric, MetricsRegistryMap}
-import org.apache.samza.serializers.SerdeManager
-import org.apache.samza.system.IncomingMessageEnvelope
-import org.apache.samza.system.SystemAdmin
-import org.apache.samza.system.SystemConsumer
-import org.apache.samza.system.SystemConsumers
-import org.apache.samza.system.SystemProducer
-import org.apache.samza.system.SystemProducers
-import org.apache.samza.system.SystemStream
-import org.apache.samza.system.SystemStreamMetadata
+import org.apache.samza.config.Config
+import org.apache.samza.metrics.Counter
 import org.apache.samza.storage.TaskStorageManager
-import org.apache.samza.system.SystemStreamMetadata.SystemStreamPartitionMetadata
-import org.apache.samza.system._
-import org.apache.samza.system.chooser.RoundRobinChooser
+import org.apache.samza.system.{IncomingMessageEnvelope, SystemAdmin, SystemConsumers, SystemStream, _}
 import org.apache.samza.task._
 import org.junit.Assert._
-import org.junit.Test
+import org.junit.{Before, Test}
 import org.mockito.Matchers._
-import org.mockito.Mockito
 import org.mockito.Mockito._
-import org.scalatest.Assertions.intercept
+import org.mockito.invocation.InvocationOnMock
+import org.mockito.stubbing.Answer
+import org.mockito.{Matchers, Mock, MockitoAnnotations}
+import org.scalatest.mockito.MockitoSugar
 
-import scala.collection.mutable.ListBuffer
 import scala.collection.JavaConverters._
 
-class TestTaskInstance {
-  @Test
-  def testOffsetsAreUpdatedOnProcess {
-    val task = new StreamTask {
-      def process(envelope: IncomingMessageEnvelope, collector: MessageCollector, coordinator: TaskCoordinator) {
-      }
-    }
-    val config = new MapConfig
-    val partition = new Partition(0)
-    val consumerMultiplexer = new SystemConsumers(
-      new RoundRobinChooser,
-      Map[String, SystemConsumer]())
-    val producerMultiplexer = new SystemProducers(
-      Map[String, SystemProducer](),
-      new SerdeManager)
-    val systemStream = new SystemStream("test-system", "test-stream")
-    val systemStreamPartition = new SystemStreamPartition(systemStream, partition)
-    val systemStreamPartitions = Set(systemStreamPartition)
-    // Pretend our last checkpointed (next) offset was 2.
-    val testSystemStreamMetadata = new SystemStreamMetadata(systemStream.getStream, Map(partition -> new SystemStreamPartitionMetadata("0", "1", "2")).asJava)
-    val offsetManager = OffsetManager(Map(systemStream -> testSystemStreamMetadata), config)
-    val taskName = new TaskName("taskName")
-    val collector = new TaskInstanceCollector(producerMultiplexer)
-    val containerContext = new SamzaContainerContext("0", config, Set(taskName).asJava, new MetricsRegistryMap)
-    val taskInstance: TaskInstance = new TaskInstance(
-      task,
-      taskName,
-      config,
-      new TaskInstanceMetrics,
-      null,
-      consumerMultiplexer,
-      collector,
-      containerContext,
-      offsetManager,
-      systemStreamPartitions = systemStreamPartitions)
-    // Pretend we got a message with offset 2 and next offset 3.
-    val coordinator = new ReadableCoordinator(taskName)
-    taskInstance.process(new IncomingMessageEnvelope(systemStreamPartition, "2", null, null), coordinator)
-    // Check to see if the offset manager has been properly updated with offset 3.
-    val lastProcessedOffset = offsetManager.getLastProcessedOffset(taskName, systemStreamPartition)
-    assertTrue(lastProcessedOffset.isDefined)
-    assertEquals("2", lastProcessedOffset.get)
-  }
-
-  /**
-   * Mock exception used to test exception counts metrics.
-   */
-  class TroublesomeException extends RuntimeException {
-  }
-
-  /**
-   * Mock exception used to test exception counts metrics.
-   */
-  class NonFatalException extends RuntimeException {
-  }
-
-  /**
-   * Mock exception used to test exception counts metrics.
-   */
-  class FatalException extends RuntimeException {
-  }
-
-  /**
-   * Task used to test exception counts metrics.
-   */
-  class TroublesomeTask extends StreamTask with WindowableTask {
-    def process(
-                 envelope: IncomingMessageEnvelope,
-                 collector: MessageCollector,
-                 coordinator: TaskCoordinator) {
-
-      envelope.getOffset().toInt match {
-        case offset if offset % 2 == 0 => throw new TroublesomeException
-        case _ => throw new NonFatalException
-      }
-    }
-
-    def window(collector: MessageCollector, coordinator: TaskCoordinator) {
-      throw new FatalException
-    }
+class TestTaskInstance extends MockitoSugar {
+  private val SYSTEM_NAME = "test-system"
+  private val TASK_NAME = new TaskName("taskName")
+  private val SYSTEM_STREAM_PARTITION =
+    new SystemStreamPartition(new SystemStream(SYSTEM_NAME, "test-stream"), new Partition(0))
+  private val SYSTEM_STREAM_PARTITIONS = Set(SYSTEM_STREAM_PARTITION)
+
+  @Mock
+  private var task: AllTask = null
+  @Mock
+  private var config: Config = null
+  @Mock
+  private var metrics: TaskInstanceMetrics = null
+  @Mock
+  private var systemAdmins: SystemAdmins = null
+  @Mock
+  private var systemAdmin: SystemAdmin = null
+  @Mock
+  private var consumerMultiplexer: SystemConsumers = null
+  @Mock
+  private var collector: TaskInstanceCollector = null
+  @Mock
+  private var containerContext: SamzaContainerContext = null
+  @Mock
+  private var offsetManager: OffsetManager = null
+  @Mock
+  private var taskStorageManager: TaskStorageManager = null
+  // not a mock; using MockTaskInstanceExceptionHandler
+  private var taskInstanceExceptionHandler: MockTaskInstanceExceptionHandler = null
+
+  private var taskInstance: TaskInstance = null
+
+  @Before
+  def setup(): Unit = {
+    MockitoAnnotations.initMocks(this)
+    // not using Mockito mock since Mockito doesn't work well with the call-by-name argument in maybeHandle
+    this.taskInstanceExceptionHandler = new MockTaskInstanceExceptionHandler
+    this.taskInstance = new TaskInstance(this.task,
+      TASK_NAME,
+      this.config,
+      this.metrics,
+      this.systemAdmins,
+      this.consumerMultiplexer,
+      this.collector,
+      this.containerContext,
+      this.offsetManager,
+      storageManager = this.taskStorageManager,
+      systemStreamPartitions = SYSTEM_STREAM_PARTITIONS,
+      exceptionHandler = this.taskInstanceExceptionHandler)
+    when(this.systemAdmins.getSystemAdmin(SYSTEM_NAME)).thenReturn(this.systemAdmin)
   }
 
-  /*
-   * Helper method used to retrieve the value of a counter from a group.
-   */
-  private def getCount(
-                        group: ConcurrentHashMap[String, Metric],
-                        name: String): Long = {
-    group.get("exception-ignored-" + name.toLowerCase).asInstanceOf[Counter].getCount
+  @Test
+  def testProcess() {
+    val processesCounter = mock[Counter]
+    when(this.metrics.processes).thenReturn(processesCounter)
+    val messagesActuallyProcessedCounter = mock[Counter]
+    when(this.metrics.messagesActuallyProcessed).thenReturn(messagesActuallyProcessedCounter)
+    when(this.offsetManager.getStartingOffset(TASK_NAME, SYSTEM_STREAM_PARTITION)).thenReturn(Some("0"))
+    val envelope = new IncomingMessageEnvelope(SYSTEM_STREAM_PARTITION, "0", null, null)
+    val coordinator = mock[ReadableCoordinator]
+    this.taskInstance.process(envelope, coordinator)
+    assertEquals(1, this.taskInstanceExceptionHandler.numTimesCalled)
+    verify(this.task).process(envelope, this.collector, coordinator)
+    verify(processesCounter).inc()
+    verify(messagesActuallyProcessedCounter).inc()
   }
 
-  /**
-   * Test task instance exception metrics with two ignored exceptions and one
-   * exception not ignored.
-   */
   @Test
-  def testExceptionCounts {
-    val task = new TroublesomeTask
-    val ignoredExceptions = classOf[TroublesomeException].getName + "," +
-      classOf[NonFatalException].getName
-    val config = new MapConfig(Map[String, String](
-      "task.ignored.exceptions" -> ignoredExceptions).asJava)
-
-    val partition = new Partition(0)
-    val consumerMultiplexer = new SystemConsumers(
-      new RoundRobinChooser,
-      Map[String, SystemConsumer]())
-    val producerMultiplexer = new SystemProducers(
-      Map[String, SystemProducer](),
-      new SerdeManager)
-    val systemStream = new SystemStream("test-system", "test-stream")
-    val systemStreamPartition = new SystemStreamPartition(systemStream, partition)
-    val systemStreamPartitions = Set(systemStreamPartition)
-    // Pretend our last checkpointed (next) offset was 2.
-    val testSystemStreamMetadata = new SystemStreamMetadata(systemStream.getStream, Map(partition -> new SystemStreamPartitionMetadata("0", "1", "2")).asJava)
-    val offsetManager = OffsetManager(Map(systemStream -> testSystemStreamMetadata), config)
-    val taskName = new TaskName("taskName")
-    val collector = new TaskInstanceCollector(producerMultiplexer)
-    val containerContext = new SamzaContainerContext("0", config, Set(taskName).asJava, new MetricsRegistryMap)
-
-    val registry = new MetricsRegistryMap
-    val taskMetrics = new TaskInstanceMetrics(registry = registry)
-    val taskInstance = new TaskInstance(
-      task,
-      taskName,
-      config,
-      taskMetrics,
-      null,
-      consumerMultiplexer,
-      collector,
-      containerContext,
-      offsetManager,
-      systemStreamPartitions = systemStreamPartitions,
-      exceptionHandler = TaskInstanceExceptionHandler(taskMetrics, config))
-
-    val coordinator = new ReadableCoordinator(taskName)
-    taskInstance.process(new IncomingMessageEnvelope(systemStreamPartition, "1", null, null), coordinator)
-    taskInstance.process(new IncomingMessageEnvelope(systemStreamPartition, "2", null, null), coordinator)
-    taskInstance.process(new IncomingMessageEnvelope(systemStreamPartition, "3", null, null), coordinator)
-
-    val group = registry.getGroup(taskMetrics.group)
-    assertEquals(1L, getCount(group, classOf[TroublesomeException].getName))
-    assertEquals(2L, getCount(group, classOf[NonFatalException].getName))
-
-    intercept[FatalException] {
-      taskInstance.window(coordinator)
-    }
-    assertFalse(group.contains(classOf[FatalException].getName.toLowerCase))
+  def testWindow() {
+    val windowsCounter = mock[Counter]
+    when(this.metrics.windows).thenReturn(windowsCounter)
+    val coordinator = mock[ReadableCoordinator]
+    this.taskInstance.window(coordinator)
+    assertEquals(1, this.taskInstanceExceptionHandler.numTimesCalled)
+    verify(this.task).window(this.collector, coordinator)
+    verify(windowsCounter).inc()
   }
 
-  /**
-   * Test task instance exception metrics with all exception ignored using a
-   * wildcard.
-   */
   @Test
-  def testIgnoreAllExceptions {
-    val task = new TroublesomeTask
-    val config = new MapConfig(Map[String, String](
-      "task.ignored.exceptions" -> "*").asJava)
-
-    val partition = new Partition(0)
-    val consumerMultiplexer = new SystemConsumers(
-      new RoundRobinChooser,
-      Map[String, SystemConsumer]())
-    val producerMultiplexer = new SystemProducers(
-      Map[String, SystemProducer](),
-      new SerdeManager)
-    val systemStream = new SystemStream("test-system", "test-stream")
-    val systemStreamPartition = new SystemStreamPartition(systemStream, partition)
-    val systemStreamPartitions = Set(systemStreamPartition)
-    // Pretend our last checkpointed (next) offset was 2.
-    val testSystemStreamMetadata = new SystemStreamMetadata(systemStream.getStream, Map(partition -> new SystemStreamPartitionMetadata("0", "1", "2")).asJava)
-    val offsetManager = OffsetManager(Map(systemStream -> testSystemStreamMetadata), config)
-    val taskName = new TaskName("taskName")
-    val collector = new TaskInstanceCollector(producerMultiplexer)
-    val containerContext = new SamzaContainerContext("0", config, Set(taskName).asJava, new MetricsRegistryMap)
-
-    val registry = new MetricsRegistryMap
-    val taskMetrics = new TaskInstanceMetrics(registry = registry)
-    val taskInstance = new TaskInstance(
-      task,
-      taskName,
-      config,
-      taskMetrics,
-      null,
-      consumerMultiplexer,
-      collector,
-      containerContext,
-      offsetManager,
-      systemStreamPartitions = systemStreamPartitions,
-      exceptionHandler = TaskInstanceExceptionHandler(taskMetrics, config))
-
-    val coordinator = new ReadableCoordinator(taskName)
-    taskInstance.process(new IncomingMessageEnvelope(systemStreamPartition, "1", null, null), coordinator)
-    taskInstance.process(new IncomingMessageEnvelope(systemStreamPartition, "2", null, null), coordinator)
-    taskInstance.process(new IncomingMessageEnvelope(systemStreamPartition, "3", null, null), coordinator)
-    taskInstance.window(coordinator)
-
-    val group = registry.getGroup(taskMetrics.group)
-    assertEquals(1L, getCount(group, classOf[TroublesomeException].getName))
-    assertEquals(2L, getCount(group, classOf[NonFatalException].getName))
-    assertEquals(1L, getCount(group, classOf[FatalException].getName))
+  def testOffsetsAreUpdatedOnProcess() {
+    when(this.metrics.processes).thenReturn(mock[Counter])
+    when(this.metrics.messagesActuallyProcessed).thenReturn(mock[Counter])
+    when(this.offsetManager.getStartingOffset(TASK_NAME, SYSTEM_STREAM_PARTITION)).thenReturn(Some("2"))
+    this.taskInstance.process(new IncomingMessageEnvelope(SYSTEM_STREAM_PARTITION, "4", null, null),
+      mock[ReadableCoordinator])
+    verify(this.offsetManager).update(TASK_NAME, SYSTEM_STREAM_PARTITION, "4")
   }
 
   /**
-   * Tests that the init() method of task can override the existing offset
-   * assignment.
+   * Tests that the init() method of task can override the existing offset assignment.
+   * This helps verify wiring for the task context (i.e. offset manager).
    */
   @Test
-  def testManualOffsetReset {
-
-    val partition0 = new SystemStreamPartition("system", "stream", new Partition(0))
-    val partition1 = new SystemStreamPartition("system", "stream", new Partition(1))
-
-    val task = new StreamTask with InitableTask {
-
-      override def init(config: Config, context: TaskContext): Unit = {
-
-        assertTrue("Can only update offsets for assigned partition",
-          context.getSystemStreamPartitions.contains(partition1))
-
-        context.setStartingOffset(partition1, "10")
+  def testManualOffsetReset() {
+    when(this.task.init(any(), any())).thenAnswer(new Answer[Void] {
+      override def answer(invocation: InvocationOnMock): Void = {
+        val taskContext = invocation.getArgumentAt(1, classOf[TaskContext])
+        taskContext.setStartingOffset(SYSTEM_STREAM_PARTITION, "10")
+        null
       }
-
-      override def process(envelope: IncomingMessageEnvelope, collector: MessageCollector, coordinator: TaskCoordinator): Unit = {}
-    }
-
-    val config = new MapConfig()
-    val chooser = new RoundRobinChooser()
-    val consumers = new SystemConsumers(chooser, consumers = Map.empty)
-    val producers = new SystemProducers(Map.empty, new SerdeManager())
-    val metrics = new TaskInstanceMetrics()
-    val taskName = new TaskName("Offset Reset Task 0")
-    val collector = new TaskInstanceCollector(producers)
-    val containerContext = new SamzaContainerContext("0", config, Set(taskName).asJava, new MetricsRegistryMap)
-
-    val offsetManager = new OffsetManager()
-
-    offsetManager.startingOffsets += taskName -> Map(partition0 -> "0", partition1 -> "0")
-
-    val taskInstance = new TaskInstance(
-      task,
-      taskName,
-      config,
-      metrics,
-      null,
-      consumers,
-      collector,
-      containerContext,
-      offsetManager,
-      systemStreamPartitions = Set(partition0, partition1))
-
+    })
     taskInstance.initTask
 
-    assertEquals(Some("0"), offsetManager.getStartingOffset(taskName, partition0))
-    assertEquals(Some("10"), offsetManager.getStartingOffset(taskName, partition1))
+    verify(this.offsetManager).setStartingOffset(TASK_NAME, SYSTEM_STREAM_PARTITION, "10")
+    verifyNoMoreInteractions(this.offsetManager)
   }
 
   @Test
-  def testIgnoreMessagesOlderThanStartingOffsets {
-    val partition0 = new SystemStreamPartition("system", "stream", new Partition(0))
-    val partition1 = new SystemStreamPartition("system", "stream", new Partition(1))
-    val config = new MapConfig()
-    val chooser = new RoundRobinChooser()
-    val consumers = new SystemConsumers(chooser, consumers = Map.empty)
-    val producers = new SystemProducers(Map.empty, new SerdeManager())
-    val metrics = new TaskInstanceMetrics()
-    val taskName = new TaskName("testing")
-    val collector = new TaskInstanceCollector(producers)
-    val containerContext = new SamzaContainerContext("0", config, Set(taskName).asJava, new MetricsRegistryMap)
-    val offsetManager = new OffsetManager()
-    offsetManager.startingOffsets += taskName -> Map(partition0 -> "0", partition1 -> "100")
-    val systemAdmins = Mockito.mock(classOf[SystemAdmins])
-    when(systemAdmins.getSystemAdmin("system")).thenReturn(new MockSystemAdmin)
-    var result = new ListBuffer[IncomingMessageEnvelope]
-
-    val task = new StreamTask {
-      def process(envelope: IncomingMessageEnvelope, collector: MessageCollector, coordinator: TaskCoordinator) {
-        result += envelope
+  def testIgnoreMessagesOlderThanStartingOffsets() {
+    val processesCounter = mock[Counter]
+    when(this.metrics.processes).thenReturn(processesCounter)
+    val messagesActuallyProcessedCounter = mock[Counter]
+    when(this.metrics.messagesActuallyProcessed).thenReturn(messagesActuallyProcessedCounter)
+    when(this.offsetManager.getStartingOffset(TASK_NAME, SYSTEM_STREAM_PARTITION)).thenReturn(Some("5"))
+    when(this.systemAdmin.offsetComparator(any(), any())).thenAnswer(new Answer[Integer] {
+      override def answer(invocation: InvocationOnMock): Integer = {
+        val offset1 = invocation.getArgumentAt(0, classOf[String])
+        val offset2 = invocation.getArgumentAt(1, classOf[String])
+        offset1.toLong.compareTo(offset2.toLong)
       }
-    }
-
-    val taskInstance = new TaskInstance(
-      task,
-      taskName,
-      config,
-      metrics,
-      systemAdmins,
-      consumers,
-      collector,
-      containerContext,
-      offsetManager,
-      systemStreamPartitions = Set(partition0, partition1))
-
-    val coordinator = new ReadableCoordinator(taskName)
-    val envelope1 = new IncomingMessageEnvelope(partition0, "1", null, null)
-    val envelope2 = new IncomingMessageEnvelope(partition0, "2", null, null)
-    val envelope3 = new IncomingMessageEnvelope(partition1, "1", null, null)
-    val envelope4 = new IncomingMessageEnvelope(partition1, "102", null, null)
-
-    taskInstance.process(envelope1, coordinator)
-    taskInstance.process(envelope2, coordinator)
-    taskInstance.process(envelope3, coordinator)
-    taskInstance.process(envelope4, coordinator)
-
-    val expected = List(envelope1, envelope2, envelope4)
-    assertEquals(expected, result.toList)
+    })
+    val oldEnvelope = new IncomingMessageEnvelope(SYSTEM_STREAM_PARTITION, "0", null, null)
+    val newEnvelope0 = new IncomingMessageEnvelope(SYSTEM_STREAM_PARTITION, "5", null, null)
+    val newEnvelope1 = new IncomingMessageEnvelope(SYSTEM_STREAM_PARTITION, "7", null, null)
+
+    this.taskInstance.process(oldEnvelope, mock[ReadableCoordinator])
+    this.taskInstance.process(newEnvelope0, mock[ReadableCoordinator])
+    this.taskInstance.process(newEnvelope1, mock[ReadableCoordinator])
+    verify(this.task).process(Matchers.eq(newEnvelope0), Matchers.eq(this.collector), any())
+    verify(this.task).process(Matchers.eq(newEnvelope1), Matchers.eq(this.collector), any())
+    verify(this.task, never()).process(Matchers.eq(oldEnvelope), any(), any())
+    verify(processesCounter, times(3)).inc()
+    verify(messagesActuallyProcessedCounter, times(2)).inc()
   }
 
   @Test
-  def testCommitOrder {
-    // Simple objects
-    val partition = new Partition(0)
-    val taskName = new TaskName("taskName")
-    val systemStream = new SystemStream("test-system", "test-stream")
-    val systemStreamPartition = new SystemStreamPartition(systemStream, partition)
-    val checkpoint = new Checkpoint(Map(systemStreamPartition -> "4").asJava)
-
-    // Mocks
-    val collector = Mockito.mock(classOf[TaskInstanceCollector])
-    val storageManager = Mockito.mock(classOf[TaskStorageManager])
-    val offsetManager = Mockito.mock(classOf[OffsetManager])
-    when(offsetManager.buildCheckpoint(any())).thenReturn(checkpoint)
-    val mockOrder = inOrder(offsetManager, collector, storageManager)
-
-    val taskInstance: TaskInstance = new TaskInstance(
-      Mockito.mock(classOf[StreamTask]),
-      taskName,
-      new MapConfig,
-      new TaskInstanceMetrics,
-      null,
-      Mockito.mock(classOf[SystemConsumers]),
-      collector,
-      Mockito.mock(classOf[SamzaContainerContext]),
-      offsetManager,
-      storageManager,
-      systemStreamPartitions = Set(systemStreamPartition))
+  def testCommitOrder() {
+    val commitsCounter = mock[Counter]
+    when(this.metrics.commits).thenReturn(commitsCounter)
+    val checkpoint = new Checkpoint(Map(SYSTEM_STREAM_PARTITION -> "4").asJava)
+    when(this.offsetManager.buildCheckpoint(TASK_NAME)).thenReturn(checkpoint)
 
     taskInstance.commit
 
+    val mockOrder = inOrder(this.offsetManager, this.collector, this.taskStorageManager)
+
     // We must first get a snapshot of the checkpoint so it doesn't change while we flush. SAMZA-1384
-    mockOrder.verify(offsetManager).buildCheckpoint(taskName)
+    mockOrder.verify(this.offsetManager).buildCheckpoint(TASK_NAME)
     // Producers must be flushed next and ideally the output would be flushed before the changelog
     // s.t. the changelog and checkpoints (state and inputs) are captured last
-    mockOrder.verify(collector).flush
+    mockOrder.verify(this.collector).flush
     // Local state is next, to ensure that the state (particularly the offset file) never points to a newer changelog
     // offset than what is reflected in the on disk state.
-    mockOrder.verify(storageManager).flush()
+    mockOrder.verify(this.taskStorageManager).flush()
     // Finally, checkpoint the inputs with the snapshotted checkpoint captured at the beginning of commit
-    mockOrder.verify(offsetManager).writeCheckpoint(taskName, checkpoint)
+    mockOrder.verify(offsetManager).writeCheckpoint(TASK_NAME, checkpoint)
+    verify(commitsCounter).inc()
   }
 
   @Test(expected = classOf[SystemProducerException])
-  def testProducerExceptionsIsPropagated {
-    // Simple objects
-    val partition = new Partition(0)
-    val taskName = new TaskName("taskName")
-    val systemStream = new SystemStream("test-system", "test-stream")
-    val systemStreamPartition = new SystemStreamPartition(systemStream, partition)
-
-    // Mocks
-    val collector = Mockito.mock(classOf[TaskInstanceCollector])
-    when(collector.flush).thenThrow(new SystemProducerException("Test"))
-    val storageManager = Mockito.mock(classOf[TaskStorageManager])
-    val offsetManager = Mockito.mock(classOf[OffsetManager])
-
-    val taskInstance: TaskInstance = new TaskInstance(
-      Mockito.mock(classOf[StreamTask]),
-      taskName,
-      new MapConfig,
-      new TaskInstanceMetrics,
-      null,
-      Mockito.mock(classOf[SystemConsumers]),
-      collector,
-      Mockito.mock(classOf[SamzaContainerContext]),
-      offsetManager,
-      storageManager,
-      systemStreamPartitions = Set(systemStreamPartition))
+  def testProducerExceptionsIsPropagated() {
+    when(this.metrics.commits).thenReturn(mock[Counter])
+    when(this.collector.flush).thenThrow(new SystemProducerException("systemProducerException"))
 
     try {
       taskInstance.commit // Should not swallow the SystemProducerException
     } finally {
-      Mockito.verify(offsetManager, times(0)).writeCheckpoint(any(classOf[TaskName]), any(classOf[Checkpoint]))
+      verify(offsetManager, never()).writeCheckpoint(any(), any())
     }
   }
 
-}
-
-class MockSystemAdmin extends SystemAdmin {
-  override def getOffsetsAfter(offsets: java.util.Map[SystemStreamPartition, String]) = { offsets }
-  override def getSystemStreamMetadata(streamNames: java.util.Set[String]) = null
+  /**
+    * Task type which has all task traits, which can be mocked.
+    */
+  trait AllTask extends StreamTask with InitableTask with WindowableTask {}
 
-  override def offsetComparator(offset1: String, offset2: String) = {
-    offset1.toLong compare offset2.toLong
+  /**
+    * Mock version of [TaskInstanceExceptionHandler] which just does a passthrough execution and keeps track of the
+    * number of times it is called. This is used to verify that the handler does get used to wrap the actual processing.
+    */
+  class MockTaskInstanceExceptionHandler extends TaskInstanceExceptionHandler {
+    var numTimesCalled = 0
+
+    override def maybeHandle(tryCodeBlock: => Unit): Unit = {
+      numTimesCalled += 1
+      tryCodeBlock
+    }
   }
-}
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/samza/blob/19c6f4f6/samza-core/src/test/scala/org/apache/samza/container/TestTaskInstanceExceptionHandler.scala
----------------------------------------------------------------------
diff --git a/samza-core/src/test/scala/org/apache/samza/container/TestTaskInstanceExceptionHandler.scala b/samza-core/src/test/scala/org/apache/samza/container/TestTaskInstanceExceptionHandler.scala
new file mode 100644
index 0000000..ca06b2a
--- /dev/null
+++ b/samza-core/src/test/scala/org/apache/samza/container/TestTaskInstanceExceptionHandler.scala
@@ -0,0 +1,144 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License") you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.samza.container
+
+import com.google.common.collect.ImmutableMap
+import org.apache.samza.config.{Config, MapConfig, TaskConfig}
+import org.apache.samza.metrics.{Counter, MetricsHelper}
+import org.junit.{Before, Test}
+import org.mockito.Mockito._
+import org.mockito.{Mock, MockitoAnnotations}
+import org.scalatest.junit.AssertionsForJUnit
+import org.scalatest.mockito.MockitoSugar
+
+class TestTaskInstanceExceptionHandler extends AssertionsForJUnit with MockitoSugar {
+  @Mock
+  private var metrics: MetricsHelper = null
+  @Mock
+  private var troublesomeExceptionCounter: Counter = null
+  @Mock
+  private var nonFatalExceptionCounter: Counter = null
+  @Mock
+  private var fatalExceptionCounter: Counter = null
+
+  @Before
+  def setup() {
+    MockitoAnnotations.initMocks(this)
+    when(this.metrics.newCounter("exception-ignored-" + classOf[TroublesomeException].getName)).thenReturn(
+        this.troublesomeExceptionCounter)
+    when(this.metrics.newCounter("exception-ignored-" + classOf[NonFatalException].getName)).thenReturn(
+        this.nonFatalExceptionCounter)
+    when(this.metrics.newCounter("exception-ignored-" + classOf[FatalException].getName)).thenReturn(
+        this.fatalExceptionCounter)
+  }
+
+  /**
+   * Given that no exceptions are ignored, any exception should get propogated up.
+   */
+  @Test
+  def testHandleIgnoreNone() {
+    val handler = build(new MapConfig())
+    intercept[TroublesomeException] {
+      handler.maybeHandle(() -> {
+        throw new TroublesomeException()
+      })
+    }
+    verifyZeroInteractions(this.metrics, this.troublesomeExceptionCounter, this.nonFatalExceptionCounter,
+        this.fatalExceptionCounter)
+  }
+
+  /**
+   * Given that some exceptions are ignored, the ignored exceptions should not be thrown and should increment the proper
+   * metrics, and any other exception should get propagated up.
+   */
+  @Test
+  def testHandleIgnoreSome() {
+    val config = new MapConfig(ImmutableMap.of(TaskConfig.IGNORED_EXCEPTIONS,
+        String.join(",", classOf[TroublesomeException].getName, classOf[NonFatalException].getName)))
+    val handler = build(config)
+    handler.maybeHandle(() -> {
+      throw new TroublesomeException()
+    })
+    handler.maybeHandle(() -> {
+      throw new NonFatalException()
+    })
+    intercept[FatalException] {
+      handler.maybeHandle(() -> {
+        throw new FatalException()
+      })
+    }
+    handler.maybeHandle(() -> {
+      throw new TroublesomeException()
+    })
+    verify(this.troublesomeExceptionCounter, times(2)).inc()
+    // double check that the counter gets cached for multiple occurrences of the same exception type
+    verify(this.metrics).newCounter("exception-ignored-" + classOf[TroublesomeException].getName)
+    verify(this.nonFatalExceptionCounter).inc()
+    verifyZeroInteractions(this.fatalExceptionCounter)
+  }
+
+  /**
+   * Given that all exceptions are ignored, no exceptions should be thrown and the proper metrics should be incremented.
+   */
+  @Test
+  def testHandleIgnoreAll() {
+    val config = new MapConfig(ImmutableMap.of(TaskConfig.IGNORED_EXCEPTIONS, "*"))
+    val handler = build(config)
+    handler.maybeHandle(() -> {
+      throw new TroublesomeException()
+    })
+    handler.maybeHandle(() -> {
+      throw new TroublesomeException()
+    })
+    handler.maybeHandle(() -> {
+      throw new NonFatalException()
+    })
+    handler.maybeHandle(() -> {
+      throw new FatalException()
+    })
+
+    verify(this.troublesomeExceptionCounter, times(2)).inc()
+    // double check that the counter gets cached for multiple occurrences of the same exception type
+    verify(this.metrics).newCounter("exception-ignored-" + classOf[TroublesomeException].getName)
+    verify(this.nonFatalExceptionCounter).inc()
+    verify(this.fatalExceptionCounter).inc()
+  }
+
+  private def build(config: Config): TaskInstanceExceptionHandler = {
+    TaskInstanceExceptionHandler.apply(this.metrics, config)
+  }
+
+  /**
+   * Mock exception used to test exception counts metrics.
+   */
+  private class TroublesomeException extends RuntimeException {
+  }
+
+  /**
+   * Mock exception used to test exception counts metrics.
+   */
+  private class NonFatalException extends RuntimeException {
+  }
+
+  /**
+   * Mock exception used to test exception counts metrics.
+   */
+  private class FatalException extends RuntimeException {
+  }
+}

http://git-wip-us.apache.org/repos/asf/samza/blob/19c6f4f6/samza-core/src/test/scala/org/apache/samza/system/chooser/MockSystemAdmin.scala
----------------------------------------------------------------------
diff --git a/samza-core/src/test/scala/org/apache/samza/system/chooser/MockSystemAdmin.scala b/samza-core/src/test/scala/org/apache/samza/system/chooser/MockSystemAdmin.scala
new file mode 100644
index 0000000..288dd25
--- /dev/null
+++ b/samza-core/src/test/scala/org/apache/samza/system/chooser/MockSystemAdmin.scala
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.samza.system.chooser
+
+import org.apache.samza.system.{SystemAdmin, SystemStreamPartition}
+
+class MockSystemAdmin extends SystemAdmin {
+  override def getOffsetsAfter(offsets: java.util.Map[SystemStreamPartition, String]) = { offsets }
+  override def getSystemStreamMetadata(streamNames: java.util.Set[String]) = null
+
+  override def offsetComparator(offset1: String, offset2: String) = {
+    offset1.toLong compare offset2.toLong
+  }
+}

http://git-wip-us.apache.org/repos/asf/samza/blob/19c6f4f6/samza-core/src/test/scala/org/apache/samza/system/chooser/TestBootstrappingChooser.scala
----------------------------------------------------------------------
diff --git a/samza-core/src/test/scala/org/apache/samza/system/chooser/TestBootstrappingChooser.scala b/samza-core/src/test/scala/org/apache/samza/system/chooser/TestBootstrappingChooser.scala
index 1a99355..5116a51 100644
--- a/samza-core/src/test/scala/org/apache/samza/system/chooser/TestBootstrappingChooser.scala
+++ b/samza-core/src/test/scala/org/apache/samza/system/chooser/TestBootstrappingChooser.scala
@@ -23,7 +23,6 @@ import java.util.Arrays
 
 import org.apache.samza.system._
 import org.apache.samza.Partition
-import org.apache.samza.container.MockSystemAdmin
 import org.apache.samza.metrics.MetricsRegistryMap
 import org.apache.samza.system.SystemStreamMetadata.SystemStreamPartitionMetadata
 import org.junit.Assert._
@@ -301,4 +300,4 @@ object TestBootstrappingChooser {
       Array((wrapped: MessageChooser, bootstrapStreamMetadata: Map[SystemStream, SystemStreamMetadata]) =>
         new DefaultChooser(wrapped, bootstrapStreamMetadata = bootstrapStreamMetadata, registry = new MetricsRegistryMap(), systemAdmins = systemAdmins)))
   }
-}
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/samza/blob/19c6f4f6/samza-core/src/test/scala/org/apache/samza/system/chooser/TestDefaultChooser.scala
----------------------------------------------------------------------
diff --git a/samza-core/src/test/scala/org/apache/samza/system/chooser/TestDefaultChooser.scala b/samza-core/src/test/scala/org/apache/samza/system/chooser/TestDefaultChooser.scala
index c4c702d..a4917d4 100644
--- a/samza-core/src/test/scala/org/apache/samza/system/chooser/TestDefaultChooser.scala
+++ b/samza-core/src/test/scala/org/apache/samza/system/chooser/TestDefaultChooser.scala
@@ -21,7 +21,6 @@ package org.apache.samza.system.chooser
 
 import org.apache.samza.Partition
 import org.apache.samza.config.{DefaultChooserConfig, MapConfig}
-import org.apache.samza.container.MockSystemAdmin
 import org.apache.samza.metrics.MetricsRegistryMap
 import org.apache.samza.system.SystemStreamMetadata.SystemStreamPartitionMetadata
 import org.apache.samza.system._