You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@samza.apache.org by ni...@apache.org on 2016/07/20 16:56:48 UTC

[2/3] samza git commit: SAMZA-863: Multithreading support in Samza

http://git-wip-us.apache.org/repos/asf/samza/blob/e5f31c57/samza-core/src/main/scala/org/apache/samza/container/SamzaContainer.scala
----------------------------------------------------------------------
diff --git a/samza-core/src/main/scala/org/apache/samza/container/SamzaContainer.scala b/samza-core/src/main/scala/org/apache/samza/container/SamzaContainer.scala
index 18c0922..b8600d5 100644
--- a/samza-core/src/main/scala/org/apache/samza/container/SamzaContainer.scala
+++ b/samza-core/src/main/scala/org/apache/samza/container/SamzaContainer.scala
@@ -20,12 +20,19 @@
 package org.apache.samza.container
 
 import java.io.File
+import java.lang.Thread.UncaughtExceptionHandler
+import java.net.URL
+import java.net.UnknownHostException
 import java.nio.file.Path
 import java.util
-import java.lang.Thread.UncaughtExceptionHandler
-import java.net.{URL, UnknownHostException}
+import java.util.concurrent.ExecutorService
+import java.util.concurrent.Executors
+import java.util.concurrent.TimeUnit
+
 import org.apache.samza.SamzaException
-import org.apache.samza.checkpoint.{CheckpointManagerFactory, OffsetManager, OffsetManagerMetrics}
+import org.apache.samza.checkpoint.CheckpointManagerFactory
+import org.apache.samza.checkpoint.OffsetManager
+import org.apache.samza.checkpoint.OffsetManagerMetrics
 import org.apache.samza.config.JobConfig.Config2Job
 import org.apache.samza.config.MetricsConfig.Config2Metrics
 import org.apache.samza.config.SerializerConfig.Config2Serializer
@@ -34,18 +41,45 @@ import org.apache.samza.config.StorageConfig.Config2Storage
 import org.apache.samza.config.StreamConfig.Config2Stream
 import org.apache.samza.config.SystemConfig.Config2System
 import org.apache.samza.config.TaskConfig.Config2Task
+import org.apache.samza.container.disk.DiskQuotaPolicyFactory
+import org.apache.samza.container.disk.DiskSpaceMonitor
 import org.apache.samza.container.disk.DiskSpaceMonitor.Listener
-import org.apache.samza.container.disk.{NoThrottlingDiskQuotaPolicyFactory, DiskQuotaPolicyFactory, PollingScanDiskSpaceMonitor, DiskSpaceMonitor}
+import org.apache.samza.container.disk.NoThrottlingDiskQuotaPolicyFactory
+import org.apache.samza.container.disk.PollingScanDiskSpaceMonitor
 import org.apache.samza.coordinator.stream.CoordinatorStreamSystemFactory
-import org.apache.samza.job.model.{ContainerModel, JobModel}
-import org.apache.samza.metrics.{JmxServer, JvmMetrics, MetricsRegistryMap, MetricsReporter, MetricsReporterFactory}
-import org.apache.samza.serializers.{SerdeFactory, SerdeManager}
+import org.apache.samza.job.model.ContainerModel
+import org.apache.samza.job.model.JobModel
+import org.apache.samza.metrics.JmxServer
+import org.apache.samza.metrics.JvmMetrics
+import org.apache.samza.metrics.MetricsRegistryMap
+import org.apache.samza.metrics.MetricsReporter
+import org.apache.samza.metrics.MetricsReporterFactory
+import org.apache.samza.serializers.SerdeFactory
+import org.apache.samza.serializers.SerdeManager
 import org.apache.samza.serializers.model.SamzaObjectMapper
-import org.apache.samza.storage.{StorageEngineFactory, TaskStorageManager}
-import org.apache.samza.system.{StreamMetadataCache, SystemConsumers, SystemConsumersMetrics, SystemFactory, SystemProducers, SystemProducersMetrics, SystemStream, SystemStreamPartition}
-import org.apache.samza.system.chooser.{DefaultChooser, MessageChooserFactory, RoundRobinChooserFactory}
-import org.apache.samza.task.{StreamTask, TaskInstanceCollector}
-import org.apache.samza.util.{ThrottlingExecutor, ExponentialSleepStrategy, Logging, Util}
+import org.apache.samza.storage.StorageEngineFactory
+import org.apache.samza.storage.TaskStorageManager
+import org.apache.samza.system.StreamMetadataCache
+import org.apache.samza.system.SystemConsumers
+import org.apache.samza.system.SystemConsumersMetrics
+import org.apache.samza.system.SystemFactory
+import org.apache.samza.system.SystemProducers
+import org.apache.samza.system.SystemProducersMetrics
+import org.apache.samza.system.SystemStream
+import org.apache.samza.system.SystemStreamPartition
+import org.apache.samza.system.chooser.DefaultChooser
+import org.apache.samza.system.chooser.MessageChooserFactory
+import org.apache.samza.system.chooser.RoundRobinChooserFactory
+import org.apache.samza.task.AsyncRunLoop
+import org.apache.samza.task.AsyncStreamTask
+import org.apache.samza.task.AsyncStreamTaskAdapter
+import org.apache.samza.task.StreamTask
+import org.apache.samza.task.TaskInstanceCollector
+import org.apache.samza.util.ExponentialSleepStrategy
+import org.apache.samza.util.Logging
+import org.apache.samza.util.ThrottlingExecutor
+import org.apache.samza.util.Util
+
 import scala.collection.JavaConversions._
 
 object SamzaContainer extends Logging {
@@ -164,6 +198,12 @@ object SamzaContainer extends Logging {
 
     info("Got input stream metadata: %s" format inputStreamMetadata)
 
+    val taskClassName = config
+      .getTaskClass
+      .getOrElse(throw new SamzaException("No task class defined in configuration."))
+
+    info("Got stream task class: %s" format taskClassName)
+
     val consumers = inputSystems
       .map(systemName => {
         val systemFactory = systemFactories(systemName)
@@ -181,6 +221,9 @@ object SamzaContainer extends Logging {
 
     info("Got system consumers: %s" format consumers.keys)
 
+    val isAsyncTask = classOf[AsyncStreamTask].isAssignableFrom(Class.forName(taskClassName))
+    info("%s is AsyncStreamTask" format taskClassName)
+
     val producers = systemFactories
       .map {
         case (systemName, systemFactory) =>
@@ -360,26 +403,22 @@ object SamzaContainer extends Logging {
 
     info("Got storage engines: %s" format storageEngineFactories.keys)
 
-    val taskClassName = config
-      .getTaskClass
-      .getOrElse(throw new SamzaException("No task class defined in configuration."))
-
-    info("Got stream task class: %s" format taskClassName)
-
-    val taskWindowMs = config.getWindowMs.getOrElse(-1L)
-
-    info("Got window milliseconds: %s" format taskWindowMs)
+    val singleThreadMode = config.getSingleThreadMode
+    info("Got single thread mode: " + singleThreadMode)
 
-    val taskCommitMs = config.getCommitMs.getOrElse(60000L)
-
-    info("Got commit milliseconds: %s" format taskCommitMs)
+    if(singleThreadMode && isAsyncTask) {
+      throw new SamzaException("AsyncStreamTask %s cannot run on single thread mode." format taskClassName)
+    }
 
-    val taskShutdownMs = config.getShutdownMs.getOrElse(5000L)
+    val threadPoolSize = config.getThreadPoolSize
+    info("Got thread pool size: " + threadPoolSize)
 
-    info("Got shutdown timeout milliseconds: %s" format taskShutdownMs)
+    val taskThreadPool = if (!singleThreadMode && threadPoolSize > 0)
+      Executors.newFixedThreadPool(threadPoolSize)
+    else
+      null
 
     // Wire up all task-instance-level (unshared) objects.
-
     val taskNames = containerModel
       .getTasks
       .values
@@ -395,12 +434,18 @@ object SamzaContainer extends Logging {
     val storeWatchPaths = new util.HashSet[Path]()
     storeWatchPaths.add(defaultStoreBaseDir.toPath)
 
-    val taskInstances: Map[TaskName, TaskInstance] = containerModel.getTasks.values.map(taskModel => {
+    val taskInstances: Map[TaskName, TaskInstance[_]] = containerModel.getTasks.values.map(taskModel => {
       debug("Setting up task instance: %s" format taskModel)
 
       val taskName = taskModel.getTaskName
 
-      val task = Util.getObj[StreamTask](taskClassName)
+      val taskObj = Class.forName(taskClassName).newInstance
+
+      val task = if (!singleThreadMode && !isAsyncTask)
+        // Wrap the StreamTask into a AsyncStreamTask with the build-in thread pool
+        new AsyncStreamTaskAdapter(taskObj.asInstanceOf[StreamTask], taskThreadPool)
+      else
+        taskObj
 
       val taskInstanceMetrics = new TaskInstanceMetrics("TaskName-%s" format taskName)
 
@@ -487,20 +532,22 @@ object SamzaContainer extends Logging {
 
       info("Retrieved SystemStreamPartitions " + systemStreamPartitions + " for " + taskName)
 
-      val taskInstance = new TaskInstance(
-        task = task,
-        taskName = taskName,
-        config = config,
-        metrics = taskInstanceMetrics,
-        systemAdmins = systemAdmins,
-        consumerMultiplexer = consumerMultiplexer,
-        collector = collector,
-        containerContext = containerContext,
-        offsetManager = offsetManager,
-        storageManager = storageManager,
-        reporters = reporters,
-        systemStreamPartitions = systemStreamPartitions,
-        exceptionHandler = TaskInstanceExceptionHandler(taskInstanceMetrics, config))
+      def createTaskInstance[T] (task: T ): TaskInstance[T] = new TaskInstance[T](
+          task = task,
+          taskName = taskName,
+          config = config,
+          metrics = taskInstanceMetrics,
+          systemAdmins = systemAdmins,
+          consumerMultiplexer = consumerMultiplexer,
+          collector = collector,
+          containerContext = containerContext,
+          offsetManager = offsetManager,
+          storageManager = storageManager,
+          reporters = reporters,
+          systemStreamPartitions = systemStreamPartitions,
+          exceptionHandler = TaskInstanceExceptionHandler(taskInstanceMetrics, config))
+
+      val taskInstance = createTaskInstance(task)
 
       (taskName, taskInstance)
     }).toMap
@@ -533,14 +580,13 @@ object SamzaContainer extends Logging {
       info(s"Disk quotas disabled because polling interval is not set ($DISK_POLL_INTERVAL_KEY)")
     }
 
-    val runLoop = new RunLoop(
-      taskInstances = taskInstances,
-      consumerMultiplexer = consumerMultiplexer,
-      metrics = samzaContainerMetrics,
-      windowMs = taskWindowMs,
-      commitMs = taskCommitMs,
-      shutdownMs = taskShutdownMs,
-      executor = executor)
+    val runLoop = RunLoopFactory.createRunLoop(
+      taskInstances,
+      consumerMultiplexer,
+      taskThreadPool,
+      executor,
+      samzaContainerMetrics,
+      config)
 
     info("Samza container setup complete.")
 
@@ -557,14 +603,15 @@ object SamzaContainer extends Logging {
       reporters = reporters,
       jvm = jvm,
       jmxServer = jmxServer,
-      diskSpaceMonitor = diskSpaceMonitor)
+      diskSpaceMonitor = diskSpaceMonitor,
+      taskThreadPool = taskThreadPool)
   }
 }
 
 class SamzaContainer(
   containerContext: SamzaContainerContext,
-  taskInstances: Map[TaskName, TaskInstance],
-  runLoop: RunLoop,
+  taskInstances: Map[TaskName, TaskInstance[_]],
+  runLoop: Runnable,
   consumerMultiplexer: SystemConsumers,
   producerMultiplexer: SystemProducers,
   metrics: SamzaContainerMetrics,
@@ -574,7 +621,10 @@ class SamzaContainer(
   localityManager: LocalityManager = null,
   securityManager: SecurityManager = null,
   reporters: Map[String, MetricsReporter] = Map(),
-  jvm: JvmMetrics = null) extends Runnable with Logging {
+  jvm: JvmMetrics = null,
+  taskThreadPool: ExecutorService = null) extends Runnable with Logging {
+
+  val shutdownMs = containerContext.config.getShutdownMs.getOrElse(5000L)
 
   def run {
     try {
@@ -591,6 +641,7 @@ class SamzaContainer(
       startSecurityManger
 
       info("Entering run loop.")
+      addShutdownHook
       runLoop.run
     } catch {
       case e: Exception =>
@@ -710,7 +761,7 @@ class SamzaContainer(
     consumerMultiplexer.start
   }
 
-  def startSecurityManger: Unit = {
+  def startSecurityManger {
     if (securityManager != null) {
       info("Starting security manager.")
 
@@ -718,6 +769,25 @@ class SamzaContainer(
     }
   }
 
+  def addShutdownHook {
+    val runLoopThread = Thread.currentThread()
+    Runtime.getRuntime().addShutdownHook(new Thread() {
+      override def run() = {
+        info("Shutting down, will wait up to %s ms" format shutdownMs)
+        runLoop match {
+          case runLoop: RunLoop => runLoop.shutdown
+          case asyncRunLoop: AsyncRunLoop => asyncRunLoop.shutdown()
+        }
+        runLoopThread.join(shutdownMs)
+        if (runLoopThread.isAlive) {
+          warn("Did not shut down within %s ms, exiting" format shutdownMs)
+        } else {
+          info("Shutdown complete")
+        }
+      }
+    })
+  }
+
   def shutdownConsumers {
     info("Shutting down consumer multiplexer.")
 
@@ -733,6 +803,19 @@ class SamzaContainer(
   def shutdownTask {
     info("Shutting down task instance stream tasks.")
 
+
+    if (taskThreadPool != null) {
+      info("Shutting down task thread pool")
+      try {
+        taskThreadPool.shutdown()
+        if(taskThreadPool.awaitTermination(shutdownMs, TimeUnit.MILLISECONDS)) {
+          taskThreadPool.shutdownNow()
+        }
+      } catch {
+        case e: Exception => error(e.getMessage, e)
+      }
+    }
+
     taskInstances.values.foreach(_.shutdownTask)
   }
 

http://git-wip-us.apache.org/repos/asf/samza/blob/e5f31c57/samza-core/src/main/scala/org/apache/samza/container/SamzaContainerMetrics.scala
----------------------------------------------------------------------
diff --git a/samza-core/src/main/scala/org/apache/samza/container/SamzaContainerMetrics.scala b/samza-core/src/main/scala/org/apache/samza/container/SamzaContainerMetrics.scala
index 2044ce0..e3891cf 100644
--- a/samza-core/src/main/scala/org/apache/samza/container/SamzaContainerMetrics.scala
+++ b/samza-core/src/main/scala/org/apache/samza/container/SamzaContainerMetrics.scala
@@ -34,9 +34,11 @@ class SamzaContainerMetrics(
   val envelopes = newCounter("process-envelopes")
   val nullEnvelopes = newCounter("process-null-envelopes")
   val chooseNs = newTimer("choose-ns")
+  val chooserUpdateNs = newTimer("chooser-update-ns")
   val windowNs = newTimer("window-ns")
   val processNs = newTimer("process-ns")
   val commitNs = newTimer("commit-ns")
+  val blockNs = newTimer("block-ns")
   val utilization = newGauge("event-loop-utilization", 0.0F)
   val diskUsageBytes = newGauge("disk-usage-bytes", 0L)
   val diskQuotaBytes = newGauge("disk-quota-bytes", Long.MaxValue)

http://git-wip-us.apache.org/repos/asf/samza/blob/e5f31c57/samza-core/src/main/scala/org/apache/samza/container/TaskInstance.scala
----------------------------------------------------------------------
diff --git a/samza-core/src/main/scala/org/apache/samza/container/TaskInstance.scala b/samza-core/src/main/scala/org/apache/samza/container/TaskInstance.scala
index d32a929..89f6857 100644
--- a/samza-core/src/main/scala/org/apache/samza/container/TaskInstance.scala
+++ b/samza-core/src/main/scala/org/apache/samza/container/TaskInstance.scala
@@ -19,36 +19,39 @@
 
 package org.apache.samza.container
 
+
 import org.apache.samza.SamzaException
 import org.apache.samza.checkpoint.OffsetManager
 import org.apache.samza.config.Config
-import org.apache.samza.config.TaskConfig.Config2Task
 import org.apache.samza.metrics.MetricsReporter
 import org.apache.samza.storage.TaskStorageManager
 import org.apache.samza.system.IncomingMessageEnvelope
-import org.apache.samza.system.SystemStreamPartition
+import org.apache.samza.system.SystemAdmin
 import org.apache.samza.system.SystemConsumers
-import org.apache.samza.task.TaskContext
+import org.apache.samza.system.SystemStreamPartition
+import org.apache.samza.task.AsyncStreamTask
 import org.apache.samza.task.ClosableTask
 import org.apache.samza.task.InitableTask
-import org.apache.samza.task.WindowableTask
-import org.apache.samza.task.StreamTask
 import org.apache.samza.task.ReadableCoordinator
+import org.apache.samza.task.StreamTask
+import org.apache.samza.task.TaskCallbackFactory
+import org.apache.samza.task.TaskContext
 import org.apache.samza.task.TaskInstanceCollector
+import org.apache.samza.task.WindowableTask
 import org.apache.samza.util.Logging
+
 import scala.collection.JavaConversions._
-import org.apache.samza.system.SystemAdmin
 
-class TaskInstance(
-  task: StreamTask,
+class TaskInstance[T](
+  task: T,
   val taskName: TaskName,
   config: Config,
-  metrics: TaskInstanceMetrics,
+  val metrics: TaskInstanceMetrics,
   systemAdmins: Map[String, SystemAdmin],
   consumerMultiplexer: SystemConsumers,
   collector: TaskInstanceCollector,
   containerContext: SamzaContainerContext,
-  offsetManager: OffsetManager = new OffsetManager,
+  val offsetManager: OffsetManager = new OffsetManager,
   storageManager: TaskStorageManager = null,
   reporters: Map[String, MetricsReporter] = Map(),
   val systemStreamPartitions: Set[SystemStreamPartition] = Set(),
@@ -56,6 +59,8 @@ class TaskInstance(
   val isInitableTask = task.isInstanceOf[InitableTask]
   val isWindowableTask = task.isInstanceOf[WindowableTask]
   val isClosableTask = task.isInstanceOf[ClosableTask]
+  val isAsyncTask = task.isInstanceOf[AsyncStreamTask]
+
   val context = new TaskContext {
     def getMetricsRegistry = metrics.registry
     def getSystemStreamPartitions = systemStreamPartitions
@@ -133,7 +138,7 @@ class TaskInstance(
     })
   }
 
-  def process(envelope: IncomingMessageEnvelope, coordinator: ReadableCoordinator) {
+  def process(envelope: IncomingMessageEnvelope, coordinator: ReadableCoordinator, callbackFactory: TaskCallbackFactory = null) {
     metrics.processes.inc
 
     if (!ssp2catchedupMapping.getOrElse(envelope.getSystemStreamPartition,
@@ -146,13 +151,20 @@ class TaskInstance(
 
       trace("Processing incoming message envelope for taskName and SSP: %s, %s" format (taskName, envelope.getSystemStreamPartition))
 
-      exceptionHandler.maybeHandle {
-        task.process(envelope, collector, coordinator)
-      }
+      if (isAsyncTask) {
+        exceptionHandler.maybeHandle {
+          val callback = callbackFactory.createCallback()
+          task.asInstanceOf[AsyncStreamTask].processAsync(envelope, collector, coordinator, callback)
+        }
+      } else {
+        exceptionHandler.maybeHandle {
+         task.asInstanceOf[StreamTask].process(envelope, collector, coordinator)
+        }
 
-      trace("Updating offset map for taskName, SSP and offset: %s, %s, %s" format (taskName, envelope.getSystemStreamPartition, envelope.getOffset))
+        trace("Updating offset map for taskName, SSP and offset: %s, %s, %s" format (taskName, envelope.getSystemStreamPartition, envelope.getOffset))
 
-      offsetManager.update(taskName, envelope.getSystemStreamPartition, envelope.getOffset)
+        offsetManager.update(taskName, envelope.getSystemStreamPartition, envelope.getOffset)
+      }
     }
   }
 

http://git-wip-us.apache.org/repos/asf/samza/blob/e5f31c57/samza-core/src/main/scala/org/apache/samza/container/TaskInstanceMetrics.scala
----------------------------------------------------------------------
diff --git a/samza-core/src/main/scala/org/apache/samza/container/TaskInstanceMetrics.scala b/samza-core/src/main/scala/org/apache/samza/container/TaskInstanceMetrics.scala
index 8b86388..7bedadf 100644
--- a/samza-core/src/main/scala/org/apache/samza/container/TaskInstanceMetrics.scala
+++ b/samza-core/src/main/scala/org/apache/samza/container/TaskInstanceMetrics.scala
@@ -35,6 +35,8 @@ class TaskInstanceMetrics(
   val sends = newCounter("send-calls")
   val flushes = newCounter("flush-calls")
   val messagesSent = newCounter("messages-sent")
+  val pendingMessages = newGauge("pending-messages", 0)
+  val messagesInFlight = newGauge("messages-in-flight", 0)
 
   def addOffsetGauge(systemStreamPartition: SystemStreamPartition, getValue: () => String) {
     newGauge("%s-%s-%d-offset" format (systemStreamPartition.getSystem, systemStreamPartition.getStream, systemStreamPartition.getPartition.getPartitionId), getValue)

http://git-wip-us.apache.org/repos/asf/samza/blob/e5f31c57/samza-core/src/main/scala/org/apache/samza/coordinator/JobCoordinator.scala
----------------------------------------------------------------------
diff --git a/samza-core/src/main/scala/org/apache/samza/coordinator/JobCoordinator.scala b/samza-core/src/main/scala/org/apache/samza/coordinator/JobCoordinator.scala
index d3bd9b7..ba38b5c 100644
--- a/samza-core/src/main/scala/org/apache/samza/coordinator/JobCoordinator.scala
+++ b/samza-core/src/main/scala/org/apache/samza/coordinator/JobCoordinator.scala
@@ -71,9 +71,12 @@ object JobModelManager extends Logging {
     coordinatorSystemConsumer.start
     debug("Bootstrapping coordinator system stream.")
     coordinatorSystemConsumer.bootstrap
+    val source = "Job-coordinator"
+    coordinatorSystemProducer.register(source)
+    info("Registering coordinator system stream producer.")
     val config = coordinatorSystemConsumer.getConfig
     info("Got config: %s" format config)
-    val changelogManager = new ChangelogPartitionManager(coordinatorSystemProducer, coordinatorSystemConsumer, "Job-coordinator")
+    val changelogManager = new ChangelogPartitionManager(coordinatorSystemProducer, coordinatorSystemConsumer, source)
     val localityManager = new LocalityManager(coordinatorSystemProducer, coordinatorSystemConsumer)
 
     val systemNames = getSystemNames(config)

http://git-wip-us.apache.org/repos/asf/samza/blob/e5f31c57/samza-core/src/main/scala/org/apache/samza/system/SystemConsumers.scala
----------------------------------------------------------------------
diff --git a/samza-core/src/main/scala/org/apache/samza/system/SystemConsumers.scala b/samza-core/src/main/scala/org/apache/samza/system/SystemConsumers.scala
index 2efe836..a8355b9 100644
--- a/samza-core/src/main/scala/org/apache/samza/system/SystemConsumers.scala
+++ b/samza-core/src/main/scala/org/apache/samza/system/SystemConsumers.scala
@@ -99,7 +99,7 @@ class SystemConsumers (
    * with no remaining unprocessed messages, the SystemConsumers will poll for
    * it within 50ms of its availability in the stream system.</p>
    */
-  pollIntervalMs: Int,
+  val pollIntervalMs: Int,
 
   /**
    * Clock can be used to inject a custom clock when mocking this class in
@@ -203,28 +203,31 @@ class SystemConsumers (
     }
   }
 
-  def choose: IncomingMessageEnvelope = {
+  def choose (updateChooser: Boolean = true): IncomingMessageEnvelope = {
     val envelopeFromChooser = chooser.choose
 
     updateTimer(metrics.deserializationNs) {
       if (envelopeFromChooser == null) {
-       trace("Chooser returned null.")
+        trace("Chooser returned null.")
 
-       metrics.choseNull.inc
+        metrics.choseNull.inc
 
-       // Sleep for a while so we don't poll in a tight loop.
-       timeout = noNewMessagesTimeout
+        // Sleep for a while so we don't poll in a tight loop.
+        timeout = noNewMessagesTimeout
       } else {
-       val systemStreamPartition = envelopeFromChooser.getSystemStreamPartition
+        val systemStreamPartition = envelopeFromChooser.getSystemStreamPartition
 
-       trace("Chooser returned an incoming message envelope: %s" format envelopeFromChooser)
+        trace("Chooser returned an incoming message envelope: %s" format envelopeFromChooser)
 
-       // Ok to give the chooser a new message from this stream.
-       timeout = 0
-       metrics.choseObject.inc
-       metrics.systemStreamMessagesChosen(envelopeFromChooser.getSystemStreamPartition).inc
+        // Ok to give the chooser a new message from this stream.
+        timeout = 0
+        metrics.choseObject.inc
+        metrics.systemStreamMessagesChosen(envelopeFromChooser.getSystemStreamPartition).inc
 
-       tryUpdate(systemStreamPartition)
+        if (updateChooser) {
+          trace("Update chooser for " + systemStreamPartition.getPartition)
+          tryUpdate(systemStreamPartition)
+        }
       }
     }
 
@@ -287,7 +290,7 @@ class SystemConsumers (
     }
   }
 
-  private def tryUpdate(ssp: SystemStreamPartition) {
+  def tryUpdate(ssp: SystemStreamPartition) {
     var updated = false
     try {
       updated = update(ssp)

http://git-wip-us.apache.org/repos/asf/samza/blob/e5f31c57/samza-core/src/test/java/org/apache/samza/task/TestAsyncRunLoop.java
----------------------------------------------------------------------
diff --git a/samza-core/src/test/java/org/apache/samza/task/TestAsyncRunLoop.java b/samza-core/src/test/java/org/apache/samza/task/TestAsyncRunLoop.java
new file mode 100644
index 0000000..ca913de
--- /dev/null
+++ b/samza-core/src/test/java/org/apache/samza/task/TestAsyncRunLoop.java
@@ -0,0 +1,333 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.task;
+
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
+import org.apache.samza.Partition;
+import org.apache.samza.checkpoint.OffsetManager;
+import org.apache.samza.config.Config;
+import org.apache.samza.container.SamzaContainerContext;
+import org.apache.samza.container.SamzaContainerMetrics;
+import org.apache.samza.container.TaskInstance;
+import org.apache.samza.container.TaskInstanceExceptionHandler;
+import org.apache.samza.container.TaskInstanceMetrics;
+import org.apache.samza.container.TaskName;
+import org.apache.samza.metrics.MetricsRegistryMap;
+import org.apache.samza.system.IncomingMessageEnvelope;
+import org.apache.samza.system.SystemConsumers;
+import org.apache.samza.system.SystemStreamPartition;
+import org.junit.Before;
+import org.junit.Test;
+import scala.collection.JavaConversions;
+
+import static org.junit.Assert.assertEquals;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+public class TestAsyncRunLoop {
+
+  Map<TaskName, TaskInstance<AsyncStreamTask>> tasks;
+  ExecutorService executor;
+  SystemConsumers consumerMultiplexer;
+  SamzaContainerMetrics containerMetrics;
+  OffsetManager offsetManager;
+  long windowMs;
+  long commitMs;
+  long callbackTimeoutMs;
+  int maxMessagesInFlight;
+  TaskCoordinator.RequestScope commitRequest;
+  TaskCoordinator.RequestScope shutdownRequest;
+
+  Partition p0 = new Partition(0);
+  Partition p1 = new Partition(1);
+  TaskName taskName0 = new TaskName(p0.toString());
+  TaskName taskName1 = new TaskName(p1.toString());
+  SystemStreamPartition ssp0 = new SystemStreamPartition("testSystem", "testStream", p0);
+  SystemStreamPartition ssp1 = new SystemStreamPartition("testSystem", "testStream", p1);
+  IncomingMessageEnvelope envelope0 = new IncomingMessageEnvelope(ssp0, "0", "key0", "value0");
+  IncomingMessageEnvelope envelope1 = new IncomingMessageEnvelope(ssp1, "1", "key1", "value1");
+  IncomingMessageEnvelope envelope3 = new IncomingMessageEnvelope(ssp0, "1", "key0", "value0");
+
+  TestTask task0;
+  TestTask task1;
+  TaskInstance<AsyncStreamTask> t0;
+  TaskInstance<AsyncStreamTask> t1;
+
+  AsyncRunLoop createRunLoop() {
+    return new AsyncRunLoop(tasks,
+        executor,
+        consumerMultiplexer,
+        maxMessagesInFlight,
+        windowMs,
+        commitMs,
+        callbackTimeoutMs,
+        containerMetrics);
+  }
+
+  TaskInstance<AsyncStreamTask> createTaskInstance(AsyncStreamTask task, TaskName taskName, SystemStreamPartition ssp) {
+    TaskInstanceMetrics taskInstanceMetrics = new TaskInstanceMetrics("task", new MetricsRegistryMap());
+    scala.collection.immutable.Set<SystemStreamPartition> sspSet = JavaConversions.asScalaSet(Collections.singleton(ssp)).toSet();
+    return new TaskInstance<AsyncStreamTask>(task, taskName, mock(Config.class), taskInstanceMetrics,
+        null, consumerMultiplexer, mock(TaskInstanceCollector.class), mock(SamzaContainerContext.class),
+        offsetManager, null, null, sspSet, new TaskInstanceExceptionHandler(taskInstanceMetrics, new scala.collection.immutable.HashSet<String>()));
+  }
+
+  ExecutorService callbackExecutor;
+  void triggerCallback(final TestTask task, final TaskCallback callback, final boolean success) {
+    callbackExecutor.submit(new Runnable() {
+      @Override
+      public void run() {
+        if (task.code != null) {
+          task.code.run(callback);
+        }
+
+        task.completed.incrementAndGet();
+
+        if (success) {
+          callback.complete();
+        } else {
+          callback.failure(new Exception("process failure"));
+        }
+      }
+    });
+  }
+
+  interface TestCode {
+    void run(TaskCallback callback);
+  }
+
+  class TestTask implements AsyncStreamTask, WindowableTask {
+    boolean shutdown = false;
+    boolean commit = false;
+    boolean success;
+    int processed = 0;
+    volatile int windowCount = 0;
+
+    AtomicInteger completed = new AtomicInteger(0);
+    TestCode code = null;
+
+    TestTask(boolean success, boolean commit, boolean shutdown) {
+      this.success = success;
+      this.shutdown = shutdown;
+      this.commit = commit;
+    }
+
+    @Override
+    public void processAsync(IncomingMessageEnvelope envelope, MessageCollector collector, TaskCoordinator coordinator,
+        TaskCallback callback) {
+
+      if (maxMessagesInFlight == 1) {
+        assertEquals(processed, completed.get());
+      }
+
+      processed++;
+
+      if (commit) {
+        coordinator.commit(commitRequest);
+      }
+
+      if (shutdown) {
+        coordinator.shutdown(shutdownRequest);
+      }
+      triggerCallback(this, callback, success);
+    }
+
+    @Override
+    public void window(MessageCollector collector, TaskCoordinator coordinator) throws Exception {
+      windowCount++;
+
+      if (shutdown && windowCount == 4) {
+        coordinator.shutdown(shutdownRequest);
+      }
+    }
+  }
+
+  @Before
+  public void setup() {
+    executor = null;
+    consumerMultiplexer = mock(SystemConsumers.class);
+    windowMs = -1;
+    commitMs = -1;
+    maxMessagesInFlight = 1;
+    containerMetrics = new SamzaContainerMetrics("container", new MetricsRegistryMap());
+    callbackExecutor = Executors.newFixedThreadPool(2);
+    offsetManager = mock(OffsetManager.class);
+    shutdownRequest = TaskCoordinator.RequestScope.ALL_TASKS_IN_CONTAINER;
+
+    when(consumerMultiplexer.pollIntervalMs()).thenReturn(1000000);
+
+    tasks = new HashMap<>();
+    task0 = new TestTask(true, true, false);
+    task1 = new TestTask(true, false, true);
+    t0 = createTaskInstance(task0, taskName0, ssp0);
+    t1 = createTaskInstance(task1, taskName1, ssp1);
+    tasks.put(taskName0, t0);
+    tasks.put(taskName1, t1);
+  }
+
+  @Test
+  public void testProcessMultipleTasks() throws Exception {
+    AsyncRunLoop runLoop = createRunLoop();
+    when(consumerMultiplexer.choose(false)).thenReturn(envelope0).thenReturn(envelope1).thenReturn(null);
+    runLoop.run();
+
+    callbackExecutor.awaitTermination(100, TimeUnit.MILLISECONDS);
+
+    assertEquals(1, task0.processed);
+    assertEquals(1, task0.completed.get());
+    assertEquals(1, task1.processed);
+    assertEquals(1, task1.completed.get());
+    assertEquals(2L, containerMetrics.envelopes().getCount());
+    assertEquals(2L, containerMetrics.processes().getCount());
+  }
+
+  @Test
+  public void testProcessInOrder() throws Exception {
+    AsyncRunLoop runLoop = createRunLoop();
+    when(consumerMultiplexer.choose(false)).thenReturn(envelope0).thenReturn(envelope3).thenReturn(envelope1).thenReturn(null);
+    runLoop.run();
+
+    callbackExecutor.awaitTermination(100, TimeUnit.MILLISECONDS);
+
+    assertEquals(2, task0.processed);
+    assertEquals(2, task0.completed.get());
+    assertEquals(1, task1.processed);
+    assertEquals(1, task1.completed.get());
+    assertEquals(3L, containerMetrics.envelopes().getCount());
+    assertEquals(3L, containerMetrics.processes().getCount());
+  }
+
+  @Test
+  public void testProcessOutOfOrder() throws Exception {
+    maxMessagesInFlight = 2;
+
+    final CountDownLatch latch = new CountDownLatch(1);
+    task0.code = new TestCode() {
+      @Override
+      public void run(TaskCallback callback) {
+        IncomingMessageEnvelope envelope = ((TaskCallbackImpl) callback).envelope;
+        if (envelope == envelope0) {
+          // process first message will wait till the second one is processed
+          try {
+            latch.await();
+          } catch (InterruptedException e) {
+            e.printStackTrace();
+          }
+        } else {
+          // second envelope complete first
+          assertEquals(0, task0.completed.get());
+          latch.countDown();
+        }
+      }
+    };
+
+    AsyncRunLoop runLoop = createRunLoop();
+    when(consumerMultiplexer.choose(false)).thenReturn(envelope0).thenReturn(envelope3).thenReturn(envelope1).thenReturn(null);
+    runLoop.run();
+
+    callbackExecutor.awaitTermination(100, TimeUnit.MILLISECONDS);
+
+    assertEquals(2, task0.processed);
+    assertEquals(2, task0.completed.get());
+    assertEquals(1, task1.processed);
+    assertEquals(1, task1.completed.get());
+    assertEquals(3L, containerMetrics.envelopes().getCount());
+    assertEquals(3L, containerMetrics.processes().getCount());
+  }
+
+  @Test
+  public void testWindow() throws Exception {
+    windowMs = 1;
+
+    AsyncRunLoop runLoop = createRunLoop();
+    when(consumerMultiplexer.choose(false)).thenReturn(null);
+    runLoop.run();
+
+    callbackExecutor.awaitTermination(100, TimeUnit.MILLISECONDS);
+
+    assertEquals(4, task1.windowCount);
+  }
+
+  @Test
+  public void testCommitSingleTask() throws Exception {
+    commitRequest = TaskCoordinator.RequestScope.CURRENT_TASK;
+
+    AsyncRunLoop runLoop = createRunLoop();
+    //have a null message in between to make sure task0 finishes processing and invoke the commit
+    when(consumerMultiplexer.choose(false)).thenReturn(envelope0).thenReturn(null).thenReturn(envelope1).thenReturn(null);
+    runLoop.run();
+
+    callbackExecutor.awaitTermination(100, TimeUnit.MILLISECONDS);
+
+    verify(offsetManager).checkpoint(taskName0);
+    verify(offsetManager, never()).checkpoint(taskName1);
+  }
+
+  @Test
+  public void testCommitAllTasks() throws Exception {
+    commitRequest = TaskCoordinator.RequestScope.ALL_TASKS_IN_CONTAINER;
+
+    AsyncRunLoop runLoop = createRunLoop();
+    //have a null message in between to make sure task0 finishes processing and invoke the commit
+    when(consumerMultiplexer.choose(false)).thenReturn(envelope0).thenReturn(null).thenReturn(envelope1).thenReturn(null);
+    runLoop.run();
+
+    callbackExecutor.awaitTermination(100, TimeUnit.MILLISECONDS);
+
+    verify(offsetManager).checkpoint(taskName0);
+    verify(offsetManager).checkpoint(taskName1);
+  }
+
+  @Test
+  public void testShutdownOnConsensus() throws Exception {
+    shutdownRequest = TaskCoordinator.RequestScope.CURRENT_TASK;
+
+    tasks = new HashMap<>();
+    task0 = new TestTask(true, true, true);
+    task1 = new TestTask(true, false, true);
+    t0 = createTaskInstance(task0, taskName0, ssp0);
+    t1 = createTaskInstance(task1, taskName1, ssp1);
+    tasks.put(taskName0, t0);
+    tasks.put(taskName1, t1);
+
+    AsyncRunLoop runLoop = createRunLoop();
+    // consensus is reached after envelope1 is processed.
+    when(consumerMultiplexer.choose(false)).thenReturn(envelope0).thenReturn(envelope1).thenReturn(null);
+    runLoop.run();
+
+    callbackExecutor.awaitTermination(100, TimeUnit.MILLISECONDS);
+
+    assertEquals(1, task0.processed);
+    assertEquals(1, task0.completed.get());
+    assertEquals(1, task1.processed);
+    assertEquals(1, task1.completed.get());
+    assertEquals(2L, containerMetrics.envelopes().getCount());
+    assertEquals(2L, containerMetrics.processes().getCount());
+  }
+}

http://git-wip-us.apache.org/repos/asf/samza/blob/e5f31c57/samza-core/src/test/java/org/apache/samza/task/TestAsyncStreamAdapter.java
----------------------------------------------------------------------
diff --git a/samza-core/src/test/java/org/apache/samza/task/TestAsyncStreamAdapter.java b/samza-core/src/test/java/org/apache/samza/task/TestAsyncStreamAdapter.java
new file mode 100644
index 0000000..99e1e18
--- /dev/null
+++ b/samza-core/src/test/java/org/apache/samza/task/TestAsyncStreamAdapter.java
@@ -0,0 +1,141 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.task;
+
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.TimeUnit;
+import org.apache.samza.config.Config;
+import org.apache.samza.system.IncomingMessageEnvelope;
+import org.junit.Before;
+import org.junit.Test;
+
+import static org.junit.Assert.assertTrue;
+import static org.mockito.Mockito.mock;
+
+
+public class TestAsyncStreamAdapter {
+  TestStreamTask task;
+  AsyncStreamTaskAdapter taskAdaptor;
+  Exception e;
+  IncomingMessageEnvelope envelope;
+
+  class TestCallbackListener implements TaskCallbackListener {
+    boolean callbackComplete = false;
+    boolean callbackFailure = false;
+
+    @Override
+    public void onComplete(TaskCallback callback) {
+      callbackComplete = true;
+    }
+
+    @Override
+    public void onFailure(TaskCallback callback, Throwable t) {
+      callbackFailure = true;
+    }
+  }
+
+  class TestStreamTask implements StreamTask, InitableTask, ClosableTask, WindowableTask {
+    boolean inited = false;
+    boolean closed = false;
+    boolean processed = false;
+    boolean windowed = false;
+
+    @Override
+    public void close() throws Exception {
+      closed = true;
+    }
+
+    @Override
+    public void init(Config config, TaskContext context) throws Exception {
+      inited = true;
+    }
+
+    @Override
+    public void process(IncomingMessageEnvelope envelope, MessageCollector collector, TaskCoordinator coordinator) throws Exception {
+      processed = true;
+      if (e != null) {
+        throw e;
+      }
+    }
+
+    @Override
+    public void window(MessageCollector collector, TaskCoordinator coordinator) throws Exception {
+      windowed = true;
+    }
+  }
+
+  @Before
+  public void setup() {
+    task = new TestStreamTask();
+    e = null;
+    envelope = mock(IncomingMessageEnvelope.class);
+  }
+
+  @Test
+  public void testAdapterWithoutThreadPool() throws Exception {
+    taskAdaptor = new AsyncStreamTaskAdapter(task, null);
+    TestCallbackListener listener = new TestCallbackListener();
+    TaskCallback callback = new TaskCallbackImpl(listener, null, envelope, null, 0L);
+
+    taskAdaptor.init(null, null);
+    assertTrue(task.inited);
+
+    taskAdaptor.processAsync(null, null, null, callback);
+    assertTrue(task.processed);
+    assertTrue(listener.callbackComplete);
+
+    e = new Exception("dummy exception");
+    taskAdaptor.processAsync(null, null, null, callback);
+    assertTrue(listener.callbackFailure);
+
+    taskAdaptor.window(null, null);
+    assertTrue(task.windowed);
+
+    taskAdaptor.close();
+    assertTrue(task.closed);
+  }
+
+  @Test
+  public void testAdapterWithThreadPool() throws Exception {
+    TestCallbackListener listener1 = new TestCallbackListener();
+    TaskCallback callback1 = new TaskCallbackImpl(listener1, null, envelope, null, 0L);
+
+    TestCallbackListener listener2 = new TestCallbackListener();
+    TaskCallback callback2 = new TaskCallbackImpl(listener2, null, envelope, null, 1L);
+
+    ExecutorService executor = Executors.newFixedThreadPool(2);
+    taskAdaptor = new AsyncStreamTaskAdapter(task, executor);
+    taskAdaptor.processAsync(null, null, null, callback1);
+    taskAdaptor.processAsync(null, null, null, callback2);
+
+    executor.awaitTermination(1, TimeUnit.SECONDS);
+    assertTrue(listener1.callbackComplete);
+    assertTrue(listener2.callbackComplete);
+
+    e = new Exception("dummy exception");
+    taskAdaptor.processAsync(null, null, null, callback1);
+    taskAdaptor.processAsync(null, null, null, callback2);
+
+    executor.awaitTermination(1, TimeUnit.SECONDS);
+    assertTrue(listener1.callbackFailure);
+    assertTrue(listener2.callbackFailure);
+  }
+}

http://git-wip-us.apache.org/repos/asf/samza/blob/e5f31c57/samza-core/src/test/java/org/apache/samza/task/TestCoordinatorRequests.java
----------------------------------------------------------------------
diff --git a/samza-core/src/test/java/org/apache/samza/task/TestCoordinatorRequests.java b/samza-core/src/test/java/org/apache/samza/task/TestCoordinatorRequests.java
new file mode 100644
index 0000000..d9c68d7
--- /dev/null
+++ b/samza-core/src/test/java/org/apache/samza/task/TestCoordinatorRequests.java
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.task;
+
+import java.util.HashSet;
+import java.util.Set;
+import org.apache.samza.container.TaskName;
+import org.junit.Before;
+import org.junit.Test;
+
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+
+public class TestCoordinatorRequests {
+  CoordinatorRequests coordinatorRequests;
+  TaskName taskA = new TaskName("a");
+  TaskName taskB = new TaskName("b");
+  TaskName taskC = new TaskName("c");
+
+
+  @Before
+  public void setup() {
+    Set<TaskName> taskNames = new HashSet<>();
+    taskNames.add(taskA);
+    taskNames.add(taskB);
+    taskNames.add(taskC);
+
+    coordinatorRequests = new CoordinatorRequests(taskNames);
+  }
+
+  @Test
+  public void testUpdateCommit() {
+    ReadableCoordinator coordinator = new ReadableCoordinator(taskA);
+    coordinator.commit(TaskCoordinator.RequestScope.CURRENT_TASK);
+    coordinatorRequests.update(coordinator);
+    assertTrue(coordinatorRequests.commitRequests().contains(taskA));
+
+    coordinator = new ReadableCoordinator(taskC);
+    coordinator.commit(TaskCoordinator.RequestScope.CURRENT_TASK);
+    coordinatorRequests.update(coordinator);
+    assertTrue(coordinatorRequests.commitRequests().contains(taskC));
+    assertFalse(coordinatorRequests.commitRequests().contains(taskB));
+    assertTrue(coordinatorRequests.commitRequests().size() == 2);
+
+    coordinator.commit(TaskCoordinator.RequestScope.ALL_TASKS_IN_CONTAINER);
+    coordinatorRequests.update(coordinator);
+    assertTrue(coordinatorRequests.commitRequests().contains(taskB));
+    assertTrue(coordinatorRequests.commitRequests().size() == 3);
+  }
+
+  @Test
+  public void testUpdateShutdownOnConsensus() {
+    ReadableCoordinator coordinator = new ReadableCoordinator(taskA);
+    coordinator.shutdown(TaskCoordinator.RequestScope.CURRENT_TASK);
+    coordinatorRequests.update(coordinator);
+    assertFalse(coordinatorRequests.shouldShutdownNow());
+
+    coordinator = new ReadableCoordinator(taskB);
+    coordinator.shutdown(TaskCoordinator.RequestScope.CURRENT_TASK);
+    coordinatorRequests.update(coordinator);
+    assertFalse(coordinatorRequests.shouldShutdownNow());
+
+    coordinator = new ReadableCoordinator(taskC);
+    coordinator.shutdown(TaskCoordinator.RequestScope.CURRENT_TASK);
+    coordinatorRequests.update(coordinator);
+    assertTrue(coordinatorRequests.shouldShutdownNow());
+  }
+
+  @Test
+  public void testUpdateShutdownNow() {
+    ReadableCoordinator coordinator = new ReadableCoordinator(taskA);
+    coordinator.shutdown(TaskCoordinator.RequestScope.ALL_TASKS_IN_CONTAINER);
+    coordinatorRequests.update(coordinator);
+    assertTrue(coordinatorRequests.shouldShutdownNow());
+  }
+}

http://git-wip-us.apache.org/repos/asf/samza/blob/e5f31c57/samza-core/src/test/java/org/apache/samza/task/TestTaskCallbackImpl.java
----------------------------------------------------------------------
diff --git a/samza-core/src/test/java/org/apache/samza/task/TestTaskCallbackImpl.java b/samza-core/src/test/java/org/apache/samza/task/TestTaskCallbackImpl.java
new file mode 100644
index 0000000..f1dbf35
--- /dev/null
+++ b/samza-core/src/test/java/org/apache/samza/task/TestTaskCallbackImpl.java
@@ -0,0 +1,125 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.task;
+
+import java.util.concurrent.CyclicBarrier;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
+import org.apache.samza.system.IncomingMessageEnvelope;
+import org.junit.Before;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+import static org.mockito.Mockito.mock;
+
+
+public class TestTaskCallbackImpl {
+
+  TaskCallbackListener listener = null;
+  AtomicInteger completeCount;
+  AtomicInteger failureCount;
+  TaskCallback callback = null;
+  Throwable throwable = null;
+
+  @Before
+  public void setup() {
+    completeCount = new AtomicInteger(0);
+    failureCount = new AtomicInteger(0);
+    throwable = null;
+
+    listener = new TaskCallbackListener() {
+
+      @Override
+      public void onComplete(TaskCallback callback) {
+        completeCount.incrementAndGet();
+      }
+
+      @Override
+      public void onFailure(TaskCallback callback, Throwable t) {
+        throwable = t;
+        failureCount.incrementAndGet();
+      }
+    };
+
+    callback = new TaskCallbackImpl(listener, null, mock(IncomingMessageEnvelope.class), null, 0);
+  }
+
+  @Test
+  public void testComplete() {
+    callback.complete();
+    assertEquals(1L, completeCount.get());
+    assertEquals(0L, failureCount.get());
+  }
+
+  @Test
+  public void testFailure() {
+    callback.failure(new Exception("dummy exception"));
+    assertEquals(0L, completeCount.get());
+    assertEquals(1L, failureCount.get());
+  }
+
+  @Test
+  public void testCallbackMultipleComplete() {
+    callback.complete();
+    assertEquals(1L, completeCount.get());
+
+    callback.complete();
+    assertEquals(1L, failureCount.get());
+    assertTrue(throwable instanceof IllegalStateException);
+  }
+
+  @Test
+  public void testCallbackFailureAfterComplete() {
+    callback.complete();
+    assertEquals(1L, completeCount.get());
+
+    callback.failure(new Exception("dummy exception"));
+    assertEquals(1L, failureCount.get());
+    assertTrue(throwable instanceof IllegalStateException);
+  }
+
+
+  @Test
+  public void testMultithreadedCallbacks() throws Exception {
+    final CyclicBarrier barrier = new CyclicBarrier(2);
+    ExecutorService executor = Executors.newFixedThreadPool(2);
+
+    for (int i = 0; i < 2; i++) {
+      executor.submit(new Runnable() {
+        @Override
+        public void run() {
+          try {
+            barrier.await();
+            callback.complete();
+          } catch (Exception e) {
+            e.printStackTrace();
+          }
+        }
+      });
+    }
+    executor.awaitTermination(1, TimeUnit.SECONDS);
+    assertEquals(1L, completeCount.get());
+    assertEquals(1L, failureCount.get());
+    assertTrue(throwable instanceof IllegalStateException);
+  }
+}

http://git-wip-us.apache.org/repos/asf/samza/blob/e5f31c57/samza-core/src/test/java/org/apache/samza/task/TestTaskCallbackManager.java
----------------------------------------------------------------------
diff --git a/samza-core/src/test/java/org/apache/samza/task/TestTaskCallbackManager.java b/samza-core/src/test/java/org/apache/samza/task/TestTaskCallbackManager.java
new file mode 100644
index 0000000..d7110f3
--- /dev/null
+++ b/samza-core/src/test/java/org/apache/samza/task/TestTaskCallbackManager.java
@@ -0,0 +1,141 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.task;
+
+import org.apache.samza.Partition;
+import org.apache.samza.container.TaskInstanceMetrics;
+import org.apache.samza.container.TaskName;
+import org.apache.samza.metrics.MetricsRegistryMap;
+import org.apache.samza.system.IncomingMessageEnvelope;
+import org.apache.samza.system.SystemStreamPartition;
+import org.junit.Before;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertTrue;
+
+
+public class TestTaskCallbackManager {
+  TaskCallbackManager callbackManager = null;
+  TaskCallbackListener listener = null;
+
+  @Before
+  public void setup() {
+    TaskInstanceMetrics metrics = new TaskInstanceMetrics("Partition 0", new MetricsRegistryMap());
+    listener = new TaskCallbackListener() {
+      @Override
+      public void onComplete(TaskCallback callback) {
+      }
+      @Override
+      public void onFailure(TaskCallback callback, Throwable t) {
+      }
+    };
+    callbackManager = new TaskCallbackManager(listener, metrics, null, -1);
+
+  }
+
+  @Test
+  public void testCreateCallback() {
+    TaskCallbackImpl callback = callbackManager.createCallback(new TaskName("Partition 0"), null, null);
+    assertTrue(callback.matchSeqNum(0));
+
+    callback = callbackManager.createCallback(new TaskName("Partition 0"), null, null);
+    assertTrue(callback.matchSeqNum(1));
+  }
+
+  @Test
+  public void testUpdateCallbackInOrder() {
+    TaskName taskName = new TaskName("Partition 0");
+    SystemStreamPartition ssp = new SystemStreamPartition("kafka", "topic", new Partition(0));
+    ReadableCoordinator coordinator = new ReadableCoordinator(taskName);
+
+    IncomingMessageEnvelope envelope0 = new IncomingMessageEnvelope(ssp, "0", null, null);
+    TaskCallbackImpl callback0 = new TaskCallbackImpl(listener, taskName, envelope0, coordinator, 0);
+    TaskCallbackImpl callbackToCommit = callbackManager.updateCallback(callback0, true);
+    assertTrue(callbackToCommit.matchSeqNum(0));
+    assertEquals(ssp, callbackToCommit.envelope.getSystemStreamPartition());
+    assertEquals("0", callbackToCommit.envelope.getOffset());
+
+    IncomingMessageEnvelope envelope1 = new IncomingMessageEnvelope(ssp, "1", null, null);
+    TaskCallbackImpl callback1 = new TaskCallbackImpl(listener, taskName, envelope1, coordinator, 1);
+    callbackToCommit = callbackManager.updateCallback(callback1, true);
+    assertTrue(callbackToCommit.matchSeqNum(1));
+    assertEquals(ssp, callbackToCommit.envelope.getSystemStreamPartition());
+    assertEquals("1", callbackToCommit.envelope.getOffset());
+  }
+
+  @Test
+  public void testUpdateCallbackOutofOrder() {
+    TaskName taskName = new TaskName("Partition 0");
+    SystemStreamPartition ssp = new SystemStreamPartition("kafka", "topic", new Partition(0));
+    ReadableCoordinator coordinator = new ReadableCoordinator(taskName);
+
+    // simulate out of order
+    IncomingMessageEnvelope envelope2 = new IncomingMessageEnvelope(ssp, "2", null, null);
+    TaskCallbackImpl callback2 = new TaskCallbackImpl(listener, taskName, envelope2, coordinator, 2);
+    TaskCallbackImpl callbackToCommit = callbackManager.updateCallback(callback2, true);
+    assertNull(callbackToCommit);
+
+    IncomingMessageEnvelope envelope1 = new IncomingMessageEnvelope(ssp, "1", null, null);
+    TaskCallbackImpl callback1 = new TaskCallbackImpl(listener, taskName, envelope1, coordinator, 1);
+    callbackToCommit = callbackManager.updateCallback(callback1, true);
+    assertNull(callbackToCommit);
+
+    IncomingMessageEnvelope envelope0 = new IncomingMessageEnvelope(ssp, "0", null, null);
+    TaskCallbackImpl callback0 = new TaskCallbackImpl(listener, taskName, envelope0, coordinator, 0);
+    callbackToCommit = callbackManager.updateCallback(callback0, true);
+    assertTrue(callbackToCommit.matchSeqNum(2));
+    assertEquals(ssp, callbackToCommit.envelope.getSystemStreamPartition());
+    assertEquals("2", callbackToCommit.envelope.getOffset());
+  }
+
+  @Test
+  public void testUpdateCallbackWithCoordinatorRequests() {
+    TaskName taskName = new TaskName("Partition 0");
+    SystemStreamPartition ssp = new SystemStreamPartition("kafka", "topic", new Partition(0));
+
+
+    // simulate out of order
+    IncomingMessageEnvelope envelope2 = new IncomingMessageEnvelope(ssp, "2", null, null);
+    ReadableCoordinator coordinator2 = new ReadableCoordinator(taskName);
+    coordinator2.shutdown(TaskCoordinator.RequestScope.ALL_TASKS_IN_CONTAINER);
+    TaskCallbackImpl callback2 = new TaskCallbackImpl(listener, taskName, envelope2, coordinator2, 2);
+    TaskCallbackImpl callbackToCommit = callbackManager.updateCallback(callback2, true);
+    assertNull(callbackToCommit);
+
+    IncomingMessageEnvelope envelope1 = new IncomingMessageEnvelope(ssp, "1", null, null);
+    ReadableCoordinator coordinator1 = new ReadableCoordinator(taskName);
+    coordinator1.commit(TaskCoordinator.RequestScope.CURRENT_TASK);
+    TaskCallbackImpl callback1 = new TaskCallbackImpl(listener, taskName, envelope1, coordinator1, 1);
+    callbackToCommit = callbackManager.updateCallback(callback1, true);
+    assertNull(callbackToCommit);
+
+    IncomingMessageEnvelope envelope0 = new IncomingMessageEnvelope(ssp, "0", null, null);
+    ReadableCoordinator coordinator = new ReadableCoordinator(taskName);
+    TaskCallbackImpl callback0 = new TaskCallbackImpl(listener, taskName, envelope0, coordinator, 0);
+    callbackToCommit = callbackManager.updateCallback(callback0, true);
+    assertTrue(callbackToCommit.matchSeqNum(1));
+    assertEquals(ssp, callbackToCommit.envelope.getSystemStreamPartition());
+    assertEquals("1", callbackToCommit.envelope.getOffset());
+    assertTrue(callbackToCommit.coordinator.requestedShutdownNow());
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/samza/blob/e5f31c57/samza-core/src/test/scala/org/apache/samza/container/TestRunLoop.scala
----------------------------------------------------------------------
diff --git a/samza-core/src/test/scala/org/apache/samza/container/TestRunLoop.scala b/samza-core/src/test/scala/org/apache/samza/container/TestRunLoop.scala
index e280daa..aa1a8d6 100644
--- a/samza-core/src/test/scala/org/apache/samza/container/TestRunLoop.scala
+++ b/samza-core/src/test/scala/org/apache/samza/container/TestRunLoop.scala
@@ -20,22 +20,26 @@
 package org.apache.samza.container
 
 
-import org.apache.samza.metrics.{Timer, SlidingTimeWindowReservoir, MetricsRegistryMap}
+import org.apache.samza.Partition
+import org.apache.samza.metrics.MetricsRegistryMap
+import org.apache.samza.metrics.SlidingTimeWindowReservoir
+import org.apache.samza.metrics.Timer
+import org.apache.samza.system.IncomingMessageEnvelope
+import org.apache.samza.system.SystemConsumers
+import org.apache.samza.system.SystemStreamPartition
+import org.apache.samza.task.TaskCoordinator.RequestScope
+import org.apache.samza.task.ReadableCoordinator
+import org.apache.samza.task.StreamTask
 import org.apache.samza.util.Clock
-import org.junit.Test
 import org.junit.Assert._
+import org.junit.Test
 import org.mockito.Matchers
 import org.mockito.Mockito._
-import org.mockito.internal.util.reflection.Whitebox
 import org.mockito.invocation.InvocationOnMock
 import org.mockito.stubbing.Answer
 import org.scalatest.junit.AssertionsForJUnit
-import org.scalatest.{Matchers => ScalaTestMatchers}
 import org.scalatest.mock.MockitoSugar
-import org.apache.samza.Partition
-import org.apache.samza.system.{ IncomingMessageEnvelope, SystemConsumers, SystemStreamPartition }
-import org.apache.samza.task.ReadableCoordinator
-import org.apache.samza.task.TaskCoordinator.RequestScope
+import org.scalatest.{Matchers => ScalaTestMatchers}
 
 class TestRunLoop extends AssertionsForJUnit with MockitoSugar with ScalaTestMatchers {
   class StopRunLoop extends RuntimeException
@@ -49,12 +53,12 @@ class TestRunLoop extends AssertionsForJUnit with MockitoSugar with ScalaTestMat
   val envelope0 = new IncomingMessageEnvelope(ssp0, "0", "key0", "value0")
   val envelope1 = new IncomingMessageEnvelope(ssp1, "1", "key1", "value1")
 
-  def getMockTaskInstances: Map[TaskName, TaskInstance] = {
-    val ti0 = mock[TaskInstance]
+  def getMockTaskInstances: Map[TaskName, TaskInstance[StreamTask]] = {
+    val ti0 = mock[TaskInstance[StreamTask]]
     when(ti0.systemStreamPartitions).thenReturn(Set(ssp0))
     when(ti0.taskName).thenReturn(taskName0)
 
-    val ti1 = mock[TaskInstance]
+    val ti1 = mock[TaskInstance[StreamTask]]
     when(ti1.systemStreamPartitions).thenReturn(Set(ssp1))
     when(ti1.taskName).thenReturn(taskName1)
 
@@ -67,10 +71,10 @@ class TestRunLoop extends AssertionsForJUnit with MockitoSugar with ScalaTestMat
     val consumers = mock[SystemConsumers]
     val runLoop = new RunLoop(taskInstances, consumers, new SamzaContainerMetrics)
 
-    when(consumers.choose).thenReturn(envelope0).thenReturn(envelope1).thenThrow(new StopRunLoop)
+    when(consumers.choose()).thenReturn(envelope0).thenReturn(envelope1).thenThrow(new StopRunLoop)
     intercept[StopRunLoop] { runLoop.run }
-    verify(taskInstances(taskName0)).process(Matchers.eq(envelope0), anyObject)
-    verify(taskInstances(taskName1)).process(Matchers.eq(envelope1), anyObject)
+    verify(taskInstances(taskName0)).process(Matchers.eq(envelope0), anyObject, anyObject)
+    verify(taskInstances(taskName1)).process(Matchers.eq(envelope1), anyObject, anyObject)
     runLoop.metrics.envelopes.getCount should equal(2L)
     runLoop.metrics.nullEnvelopes.getCount should equal(0L)
   }
@@ -80,7 +84,7 @@ class TestRunLoop extends AssertionsForJUnit with MockitoSugar with ScalaTestMat
     val consumers = mock[SystemConsumers]
     val map = getMockTaskInstances - taskName1 // This test only needs p0
     val runLoop = new RunLoop(map, consumers, new SamzaContainerMetrics)
-    when(consumers.choose).thenReturn(null).thenReturn(null).thenThrow(new StopRunLoop)
+    when(consumers.choose()).thenReturn(null).thenReturn(null).thenThrow(new StopRunLoop)
     intercept[StopRunLoop] { runLoop.run }
     runLoop.metrics.envelopes.getCount should equal(0L)
     runLoop.metrics.nullEnvelopes.getCount should equal(2L)
@@ -90,7 +94,7 @@ class TestRunLoop extends AssertionsForJUnit with MockitoSugar with ScalaTestMat
   def testWindowAndCommitAreCalledRegularly {
     var now = 1400000000000L
     val consumers = mock[SystemConsumers]
-    when(consumers.choose).thenReturn(envelope0)
+    when(consumers.choose()).thenReturn(envelope0)
 
     val runLoop = new RunLoop(
       taskInstances = getMockTaskInstances,
@@ -118,7 +122,7 @@ class TestRunLoop extends AssertionsForJUnit with MockitoSugar with ScalaTestMat
     val consumers = mock[SystemConsumers]
     val runLoop = new RunLoop(taskInstances, consumers, new SamzaContainerMetrics, windowMs = -1, commitMs = -1)
 
-    when(consumers.choose).thenReturn(envelope0).thenReturn(envelope1).thenThrow(new StopRunLoop)
+    when(consumers.choose()).thenReturn(envelope0).thenReturn(envelope1).thenThrow(new StopRunLoop)
     stubProcess(taskInstances(taskName0), (envelope, coordinator) => coordinator.commit(RequestScope.CURRENT_TASK))
 
     intercept[StopRunLoop] { runLoop.run }
@@ -132,7 +136,7 @@ class TestRunLoop extends AssertionsForJUnit with MockitoSugar with ScalaTestMat
     val consumers = mock[SystemConsumers]
     val runLoop = new RunLoop(taskInstances, consumers, new SamzaContainerMetrics, windowMs = -1, commitMs = -1)
 
-    when(consumers.choose).thenReturn(envelope0).thenThrow(new StopRunLoop)
+    when(consumers.choose()).thenReturn(envelope0).thenThrow(new StopRunLoop)
     stubProcess(taskInstances(taskName0), (envelope, coordinator) => coordinator.commit(RequestScope.ALL_TASKS_IN_CONTAINER))
 
     intercept[StopRunLoop] { runLoop.run }
@@ -146,13 +150,13 @@ class TestRunLoop extends AssertionsForJUnit with MockitoSugar with ScalaTestMat
     val consumers = mock[SystemConsumers]
     val runLoop = new RunLoop(taskInstances, consumers, new SamzaContainerMetrics, windowMs = -1, commitMs = -1)
 
-    when(consumers.choose).thenReturn(envelope0).thenReturn(envelope0).thenReturn(envelope1)
+    when(consumers.choose()).thenReturn(envelope0).thenReturn(envelope0).thenReturn(envelope1)
     stubProcess(taskInstances(taskName0), (envelope, coordinator) => coordinator.shutdown(RequestScope.CURRENT_TASK))
     stubProcess(taskInstances(taskName1), (envelope, coordinator) => coordinator.shutdown(RequestScope.CURRENT_TASK))
 
     runLoop.run
-    verify(taskInstances(taskName0), times(2)).process(Matchers.eq(envelope0), anyObject)
-    verify(taskInstances(taskName1), times(1)).process(Matchers.eq(envelope1), anyObject)
+    verify(taskInstances(taskName0), times(2)).process(Matchers.eq(envelope0), anyObject, anyObject)
+    verify(taskInstances(taskName1), times(1)).process(Matchers.eq(envelope1), anyObject, anyObject)
   }
 
   @Test
@@ -161,19 +165,19 @@ class TestRunLoop extends AssertionsForJUnit with MockitoSugar with ScalaTestMat
     val consumers = mock[SystemConsumers]
     val runLoop = new RunLoop(taskInstances, consumers, new SamzaContainerMetrics, windowMs = -1, commitMs = -1)
 
-    when(consumers.choose).thenReturn(envelope0).thenReturn(envelope1)
+    when(consumers.choose()).thenReturn(envelope0).thenReturn(envelope1)
     stubProcess(taskInstances(taskName0), (envelope, coordinator) => coordinator.shutdown(RequestScope.ALL_TASKS_IN_CONTAINER))
 
     runLoop.run
-    verify(taskInstances(taskName0), times(1)).process(anyObject, anyObject)
-    verify(taskInstances(taskName1), times(0)).process(anyObject, anyObject)
+    verify(taskInstances(taskName0), times(1)).process(anyObject, anyObject, anyObject)
+    verify(taskInstances(taskName1), times(0)).process(anyObject, anyObject, anyObject)
   }
 
   def anyObject[T] = Matchers.anyObject.asInstanceOf[T]
 
   // Stub out TaskInstance.process. Mockito really doesn't make this easy. :(
-  def stubProcess(taskInstance: TaskInstance, process: (IncomingMessageEnvelope, ReadableCoordinator) => Unit) {
-    when(taskInstance.process(anyObject, anyObject)).thenAnswer(new Answer[Unit]() {
+  def stubProcess(taskInstance: TaskInstance[StreamTask], process: (IncomingMessageEnvelope, ReadableCoordinator) => Unit) {
+    when(taskInstance.process(anyObject, anyObject, anyObject)).thenAnswer(new Answer[Unit]() {
       override def answer(invocation: InvocationOnMock) {
         val envelope = invocation.getArguments()(0).asInstanceOf[IncomingMessageEnvelope]
         val coordinator = invocation.getArguments()(1).asInstanceOf[ReadableCoordinator]
@@ -186,7 +190,7 @@ class TestRunLoop extends AssertionsForJUnit with MockitoSugar with ScalaTestMat
   def testUpdateTimerCorrectly {
     var now = 0L
     val consumers = mock[SystemConsumers]
-    when(consumers.choose).thenReturn(envelope0)
+    when(consumers.choose()).thenReturn(envelope0)
     val clock = new Clock {
       var c = 0L
       def currentTimeMillis: Long = {
@@ -263,9 +267,9 @@ class TestRunLoop extends AssertionsForJUnit with MockitoSugar with ScalaTestMat
 
   @Test
   def testGetSystemStreamPartitionToTaskInstancesMapping {
-    val ti0 = mock[TaskInstance]
-    val ti1 = mock[TaskInstance]
-    val ti2 = mock[TaskInstance]
+    val ti0 = mock[TaskInstance[StreamTask]]
+    val ti1 = mock[TaskInstance[StreamTask]]
+    val ti2 = mock[TaskInstance[StreamTask]]
     when(ti0.systemStreamPartitions).thenReturn(Set(ssp0))
     when(ti1.systemStreamPartitions).thenReturn(Set(ssp1))
     when(ti2.systemStreamPartitions).thenReturn(Set(ssp1))

http://git-wip-us.apache.org/repos/asf/samza/blob/e5f31c57/samza-core/src/test/scala/org/apache/samza/container/TestSamzaContainer.scala
----------------------------------------------------------------------
diff --git a/samza-core/src/test/scala/org/apache/samza/container/TestSamzaContainer.scala b/samza-core/src/test/scala/org/apache/samza/container/TestSamzaContainer.scala
index 1358fdd..cff6b96 100644
--- a/samza-core/src/test/scala/org/apache/samza/container/TestSamzaContainer.scala
+++ b/samza-core/src/test/scala/org/apache/samza/container/TestSamzaContainer.scala
@@ -180,7 +180,7 @@ class TestSamzaContainer extends AssertionsForJUnit with MockitoSugar {
       new SerdeManager)
     val collector = new TaskInstanceCollector(producerMultiplexer)
     val containerContext = new SamzaContainerContext(0, config, Set[TaskName](taskName))
-    val taskInstance: TaskInstance = new TaskInstance(
+    val taskInstance: TaskInstance[StreamTask] = new TaskInstance[StreamTask](
       task,
       taskName,
       config,
@@ -261,7 +261,7 @@ class TestSamzaContainer extends AssertionsForJUnit with MockitoSugar {
       }
     })
 
-    val taskInstance: TaskInstance = new TaskInstance(
+    val taskInstance: TaskInstance[StreamTask] = new TaskInstance[StreamTask](
       task,
       taskName,
       config,
@@ -314,4 +314,4 @@ class MockJobServlet(exceptionLimit: Int, jobModelRef: AtomicReference[JobModel]
       jobModel
     }
   }
-}
\ No newline at end of file
+}

http://git-wip-us.apache.org/repos/asf/samza/blob/e5f31c57/samza-core/src/test/scala/org/apache/samza/container/TestTaskInstance.scala
----------------------------------------------------------------------
diff --git a/samza-core/src/test/scala/org/apache/samza/container/TestTaskInstance.scala b/samza-core/src/test/scala/org/apache/samza/container/TestTaskInstance.scala
index 5457f0e..3c83529 100644
--- a/samza-core/src/test/scala/org/apache/samza/container/TestTaskInstance.scala
+++ b/samza-core/src/test/scala/org/apache/samza/container/TestTaskInstance.scala
@@ -71,7 +71,7 @@ class TestTaskInstance {
     val taskName = new TaskName("taskName")
     val collector = new TaskInstanceCollector(producerMultiplexer)
     val containerContext = new SamzaContainerContext(0, config, Set[TaskName](taskName))
-    val taskInstance: TaskInstance = new TaskInstance(
+    val taskInstance: TaskInstance[StreamTask] = new TaskInstance[StreamTask](
       task,
       taskName,
       config,
@@ -169,7 +169,7 @@ class TestTaskInstance {
 
     val registry = new MetricsRegistryMap
     val taskMetrics = new TaskInstanceMetrics(registry = registry)
-    val taskInstance: TaskInstance = new TaskInstance(
+    val taskInstance: TaskInstance[StreamTask] = new TaskInstance[StreamTask](
       task,
       taskName,
       config,
@@ -226,7 +226,7 @@ class TestTaskInstance {
 
     val registry = new MetricsRegistryMap
     val taskMetrics = new TaskInstanceMetrics(registry = registry)
-    val taskInstance: TaskInstance = new TaskInstance(
+    val taskInstance: TaskInstance[StreamTask] = new TaskInstance[StreamTask](
       task,
       taskName,
       config,

http://git-wip-us.apache.org/repos/asf/samza/blob/e5f31c57/samza-core/src/test/scala/org/apache/samza/system/TestSystemConsumers.scala
----------------------------------------------------------------------
diff --git a/samza-core/src/test/scala/org/apache/samza/system/TestSystemConsumers.scala b/samza-core/src/test/scala/org/apache/samza/system/TestSystemConsumers.scala
index 09da62e..db2249b 100644
--- a/samza-core/src/test/scala/org/apache/samza/system/TestSystemConsumers.scala
+++ b/samza-core/src/test/scala/org/apache/samza/system/TestSystemConsumers.scala
@@ -54,14 +54,14 @@ class TestSystemConsumers {
     consumer.setResponseSizes(numEnvelopes)
 
     // Choose to trigger a refresh with data.
-    assertNull(consumers.choose)
+    assertNull(consumers.choose())
     // 2: First on start, second on choose.
     assertEquals(2, consumer.polls)
     assertEquals(2, consumer.lastPoll.size)
     assertTrue(consumer.lastPoll.contains(systemStreamPartition0))
     assertTrue(consumer.lastPoll.contains(systemStreamPartition1))
-    assertEquals(envelope, consumers.choose)
-    assertEquals(envelope, consumers.choose)
+    assertEquals(envelope, consumers.choose())
+    assertEquals(envelope, consumers.choose())
     // We aren't polling because we're getting non-null envelopes.
     assertEquals(2, consumer.polls)
 
@@ -69,7 +69,7 @@ class TestSystemConsumers {
     // messages.
     now = SystemConsumers.DEFAULT_POLL_INTERVAL_MS
 
-    assertEquals(envelope, consumers.choose)
+    assertEquals(envelope, consumers.choose())
 
     // We polled even though there are still 997 messages in the unprocessed
     // message buffer.
@@ -82,11 +82,11 @@ class TestSystemConsumers {
     // Now drain all messages for SSP0. There should be exactly 997 messages,
     // since we have chosen 3 already, and we started with 1000.
     (0 until (numEnvelopes - 3)).foreach { i =>
-      assertEquals(envelope, consumers.choose)
+      assertEquals(envelope, consumers.choose())
     }
 
     // Nothing left. Should trigger a poll here.
-    assertNull(consumers.choose)
+    assertNull(consumers.choose())
     assertEquals(4, consumer.polls)
     assertEquals(2, consumer.lastPoll.size)
 
@@ -117,31 +117,31 @@ class TestSystemConsumers {
     consumer.setResponseSizes(1)
 
     // Choose to trigger a refresh with data.
-    assertNull(consumers.choose)
+    assertNull(consumers.choose())
 
     // Choose should have triggered a second poll, since no messages are available.
     assertEquals(2, consumer.polls)
 
     // Choose a few times. This time there is no data.
-    assertEquals(envelope, consumers.choose)
-    assertNull(consumers.choose)
-    assertNull(consumers.choose)
+    assertEquals(envelope, consumers.choose())
+    assertNull(consumers.choose())
+    assertNull(consumers.choose())
 
     // Return more than one message this time.
     consumer.setResponseSizes(2)
 
     // Choose to trigger a refresh with data.
-    assertNull(consumers.choose)
+    assertNull(consumers.choose())
 
     // Increase clock interval.
     now = SystemConsumers.DEFAULT_POLL_INTERVAL_MS
 
     // We get two messages now.
-    assertEquals(envelope, consumers.choose)
+    assertEquals(envelope, consumers.choose())
     // Should not poll even though clock interval increases past interval threshold.
     assertEquals(2, consumer.polls)
-    assertEquals(envelope, consumers.choose)
-    assertNull(consumers.choose)
+    assertEquals(envelope, consumers.choose())
+    assertNull(consumers.choose())
   }
 
   @Test
@@ -238,7 +238,7 @@ class TestSystemConsumers {
 
     var caughtRightException = false
     try {
-      consumers.choose
+      consumers.choose()
     } catch {
       case e: SystemConsumersException => caughtRightException = true
       case _: Throwable => caughtRightException = false
@@ -256,13 +256,13 @@ class TestSystemConsumers {
 
     var notThrowException = true;
     try {
-      consumers2.choose
+      consumers2.choose()
     } catch {
       case e: Throwable => notThrowException = false
     }
     assertTrue("it should not throw any exception", notThrowException)
 
-    var msgEnvelope = Some(consumers2.choose)
+    var msgEnvelope = Some(consumers2.choose())
     assertTrue("Consumer did not succeed in receiving the second message after Serde exception in choose", msgEnvelope.get != null)
     consumers2.stop
 
@@ -279,7 +279,7 @@ class TestSystemConsumers {
     assertTrue("SystemConsumer start should not throw any Serde exception", notThrowException)
 
     msgEnvelope = null
-    msgEnvelope = Some(consumers2.choose)
+    msgEnvelope = Some(consumers2.choose())
     assertTrue("Consumer did not succeed in receiving the second message after Serde exception in poll", msgEnvelope.get != null)
     consumers2.stop
 

http://git-wip-us.apache.org/repos/asf/samza/blob/e5f31c57/samza-hdfs/src/main/scala/org/apache/samza/system/hdfs/HdfsSystemProducer.scala
----------------------------------------------------------------------
diff --git a/samza-hdfs/src/main/scala/org/apache/samza/system/hdfs/HdfsSystemProducer.scala b/samza-hdfs/src/main/scala/org/apache/samza/system/hdfs/HdfsSystemProducer.scala
index 1f4b5c4..24bc8b5 100644
--- a/samza-hdfs/src/main/scala/org/apache/samza/system/hdfs/HdfsSystemProducer.scala
+++ b/samza-hdfs/src/main/scala/org/apache/samza/system/hdfs/HdfsSystemProducer.scala
@@ -36,6 +36,7 @@ class HdfsSystemProducer(
   val clock: () => Long = () => System.currentTimeMillis) extends SystemProducer with Logging with TimerUtils {
   val dfs = FileSystem.get(new Configuration(true))
   val writers: MMap[String, HdfsWriter[_]] = MMap.empty[String, HdfsWriter[_]]
+  private val lock = new Object //synchronization lock for thread safe access
 
   def start(): Unit = {
     info("entering HdfsSystemProducer.start() call for system: " + systemName + ", client: " + clientId)
@@ -43,52 +44,65 @@ class HdfsSystemProducer(
 
   def stop(): Unit = {
     info("entering HdfsSystemProducer.stop() for system: " + systemName + ", client: " + clientId)
-    writers.values.map { _.close }
-    dfs.close
+
+    lock.synchronized {
+      writers.values.map(_.close)
+      dfs.close
+    }
   }
 
   def register(source: String): Unit = {
     info("entering HdfsSystemProducer.register(" + source + ") " +
       "call for system: " + systemName + ", client: " + clientId)
-    writers += (source -> HdfsWriter.getInstance(dfs, systemName, config))
+
+    lock.synchronized {
+      writers += (source -> HdfsWriter.getInstance(dfs, systemName, config))
+    }
   }
 
   def flush(source: String): Unit = {
     debug("entering HdfsSystemProducer.flush(" + source + ") " +
       "call for system: " + systemName + ", client: " + clientId)
-    try {
-      metrics.flushes.inc
-      updateTimer(metrics.flushMs) { writers.get(source).head.flush }
-      metrics.flushSuccess.inc
-    } catch {
-      case e: Exception => {
-        metrics.flushFailed.inc
-        warn("Exception thrown while client " + clientId + " flushed HDFS out stream, msg: " + e.getMessage)
-        debug("Detailed message from exception thrown by client " + clientId + " in HDFS flush: ", e)
-        writers.get(source).head.close
-        throw e
+
+    metrics.flushes.inc
+    lock.synchronized {
+      try {
+        updateTimer(metrics.flushMs) {
+          writers.get(source).head.flush
+        }
+      } catch {
+        case e: Exception => {
+          metrics.flushFailed.inc
+          warn("Exception thrown while client " + clientId + " flushed HDFS out stream, msg: " + e.getMessage)
+          debug("Detailed message from exception thrown by client " + clientId + " in HDFS flush: ", e)
+          writers.get(source).head.close
+          throw e
+        }
       }
     }
+    metrics.flushSuccess.inc
   }
 
   def send(source: String, ome: OutgoingMessageEnvelope) = {
     debug("entering HdfsSystemProducer.send(source = " + source + ", envelope) " +
       "call for system: " + systemName + ", client: " + clientId)
+
     metrics.sends.inc
-    try {
-      updateTimer(metrics.sendMs) {
-        writers.get(source).head.write(ome)
-      }
-      metrics.sendSuccess.inc
-    } catch {
-      case e: Exception => {
-        metrics.sendFailed.inc
-        warn("Exception thrown while client " + clientId + " wrote to HDFS, msg: " + e.getMessage)
-        debug("Detailed message from exception thrown by client " + clientId + " in HDFS write: ", e)
-        writers.get(source).head.close
-        throw e
+    lock.synchronized {
+      try {
+        updateTimer(metrics.sendMs) {
+          writers.get(source).head.write(ome)
+        }
+      } catch {
+        case e: Exception => {
+          metrics.sendFailed.inc
+          warn("Exception thrown while client " + clientId + " wrote to HDFS, msg: " + e.getMessage)
+          debug("Detailed message from exception thrown by client " + clientId + " in HDFS write: ", e)
+          writers.get(source).head.close
+          throw e
+        }
       }
     }
+    metrics.sendSuccess.inc
   }
-
-}
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/samza/blob/e5f31c57/samza-kafka/src/main/scala/org/apache/samza/migration/KafkaCheckpointMigration.scala
----------------------------------------------------------------------
diff --git a/samza-kafka/src/main/scala/org/apache/samza/migration/KafkaCheckpointMigration.scala b/samza-kafka/src/main/scala/org/apache/samza/migration/KafkaCheckpointMigration.scala
index 5e8cc65..5d2641a 100644
--- a/samza-kafka/src/main/scala/org/apache/samza/migration/KafkaCheckpointMigration.scala
+++ b/samza-kafka/src/main/scala/org/apache/samza/migration/KafkaCheckpointMigration.scala
@@ -140,6 +140,7 @@ class KafkaCheckpointMigration extends MigrationPlan with Logging {
   def migrationCompletionMark(coordinatorSystemProducer: CoordinatorStreamSystemProducer) = {
     info("Marking completion of migration %s" format migrationKey)
     val message = new SetMigrationMetaMessage(source, migrationKey, migrationVal)
+    coordinatorSystemProducer.register(source)
     coordinatorSystemProducer.start()
     coordinatorSystemProducer.send(message)
     coordinatorSystemProducer.stop()